# 1. Install and Import Dependencies


In [None]:
# Execute this first
%pip install git+https://github.com/DLR-RM/stable-baselines3
# Then install the package with extras (gymnasium, atari, etc)
%pip install stable-baselines3[extra]

In [None]:
# Install CUDA acceleration
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

In [None]:
# You need to have tesseract installed: https://github.com/UB-Mannheim/tesseract/wiki
%pip install --upgrade pip
%pip install --upgrade Pillow
%pip install pytesseract

In [None]:
# Screen capture and game control, using MSS instead of OpenCV
%pip install mss pydirectinput

In [None]:
import os
# MSS used for screen cap
from mss import mss
# Sending commands
import pydirectinput
# OpenCV allows us to do frame processing
import cv2
# Transformational framework
import numpy as np
# OCR for game over extraction
# Add tesseract to path (Windows = C:\Program Files\Tesseract-OCR)
import pytesseract
# Visualize captured frames
from matplotlib import pyplot as plt
# For pauses
import time
# Environment components
from gymnasium import Env
from gymnasium.spaces import Box, Discrete

# 2. Build the Environment


## 2.1. Create Environment


In [None]:
class WebGame(Env):
    def __init__(self):
        super().__init__()
        # Image
        self.observation_space = Box(
            low=0, high=255, shape=(1, 83, 100), dtype=np.uint8)
        # Action space
        self.action_space = Discrete(3)
        # Define extraction parameters for the game
        self.cap = mss()
        self.game_location = {'top': 300,
                              'left': 0, 'width': 800, 'height': 500}
        self.done_location = {'top': 405,
                              'left': 650, 'width': 300, 'height': 70}

    def step(self, action):
        # Action key - 0 = Space, 1 = Duck, 2 = No action (no op)
        action_map = {
            0: 'space',
            1: 'down',
            2: 'no_op'
        }

        if action != 2:
            pydirectinput.press(action_map[action])

        # Checking whether the game is done
        done, done_cap = self.get_done()

        # Get next observation
        new_observation = self.get_observation()

        # Reward - we get a point for every point we are alive
        reward = 1

        # Info dictionary
        info = {}

        return new_observation, reward, done, False, info

    def render(self):
        cap = np.array(self.cap.grab(self.game_location))[:, :, :3]

        # Run once and show windows until q is pressed
        cv2.imshow("Game", cap)  # This will open an independent window
        if cv2.waitKey(1) & 0xFF == ord('q'):  # quit when 'q' is pressed
            self.close()

        # Run in realtime until q is pressed
        # while True:
        #     cv2.imshow("Game", cap)  # This will open an independent window
        #     if cv2.waitKey(1) & 0xFF == ord('q'):  # quit when 'q' is pressed
        #         cv2.destroyAllWindows()
        #         # normally unnecessary, but it fixes a bug on MacOS where the window doesn't close
        #         cv2.waitKey(1)
        #         break

    def reset(self, seed=None):
        super().reset(seed=seed)
        time.sleep(1)
        pydirectinput.click(x=150, y=150)
        pydirectinput.press('space')
        info = {}
        return self.get_observation(), info

    # Closes down the observation
    def close(self):
        cv2.destroyAllWindows()

    # Get the part of observation of the game we want
    def get_observation(self):
        # Get screen capture of the game
        # Grab returns an array an imagen with four channels, we only need 3 (height, width, first three channels)
        raw = np.array(self.cap.grab(self.game_location))[:, :, :3]
        # Grayscale
        gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
        # Resize
        resized = cv2.resize(gray, (100, 83))
        # Add channels first
        channel = np.reshape(resized, (1, 83, 100))
        return channel

    # Get the done text using OCR
    def get_done(self):
        # Get done screen, take only three channels (np.array(env.get_done()).shape = (70, 660, 4))
        done_cap = np.array(self.cap.grab(self.done_location))[:, :, :3]

        # Valid done text
        done_strings = ['GAME', 'GAHE', 'Go A', 'G A ']

        done = False
        # res - text extracted from cap
        res = pytesseract.image_to_string(done_cap)[:4]

        if res in done_strings:
            done = True

        return done, done_cap


# Game instance
env = WebGame()

In [None]:
# Get the observation space image
plt.imshow(env.observation_space.sample()[0])

### Dino capture


In [None]:
env.get_observation()

In [None]:
# Get a screen capture with image procesing
plt.imshow(env.get_observation()[0])

In [None]:
# Get a screen capture with colors
plt.imshow(cv2.cvtColor(env.get_observation()[0], cv2.COLOR_RGB2BGR))

### Done capture


In [None]:
done, cap_done = env.get_done()

In [None]:
done

In [None]:
plt.imshow(cap_done)

## 2.2. Test Environment


In [None]:
for episode in range(10):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        obs, reward, done, truncated, info = env.step(
            env.action_space.sample())
        total_reward += reward
    print(f'Total reward for episode {episode} is {total_reward}')

# 3. Train the Model


In [None]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

## 3.1. Create Callback


In [None]:
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common import env_checker

In [None]:
# Check that the environment is OK
env_checker.check_env(env)

### Custom Callback


In [None]:
class TraingAndLoggingCallback(BaseCallback):
    def _init_(self, check_freq: int, save_path, verbose=1):
        super(TraingAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(
                self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
        return True

In [None]:
callback = TraingAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

### EvalCallback


In [None]:
# Use deterministic actions for evaluation
eval_callback = EvalCallback(env, best_model_save_path=CHECKPOINT_DIR,
                             log_path="./logs/", eval_freq=1000,
                             deterministic=True, render=False)

## 3.2. Build DQN and Train


In [None]:
from stable_baselines3 import DQN

In [None]:
# Create the DQN model
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR,
            verbose=1, buffer_size=600000, learning_starts=1000)

In [None]:
model.learn(total_timesteps=3000, callback=eval_callback)

# 4. Test out Model


In [None]:
model = DQN.load(os.path.join(CHECKPOINT_DIR, 'best_model'))

for episode in range(10):
    # Reset returns tuple (obs, info)
    obs = env.reset()[0]
    done = False
    total_reward = 0
 
    while not done:
        # action, _states = model.predict(obs[0][:, :, :100])
        action, _ = model.predict(obs)
        obs, reward, done, truncated, info = env.step(int(action))
        total_reward += reward
    print(f'Total reward for episode {episode} is {total_reward}')