# Chrome Dino re-write 2023

Use Pytorch, stable-baselines3, etc. to build a model to play chrome://dino with reinforcement learning.
Works best with "Start Slow" disabled

## dependencies (todo: make these a requirements.txt)
- CUDA-enabled Pytorch (Pytorch 2.00; CUDA 1.18)
- Stable-Baselines3 (with extras like OpenCV): https://stable-baselines3.readthedocs.io/en/master/
- Protobuf (a training dependency) (3.20.*)
- pytesseract (interface to Google Tesseract)
- Google Tesseract-OCR ((5.3.1.20230401)
- Gym (gym v0.21 since this is used by Stable-Baselines3; RL environment library): https://gymnasium.farama.org/
- MSS (crossplatform screenshots)
- openCV (2)
- selenium (chrome test driver) (ChromeDriver 112.0.5615.49)

In [None]:
%%capture
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
%%capture
# needed to support old version of gym (0.21) used with Stable-Baselines3
# gym 0.21 has installation issue; gym moved to Gymnasium; SB3 still does not support Gymnasium 
!pip install setuptools==66 Cmake git+https://github.com/openai/gym.git@9180d12e1b66e7e2a1a622614f787a6ec147ac40

In [None]:
%%capture
!pip install stable-baselines3[extra] protobuf==3.20.*

In [None]:
%%capture
!pip install mss pydirectinput selenium webdriver-manager

In [None]:
%%capture
!pip install pytesseract

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import time

from mss import mss
from gym import Env
from gym.spaces import Box, Discrete
import pytesseract
import pydirectinput

from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from webdriver_manager.chrome import ChromeDriverManager
from selenium.common.exceptions import WebDriverException

## set up environment

In [2]:
# create a base Gym environment for managing state
# Env must implement step and reset; optionally, render 
# https://www.gymlibrary.dev/api/core/
class ChromeDinoRL(Env):
    
    # initialize environment spaces
    def __init__(self):
        super().__init__()
        
        # observation parameters
        self.capture = mss()
        self.game_window = {'top':140, 'left':20, 'width':540, 'height':300}
        self.game_over_window = {'top':260, 'left':240, 'width':320, 'height':60}
        
        self.obs_width = int(self.game_window["width"] / 4)
        self.obs_height = int(self.game_window["height"] / 4)
        
        # set up observations and actions
        # multidimensional array as return output from observation (an image)
        self.observation_space = Box(low=0, high=255, shape=(1, self.obs_height, self.obs_width), dtype=np.uint8)
        # two actions: jump or not
        self.action_space = Discrete(2)
        
    # run one timestep in the environment
    def step(self, action):
        # 0 -> spacebar, 1 -> noop
        action_map = {
            0: 'space',
            1: 'no_op'
        }
        
        if action == 0:
            pydirectinput.press(action_map[action])
        
        # update state each timestep
        game_over = self.game_over()[0]
        obs = self.observe_env()
        reward = 1 # stay alive
        # experiment with different reward mechanisms
        
        info = {}
        
        return obs, reward, game_over, info
        
    # reset to initial state, return initial observation
    def reset(self):
        time.sleep(1)
        pydirectinput.click(x=200, y=200)
        pydirectinput.press('space')
        return self.observe_env()
        
    # capture screen and render with cv2
    # close everything with "q" 
    def render(self):
        cv2.imshow('Chrome Dino', np.array(self.capture.grab(self.game_window)))
        if cv2.waitKey(1) & 0xFF == ord('q'):
            self.close()
    
    def close(self):
        cv2.destroyAllWindows()
        
    # do an observation
    def observe_env(self):        
        #obs = self.capture.grab(self.game_window)
        obs = np.array(self.capture.grab(self.game_window)).astype(np.uint8)
        
        # preprocessing -> grayscale, shrink, then reshape for Stable Baselines
        gray = cv2.cvtColor(obs, cv2.COLOR_BGR2GRAY)
        shrunk = cv2.resize(gray, (self.obs_width, self.obs_height))
        channel = np.reshape(shrunk, (1, self.obs_height, self.obs_width))
        
        return channel
        
    # identify game over
    def game_over(self):
        game_over = np.array(self.capture.grab(self.game_over_window)).astype(np.uint8)
        check_words = ['GAME', 'OVER']
        
        # preprocess for tesseract-OCR
        
        # grayscale, gaussian blur, Otsu's threshold
        # reference: https://opencv24-python-tutorials.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_thresholding/py_thresholding.html#otsus-binarization
        gray = cv2.cvtColor(game_over, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (3,3), 0)
        thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

        # morphological ops to remove remaining noise and invert image to black on white
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
        opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1) # erosion + dilation 
        invert = 255 - opening
        
        OCR = pytesseract.image_to_string(game_over)
        
        words = OCR.split()
        
        if all(word in words for word in check_words):
            return True, OCR, invert, words
        else:
            return False, OCR, invert, words
    

## set up driver

automated window setup for consistent environment

In [3]:
# set up webdriver for selenium on chrome
# requires chromedriver present in path specified
class ChromeDinoDriver():

    def __init__(self):
        self.chrome_driver = None
        self.chrome_path = r"chromedriver\chromedriver.exe" 
        self.startpage = "chrome://dino/"
        self.chrome_options = Options()
        self.chrome_options.add_argument("--window-size=640,480")
        
    # set up and run driver
    # options 
    def run(self):
        #self.chrome_driver = webdriver.Chrome(executable_path=self.chrome_path, options=self.chrome_options)
        self.chrome_driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=self.chrome_options)
        self.chrome_driver.set_window_position(0, 0, windowHandle='current')
        try:
            self.chrome_driver.get("chrome://dino")
        except WebDriverException:
            pass # ignore selenium complaining about offline

    def end(self): # close driver
        # duplicates are automatically closed by new selenium service
        self.chrome_driver.quit()

### misc testing

In [4]:
# test environment
env = ChromeDinoRL()

In [None]:
# random action; observation space empty
print(env.action_space.sample())
plt.imshow(env.observation_space.sample()[0])

In [None]:
# try an observation
obs = env.observe_env()
plt.imshow(cv2.cvtColor(obs[0], cv2.COLOR_BGR2RGB))

In [None]:
# test driver
driver = ChromeDinoDriver()
driver.run()
time.sleep(2)
driver.end()

In [None]:
# start driver and try an observation
driver = ChromeDinoDriver()
driver.run()
plt.imshow(cv2.cvtColor(env.observe_env()[0], cv2.COLOR_BGR2RGB))
driver.end()

In [None]:
# start driver for testing
driver = ChromeDinoDriver()
driver.run()

In [None]:
env.render() # close with env.close()

In [None]:
env.close()

In [None]:
env.reset()

In [None]:
plt.imshow(cv2.cvtColor(env.observe_env()[0], cv2.COLOR_BGR2RGB))

In [None]:
go = env.game_over()
print("Game Over? : " + str(go[0]))
print(go[1])
print(go[3])
plt.imshow(cv2.cvtColor(go[2], cv2.COLOR_BGR2RGB))

In [None]:
driver.end()

In [None]:
# test run

env = ChromeDinoRL()
driver = ChromeDinoDriver()
driver.run()

for ep in range(4):
    obs = env.reset()
    game_over = False 
    score = 0
    
    while not game_over:
        obs, reward, game_over, info = env.step(env.action_space.sample())
        score += reward
    
    print('Score for ep {} = {}.'.format(ep, score))
    
driver.end()

## train model

Train the DQN and save it using a Stable-Baselines3 callback

In [5]:
# training imports
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
import os

In [6]:
# create stable baselines callback for saving model and logs during training
    
CHECKPOINT_DIR = './train/'
MODELS_DIR = './models/'
LOG_DIR = './logs/'

class TrainAndLogCallback(BaseCallback):
    
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLogCallback, self).__init__(verbose)
        self.check_freq = check_freq # number of steps between checkpoints
        self.save_path = save_path
        self.best_mean_reward = -np.inf
        self.log_dir = LOG_DIR
        
    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
            
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            # save model every check_freq steps
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
            
            # get training reward
            x, y = ts2xy(load_results(self.log_dir), 'timesteps')
            if len(x) > 0:
                # show mean training reward over the last check_freq episodes
                mean_reward = np.mean(y[(-1*self.check_freq):])
                if self.verbose > 0:
                    print("Num timesteps: {}".format(self.num_timesteps))
                    print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(self.best_mean_reward, mean_reward))
            if mean_reward > self.best_mean_reward:
                self.best_mean_reward = mean_reward
        
        return True

In [7]:
callback = TrainAndLogCallback(check_freq=360, save_path=CHECKPOINT_DIR)

In [8]:
# create DQN
# reference: https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html

env = ChromeDinoRL()
env = Monitor(env, LOG_DIR) # monitor progress using tensorboard logs

# check environment - no output means no issues
#env_checker.check_env(env)

model = DQN(
    'CnnPolicy', # passing image observation
    env, # registered in Gym 
    tensorboard_log=LOG_DIR,
    verbose=1,
    buffer_size=600000, # size of replay buffer; 600k = 12GB-ish RAM
    learning_starts=100) # how many steps to collect transitions for before learning starts (100 maybe?)

Using cuda device
Wrapping the env in a DummyVecEnv.


In [10]:
# train DQN
# higher timesteps = longer training, think TF epochs

driver = ChromeDinoDriver()
driver.run()
time.sleep(1)

timesteps = 360000
model.learn(total_timesteps=timesteps, callback=callback) #100k decent model

time.sleep(1)
driver.end()

plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "Chrome Dino")
plt.show()

Logging to ./logs/DQN_7
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 15.8     |
|    ep_rew_mean      | 15.8     |
|    exploration_rate | 0.998    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 2        |
|    time_elapsed     | 29       |
|    total_timesteps  | 63       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 12.5     |
|    ep_rew_mean      | 12.5     |
|    exploration_rate | 0.997    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1        |
|    time_elapsed     | 50       |
|    total_timesteps  | 100      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 12.5     |
|    ep_rew_mean      | 12.5     |
|    exploration_rate | 0.996    |
| time/               |        

KeyboardInterrupt: 

## continual learning

Load weights for a pre-trained model and continue training it

In [None]:
# pick up where left off and continue training
# reference: https://stable-baselines.readthedocs.io/en/master/guide/examples.html#continual-learning
# reference: https://github.com/hill-a/stable-baselines/issues/599

env = ChromeDinoRL()
env = Monitor(env, LOG_DIR) # monitor progress using tensorboard logs

# check environment - no output means no issues
#env_checker.check_env(env)

model_name = 'best_model_'
log_name = ''
model = DQN.load(os.path.join(CHECKPOINT_DIR, model_name), tensorboard_log=log_name)
model.set_env(env)

driver = ChromeDinoDriver()
driver.run()
time.sleep(1)

timesteps = 36000
model.learn(total_timesteps=timesteps, callback=callback)

time.sleep(1)
driver.end()

plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "Chrome Dino")
plt.show()

## use model

Load weights for a pre-trained model and use

In [None]:
# load model

model_name = 'best_model_'
model = DQN.load(os.path.join(MODELS_DIR, model_name))

In [None]:
# run model in driver and see how it does
env = ChromeDinoRL()
driver = ChromeDinoDriver()
driver.run()
time.sleep(1)

for ep in range(1):
    obs = env.reset()
    game_over = False 
    score = 0
    
    while not game_over:
        action, _ = model.predict(obs)
        obs, reward, game_over, info = env.step(int(action))
        score += reward
    
    print('Score for ep {} = {}.'.format(ep, score))
    time.sleep(3)
    
driver.end()