# 1. Install and Import Dependencies

In [None]:
!pip install torch torchvision torchaudio 

In [None]:
!pip install jupyter_contrib_nbextensions
!jupyter contrib nbextension install --user
!jupyter nbextension enable hinterland/hinterland

# 2. Build the Environment

In [None]:
!pip install stable-baselines3[extra] protobuf==3.20.*
# Protobuf is efficient for serializing and deserializing complex data structures, 
# which is beneficial for machine learning models and their configurations

In [None]:
# https://pypi.org/project/PyDirectInput/ will be used to send key input 
# https://pypi.org/project/pytesseract/ will be used to read characters from game
## tesseract needs https://github.com/UB-Mannheim/tesseract/wiki - 
## mss for screen capture - faster than opencv

In [None]:
!pip install pytesseract mss pydirectinput

In [1]:
# Mss for screen cap
from mss import mss
# Sending commands 
import pydirectinput
# OpenCV allows us to frame processing
import cv2
# Transformational framework
import numpy as np
# For OCR "GAME OVER" extraction
import pytesseract
# Visualize captured frames
from matplotlib import pyplot as plt
# For pauses - between when we send commands vs get back frames
import time
# Environment components - Box is shape of game input, Discrete is shape of action
from gym import Env
from gym.spaces import Box, Discrete

In [2]:
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'

In [3]:
import pytesseract
print(pytesseract.get_tesseract_version())

5.4.0.20240606


## 2.1 Create Environment

In [4]:
class WebGame(Env):
    # Setup the environment action and observation shape
    def __init__(self):
        # Subclass model - uses to use the base class
        super().__init__()
        # Setup the environment
        self.observation_space = Box(low=0, high=255, shape=(1, 83,100), dtype=np.uint8)
        # Action key - 0 = Spacebar, 1 = Duck(down) 2 = No action(no op)
        self.action_space = Discrete(2)
        # Define extraction parameters for the game
        # https://python-mss.readthedocs.io/api.html?highlight=grab#mss.base.MSSBase.grab
        self.cap = mss()
        self.game_location = {'top':300, 'left':0, 'width':600, 'height':500}
        self.done_location = {'top':400, 'left':500, 'width':900, 'height':140}
        
    # What is called to do something in the game, typically we get back next game frame, reward, done, other info
    def step(self, action):
        action_map = {
            0:'space',
            1:'no_op'
        }
        if action != 2:
            pydirectinput.press(action_map[action])

        # Checking if the game is done
        done, done_cap = self.get_done()
        # Get the next observation
        new_observation = self.get_observation()
        # Reward - we get a point for every frame that we are alive
        reward = 1
        # Info dictionary
        info = {}

        return new_observation, reward, done, info

    # Visualize the game
    def render(self):
        cv2.imshow('Game', np.array(self.cap.grab(self.game_location))[:,:,:3])
        if cv2.waitKey(1) & 0xFF == ord('q'):
            self.close()
        
    # Restart the game
    def reset(self):
        time.sleep(1)
        pydirectinput.click(x=50, y=50)
        pydirectinput.press('space')
        return self.get_observation()
        
    # Get the part of the observation of the game
    def get_observation(self):
        # Get the screen capture of game
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3]
        # Grayscaling
        gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
        # Resize
        resized = cv2.resize(gray, (100, 83))
        # Add channels first
        channel = np.reshape(resized, (1, 83, 100))
        
        return channel
    # Checks if "game over" is displayed in game
    def get_done(self):
        # Get done screen
        done_cap = np.array(self.cap.grab(self.done_location))[:,:,:3]
        # Valid done text
        done_strings = ["GAME", "chro", "GAHE"]
        
        # Apply OCR
        done = False
        res = pytesseract.image_to_string(done_cap)[:4]
        if res in done_strings:
            done = True
        
        return done, done_cap
    # This closes the observation
    def close(self):
        cv2.destroyAllWindows()

In [None]:
env = WebGame()

In [None]:
done, done_cap = env.get_done()

In [None]:
env.reset()

In [None]:
done

In [None]:
pytesseract.image_to_string(done_cap)[:4]

In [None]:
plt.imshow(np.array(env.get_done()))

In [None]:
env.action_space.sample()

In [None]:
plt.imshow(cv2.cvtColor(env.get_observation()[0], cv2.COLOR_BGR2RGB))

## 2.2 Test Environment

In [12]:
env = WebGame()

In [None]:
obs = env.get_observation()

In [None]:
pytesseract.

# 3. Train the Model

In [None]:
# Play 10 games
for episode in range(1):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        obs, reward, done, info = env.step(env.action_space.sample())
        total_reward += reward
    print(f'Total Reward for espisode {episode} is {total_reward}')

## 3.1 Create Callback

In [5]:
# Import os for file navigation
import os
# Import callback class from sb3
from stable_baselines3.common.callbacks import BaseCallback
# Impor the sb3 environment checker
from stable_baselines3.common import env_checker

In [None]:
env_checker.check_env(env)

In [None]:
env.step(env.action_space.sample())

In [6]:
class TrainingLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainingLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [7]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/ChromeDino'

In [8]:
callback = TrainingLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

## 3.2 Build DQN and Train

In [10]:
# Import DQN algo
from stable_baselines3 import DQN

In [13]:
# Create the DQN model
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1,
            buffer_size=200_000, learning_starts=1000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [14]:
# Training
model.learn(total_timesteps=5000, callback=callback)

Logging to ./logs/ChromeDino\DQN_4
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 9.5      |
|    ep_rew_mean      | 9.5      |
|    exploration_rate | 0.928    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1        |
|    time_elapsed     | 27       |
|    total_timesteps  | 38       |
----------------------------------


FailSafeException: PyDirectInput fail-safe triggered from mouse moving to a corner of the screen. To disable this fail-safe, set pydirectinput.FAILSAFE to False. DISABLING FAIL-SAFE IS NOT RECOMMENDED.

# 4. Test out the Model

In [None]:
model = DQN.load(os.path.join(''))

In [None]:
# Play 10 games
for episode in range(10):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        total_reward += reward
    print(f'episode reward for {episode} is {total_reward}')        

In [None]:
env.close()