# 1. Import dependencies

In [None]:
# MSS used for screen capture in MacOS
from mss.darwin import MSS as mss
# Sending commands
import pyautogui
# OpenCV allows su to do frame processing
import cv2
# Transformational framework
import numpy as np
# OCR for game over extraction
import pytesseract
# Visualize capture frames
import matplotlib.pyplot as plt
# Bring in time for pauses
import time
# Environment components
from gym import Env
from gym.spaces import Box, Discrete

# 2. Build the environment

### 2.1 Create Environment

In [None]:
class WebGame(Env):
    # Setup the environment action and observation shapes
    def __init__(self) -> None:
        # Subclass model
        super().__init__()
        # Setup spaces
        self.observation_space = Box(
            low=0, high=255, shape=(1, 40, 90), dtype=np.uint8
        )
        self.action_space = Discrete(2)

        # Capture game frames
        self.game_location = {"top": 300, "left": 75, "width": 400, "height": 100}
        self.done_location = {"top": 270, "left": 275, "width": 100, "height": 70}

    # What is called to do something in the game
    def step(self, action):
        # Action key - 0 = SpaceBar(Jump), 1 = Duck(down), 2 = No Action(no op)
        # action_map = {0: "up", 1: "down", 2: "no_op"}
        action_map = {0: "up", 1: "no_op"}
        if action != 2:
            pyautogui.press(action_map[action])

        # Checking whether the game is done
        done, done_cap = self.get_done()
        # Get the next observation
        new_obs = self.get_observation()
        # Reward - We get a point for every frame we're alive
        reward = 1
        # Info dictionary
        info = {}

        return new_obs, reward, done, info

    # Visualize the game
    def render(self, mode="human"):
        with mss() as sct:
            cv2.imshow("Game", np.array(sct.grab(self.game_location))[:, :, :3])
            if cv2.waitKey(1) and 0xFF == ord("q"):
                self.close()

    # Restart the game
    def reset(self):
        time.sleep(1)
        pyautogui.click(x=150, y=200)
        pyautogui.press("up")
        return self.get_observation()

    # This closes down the observation
    def close(self):
        cv2.destroyAllWindows()

    # Get the part of the observation of the game that we want
    def get_observation(self):
        with mss() as sct:
            # Get screen capture of game
            raw = np.array(sct.grab(self.game_location))[:, :, :3]
            # Grayscale
            gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
            # Resize
            resized = cv2.resize(gray, (90, 40))
            # Add channels first
            channel = np.reshape(resized, (1, 40, 90))
            return channel

    # Get the done text using OCR
    def get_done(self):
        # Get done screen
        with mss() as sct:
            done_cap = np.array(sct.grab(self.done_location))[:, :, :3]
            done_strings = ["GAME", "GAHE"]
            # Apply OCR
            done = False
            result = pytesseract.image_to_string(done_cap)[:4]
            if result in done_strings:
                done = True
            return done, done_cap


### 2.2 Test Environment

In [None]:
env = WebGame()

In [None]:
plt.imshow(env.get_observation()[0])

In [None]:
done, done_cap = env.get_done()
print(done)
plt.imshow(done_cap)

In [None]:
# # Play 10 games
# for episode in range(10):
#     obs = env.reset()
#     done = False
#     total_reward = 0

#     while not done:
#         obs, reward, done, info = env.step(env.action_space.sample())
#         total_reward += reward
#     print(f"Total Reward for episode {episode} is {total_reward}")


# 3. Train the Model

### 3.1 Create Callback

In [None]:
# Import os for file path management
import os
# IMport Base Callback for saving models
from stable_baselines3.common.callbacks import BaseCallback
# Check Environment
from stable_baselines3.common import env_checker

In [None]:
# Check that the environment is okay
env = WebGame()
env_checker.check_env(env)

In [None]:
# Mostly use in all RL
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose: int = 1):
        super().__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self) -> None:
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, f"best_model_{self.n_calls}")
            self.model.save(model_path)
        return True


In [None]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [None]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

### 3.2 Build DQN and Training

In [None]:
# Import DQN algorithm
from stable_baselines3.dqn import DQN

In [None]:
# Create DQN Model
model = DQN(
    "CnnPolicy",
    env,
    tensorboard_log=LOG_DIR,
    verbose=1,
    buffer_size=180000, # Depends of your RAM
    learning_starts=0,
)

In [None]:
# Kick off training
model.learn(total_timesteps=10000, callback=callback)

# 4. Test Out Model

In [None]:
# Load the best model, inform your best model in @best_model without .zip
best_model = "best_model_8000"
model.load(os.path.join(CHECKPOINT_DIR, best_model))

In [None]:
# Play 10 games with our trained model
for episode in range(10):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        total_reward += reward
    print(f"Total Reward for episode {episode} is {total_reward}")
    time.sleep(2)