In [15]:
import numpy as np
from PIL import Image

import cv2
import io
import time
import random
import pickle
import os
from io import BytesIO
import base64
import json
import pandas as pd
from time import sleep

from collections import deque
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.keys import Keys

from webdriver_manager.chrome import ChromeDriverManager

In [16]:
# https://github.com/Paperspace/DinoRunTutorial/blob/master/Reinforcement%20Learning%20Dino%20Run.ipynb

In [17]:
game_url = "chrome://dino"
chrome_driver_path = ChromeDriverManager().install()

loss_file_path = "./objects/loss.csv"
actions_file_path = "./objects/actions.csv"
q_value_file_path = "./objects/q_values.csv"
scores_file_path = "./objects/scores.csv"

init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'"
getbase64Script = "canvasRunner = document.getElementById('runner-canvas'); return canvasRunner.toDataURL().substring(22)"

In [18]:
def grab_screen(_driver):
    image_b64 = _driver.execute_script(getbase64Script)
    screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64))))
    image = process_img(screen)
    return image

def process_img(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = cv2.resize(image, (80, 80))
    return image

def show_img(graphs = False):
    while True:
        screen = (yield)
        window_title = "logs" if graphs else "game_play"
        cv2.namedWindow(window_title, cv2.WINDOW_NORMAL)
        imS = cv2.resize(screen, (800, 400))
        cv2.imshow(window_title, screen)
        if (cv2.waitKey(1) & 0xFF == ord('q')):
            cv2.destroyAllWindows()
            break

In [19]:
class Game:
    def __init__(self, custom_config=True):
        chrome_options = Options()
        chrome_options.add_argument("disable-infobars")
        chrome_options.add_argument("--mute-audio")
        service = Service(chrome_driver_path)
        self._driver = webdriver.Chrome(service=service, options=chrome_options)
        self._driver.set_window_position(x=300,y=300)
        self._driver.set_window_size(900, 600)
        
        try : 
            self._driver.get(game_url)
        except:
            pass
        
        self._driver.execute_script("Runner.config.ACCELERATION=0")
        self._driver.execute_script(init_script)
        
    def get_crashed(self):
        return self._driver.execute_script("return Runner.instance_.crashed")
    def get_playing(self):
        return self._driver.execute_script("return Runner.instance_.playing")
    def restart(self):
        self._driver.execute_script("Runner.instance_.restart()")
    def press_up(self):
        self._driver.find_element("tag name", "body").send_keys(Keys.ARROW_UP)
    def press_down(self):
        self._driver.find_element("tag name", "body").send_keys(Keys.ARROW_DOWN)
    def get_score(self):
        score_array = self._driver.execute_script("return Runner.instance_.distanceMeter.digits")
        score = ''.join(score_array)
        return int(score)
    def pause(self):
        return self._driver.execute_script("return Runner.instance_.stop()")
    def resume(self):
        return self._driver.execute_script("return Runner.instance_.play()")
    def end(self):
        self._driver.close()

In [20]:
class DinoAgent:
    def __init__(self, game):
        self._game = game
        sleep(1)
        self.jump()
    def is_running(self):
        return self._game.get_playing()
    def is_crashed(self):
        return self._game.get_crashed()
    def jump(self):
        self._game.press_up()
    def duck(self):
        self._game.press_down()

In [21]:
class Game_state:
    def __init__(self, agent, game):
        self._agent = agent
        self._game = game
        self._display = show_img()
        self._display.__next__()
        
    def get_state(self, actions):
        # actions_df.loc[len(actions_df)] = [actions]
        score = self._game.get_score()
        reward = 0.1
        is_over = False
        
        if actions[1] == 1:
            self._agent.jump()
            reward = 0
        
        image = grab_screen(self._game._driver)
        self._display.send(image)
        
        if self._agent.is_crashed():
            # scores_df.loc[len(loss_df)] = score
            self._game.restart()
            reward = -5
            is_over = True
        
        return image, reward, is_over

In [22]:
# Parameters
PRETRAINED = False
ACTIONS = 2
GAMMA = 0.99
OBSERVATION = 100.  # timesteps to observe before training
EXPLORE = 100000.  # frames over which to anneal epsilon
FINAL_EPSILON = 0.0001  # final value of epsilon
INITIAL_EPSILON = 0.01  # starting value of epsilon
REPLAY_MEMORY = 50000  # number of previous transitions to remember
LEARNING_RATE = 1e-4

### Model

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class DinoNet(nn.Module):
    def __init__(self):
        super(DinoNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=2, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, ACTIONS)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [24]:
model = DinoNet()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

# create a new model if not exist
if not os.path.isdir("./model"):
    os.makedirs("./model")

In [25]:
def load_model():
    model.load_state_dict(torch.load(f"./latest.pth"))
    
if PRETRAINED:
    load_model()

In [26]:
def trainNetwork(model, game_state):
    last_time = time.time()
    epsilon = INITIAL_EPSILON
    t = 0
    
    running = True
    while(running):
        sleep(0.03)
        loss_sum = 0
        action_index = 0
        
        if t == 0:      # initialize
            x_t, _, _ = game_state.get_state(np.array([1, 0]))
            s_t = x_t
        
        # choose an action epsilon greedy
        random_action = np.random.rand() <= epsilon
        s_t_tensor = torch.tensor(s_t).float().unsqueeze(0).unsqueeze(0)
        action_index = np.random.randint(ACTIONS) if random_action else model(s_t_tensor).argmax().item()
        a_t = np.zeros([ACTIONS])
        a_t[action_index] = 1
                
        # reduce epsilon gradually
        if epsilon > FINAL_EPSILON and t % 1000 == 0:
            epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
            
            
        # observe outcome
        x_t1, r_t, terminal = game_state.get_state(a_t)
        s_t1 = x_t1
        
        # only train if done observing
        target = r_t
        if not terminal:
            s_t1_tensor = torch.tensor(s_t1).float().unsqueeze(0).unsqueeze(0)
            Q_sa = model(s_t1_tensor).detach().numpy()
            target = r_t + GAMMA * np.max(Q_sa)
        
        # single step update
        q_val = model(s_t_tensor)
        target_f = q_val.clone().detach()
        target_f[0, action_index] = target
        
        loss = loss_fn(q_val, target_f)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_sum += loss.item()
        
        # Transition to new state
        s_t = s_t1
        t += 1
        
        # save progress every 1000 iterations
        if t % 1000 == 0:
            game_state._game.pause() #pause game while saving to filesystem
            torch.save(model.state_dict(), f"./model/episode_{t}.pth")
            torch.save(model.state_dict(), f"./latest.pth")
            game_state._game.resume()
            
        print(f'timestep: {t}, random: {random_action}, epsilon: {round(epsilon, 3)}, action: {action_index}, reward: {r_t}, Q_max: {round(np.max(Q_sa),3)}, loss: {round(loss_sum, 3)}')
        

In [27]:
def playGame():
    game = Game()
    dino = DinoAgent(game)
    game_state = Game_state(dino, game)
    try :
        trainNetwork(model, game_state)
    except StopIteration:
        game.end()

In [28]:
playGame()

timestep: 1, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: 0.020999999716877937, loss: 0.011
timestep: 2, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.020999999716877937, loss: 0.0
timestep: 3, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.16099999845027924, loss: 0.005
timestep: 4, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.20800000429153442, loss: 0.0
timestep: 5, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.05999999865889549, loss: 0.015
timestep: 6, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: 0.06199999898672104, loss: 0.009
timestep: 7, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.07699999958276749, loss: 0.011
timestep: 8, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.289000004529953, loss: 0.022
timestep: 9, random: False, epsilon: 0.01, action: 1, reward: 0, Q_max: -0.3050000071525574, loss: 0.0
timestep: 10, random: False, epsilon: 0.01, action: 1, 

NoSuchWindowException: Message: no such window: target window already closed
from unknown error: web view not found
  (Session info: chrome=123.0.6312.122)
Stacktrace:
	GetHandleVerifier [0x00734CA3+225091]
	(No symbol) [0x00664DF1]
	(No symbol) [0x00509A7A]
	(No symbol) [0x004EE312]
	(No symbol) [0x0056517B]
	(No symbol) [0x005755A6]
	(No symbol) [0x0055F2F6]
	(No symbol) [0x005379B9]
	(No symbol) [0x0053879D]
	sqlite3_dbdata_init [0x00BA9A43+4064547]
	sqlite3_dbdata_init [0x00BB104A+4094762]
	sqlite3_dbdata_init [0x00BAB948+4072488]
	sqlite3_dbdata_init [0x008AC9A9+930953]
	(No symbol) [0x006707C4]
	(No symbol) [0x0066ACE8]
	(No symbol) [0x0066AE11]
	(No symbol) [0x0065CA80]
	BaseThreadInitThunk [0x75947BA9+25]
	RtlInitializeExceptionChain [0x76FABE3B+107]
	RtlClearBits [0x76FABDBF+191]
