In [1]:
import numpy as np
from PIL import Image

import cv2
import io
import time
import random
import pickle
import os
from io import BytesIO
import base64
import json
import pandas as pd
from time import sleep

from collections import deque
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.keys import Keys

from webdriver_manager.chrome import ChromeDriverManager

In [2]:
# https://github.com/Paperspace/DinoRunTutorial/blob/master/Reinforcement%20Learning%20Dino%20Run.ipynb

In [3]:
game_url = "chrome://dino"
chrome_driver_path = ChromeDriverManager().install()

loss_file_path = "./objects/loss.csv"
actions_file_path = "./objects/actions.csv"
q_value_file_path = "./objects/q_values.csv"
scores_file_path = "./objects/scores.csv"

init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'"
getbase64Script = "canvasRunner = document.getElementById('runner-canvas'); return canvasRunner.toDataURL().substring(22)"

In [4]:
def grab_screen(_driver):
    image_b64 = _driver.execute_script(getbase64Script)
    screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64))))
    image = process_img(screen)
    return image

def process_img(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = cv2.resize(image, (80, 80))
    return image

def show_img(graphs = False):
    while True:
        screen = (yield)
        window_title = "logs" if graphs else "game_play"
        cv2.namedWindow(window_title, cv2.WINDOW_NORMAL)
        imS = cv2.resize(screen, (800, 400))
        cv2.imshow(window_title, screen)
        if (cv2.waitKey(1) & 0xFF == ord('q')):
            cv2.destroyAllWindows()
            break

In [5]:
class Game:
    def __init__(self, custom_config=True):
        chrome_options = Options()
        chrome_options.add_argument("disable-infobars")
        chrome_options.add_argument("--mute-audio")
        service = Service(chrome_driver_path)
        self._driver = webdriver.Chrome(service=service, options=chrome_options)
        self._driver.set_window_position(x=300,y=300)
        self._driver.set_window_size(900, 600)
        
        try : 
            self._driver.get(game_url)
        except:
            pass
        
        self._driver.execute_script("Runner.config.ACCELERATION=0")
        self._driver.execute_script(init_script)
        
    def get_crashed(self):
        return self._driver.execute_script("return Runner.instance_.crashed")
    def get_playing(self):
        return self._driver.execute_script("return Runner.instance_.playing")
    def restart(self):
        self._driver.execute_script("Runner.instance_.restart()")
    def press_up(self):
        self._driver.find_element("tag name", "body").send_keys(Keys.ARROW_UP)
    def press_down(self):
        self._driver.find_element("tag name", "body").send_keys(Keys.ARROW_DOWN)
    def get_score(self):
        score_array = self._driver.execute_script("return Runner.instance_.distanceMeter.digits")
        score = ''.join(score_array)
        return int(score)
    def pause(self):
        return self._driver.execute_script("return Runner.instance_.stop()")
    def resume(self):
        return self._driver.execute_script("return Runner.instance_.play()")
    def end(self):
        self._driver.close()

In [6]:
class DinoAgent:
    def __init__(self, game):
        self._game = game
        sleep(1)
        self.jump()
    def is_running(self):
        return self._game.get_playing()
    def is_crashed(self):
        return self._game.get_crashed()
    def jump(self):
        self._game.press_up()
    def duck(self):
        self._game.press_down()

In [7]:
class Game_state:
    def __init__(self, agent, game):
        self._agent = agent
        self._game = game
        self._display = show_img()
        self._display.__next__()
        
    def get_state(self, actions):
        score = self._game.get_score()
        reward = 1
        is_over = False
        
        if actions[1] == 1:
            self._agent.jump()
            reward = -3
        
        image = grab_screen(self._game._driver)
        self._display.send(image)
        
        if self._agent.is_crashed():
            self._game.restart()
            reward = -100
            is_over = True
        
        return image, reward, is_over

In [8]:
# Parameters
PRETRAINED = False
ACTIONS = 2
GAMMA = 0.99
OBSERVATION = 100.  # timesteps to observe before training
EXPLORE = 100000.  # frames over which to anneal epsilon
FINAL_EPSILON = 0.0001  # final value of epsilon
INITIAL_EPSILON = 0.01  # starting value of epsilon
LEARNING_RATE = 1e-4
REPLAY_MEMORY = 50000  # number of previous transitions to remember

### Model

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models


print("cuda : ", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device : ", device)

cuda :  True


In [10]:
class DinoNet(nn.Module):
    def __init__(self, actions):
        super(DinoNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        self.adv_hid = nn.Linear(64, 64)
        self.adv = nn.Linear(64, actions)
        self.val_hid = nn.Linear(64, 64)
        self.val = nn.Linear(64, 1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        adv = F.relu(self.adv_hid(x))
        adv = self.adv(adv)
        val = F.relu(self.val_hid(x))
        val = self.val(val)
        
        q_value = val + adv - adv.mean(dim=1, keepdim=True)
        return q_value

In [11]:
model = DinoNet(actions=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

# create a new model if not exist
if not os.path.isdir("./model"):
    os.makedirs("./model")

In [12]:
def load_model():
    model.load_state_dict(torch.load(f"./latest.pth"))
    
if PRETRAINED:
    load_model()

In [13]:
def trainNetwork(model, game_state, optimizer, loss_fn, num_episodes, batch_size=32):
    replay_memory = deque(maxlen=REPLAY_MEMORY)
    epsilon = INITIAL_EPSILON
    global_step = 0

    for episode in range(num_episodes):
        total_loss = 0
        total_reward = 0
        game_state._game.restart()

        # Get initial state
        state, reward, done = game_state.get_state([1, 0])  # Start with no action
        state = torch.tensor(state, device=device, dtype=torch.float).unsqueeze(0)
        
        # Reduce epsilon
        if episode % 1 == 0:
            epsilon = max(FINAL_EPSILON, epsilon - (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE)

        while not game_state._agent.is_crashed():
            global_step += 1
            action = [0, 1] if random.random() <= epsilon else [1, 0]  # Random or best action based on epsilon
            next_state, reward, done = game_state.get_state(action)
            next_state = torch.tensor(next_state, device=device, dtype=torch.float).unsqueeze(0)

            # Save transition to replay memory
            replay_memory.append((state, action, reward, next_state, done))
            state = next_state
            total_reward += reward

            # Check if the memory is sufficient to sample from
            if len(replay_memory) >= batch_size:
                # Sample a minibatch from replay memory
                minibatch = random.sample(replay_memory, batch_size)
                state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)

                state_batch = torch.stack(state_batch)
                next_state_batch = torch.stack(next_state_batch)
                reward_batch = torch.tensor(reward_batch, device=device, dtype=torch.float)
                action_batch = torch.tensor([a.index(1) for a in action_batch], device=device, dtype=torch.long).unsqueeze(1)
                done_batch = torch.tensor(done_batch, device=device, dtype=torch.float)

                # Compute Q(s_t, a)
                current_q_values = model(state_batch).gather(1, action_batch)

                # Compute Q(s_t+1) for all next states.
                next_q_values = model(next_state_batch).max(1)[0]
                # Compute the target Q values
                target_q_values = reward_batch + (GAMMA * next_q_values * (1 - done_batch))

                # Compute Bellman error
                loss = loss_fn(current_q_values.squeeze(1), target_q_values.detach())

                # Optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                print(f"g_step: {global_step}, action: {np.argmax(action)}, reward: {reward}, loss: {loss.item()}")

            sleep(0.05)  # Sleep to decrease the speed of execution

        print(f"Episode {episode + 1}, Total reward: {total_reward}, Total loss: {total_loss}, Epsilon: {epsilon}")

        # Optionally save the model
        if episode % 10 == 0:
            torch.save(model.state_dict(), f"./model/dino_net_{episode}.pth")

    # Save the final model
    torch.save(model.state_dict(), "./model/dino_net_final.pth")


In [14]:
def playGame():
    game = Game()
    dino = DinoAgent(game)
    game_state = Game_state(dino, game)
    try :
        trainNetwork(model, game_state, optimizer, loss_fn, num_episodes=1000)
    except StopIteration:
        game.end()

In [15]:
playGame()

g_step: 32, action: 0, reward: 1, loss: 2.500885009765625
Episode 1, Total reward: 32, Total loss: 2.500885009765625, Epsilon: 0.009999901
g_step: 33, action: 0, reward: 1, loss: 2.023440361022949
g_step: 34, action: 0, reward: 1, loss: 1.6406183242797852
g_step: 35, action: 0, reward: 1, loss: 1.4479327201843262
g_step: 36, action: 0, reward: 1, loss: 1.2986059188842773
g_step: 37, action: 0, reward: 1, loss: 1.1987905502319336
g_step: 38, action: 0, reward: 1, loss: 1.1500740051269531
g_step: 39, action: 0, reward: 1, loss: 1.2284672260284424
g_step: 40, action: 0, reward: 1, loss: 1.1804471015930176
g_step: 41, action: 0, reward: 1, loss: 1.1831492185592651
g_step: 42, action: 0, reward: 1, loss: 1.2884538173675537
g_step: 43, action: 0, reward: 1, loss: 1.2667040824890137
g_step: 44, action: 0, reward: 1, loss: 1.2824392318725586
g_step: 45, action: 0, reward: 1, loss: 1.1499900817871094
g_step: 46, action: 0, reward: 1, loss: 0.9541349411010742
g_step: 47, action: 0, reward: 1, lo

NoSuchWindowException: Message: no such window: target window already closed
from unknown error: web view not found
  (Session info: chrome=124.0.6367.91)
Stacktrace:
	GetHandleVerifier [0x0104C113+48259]
	(No symbol) [0x00FDCA41]
	(No symbol) [0x00ED0A17]
	(No symbol) [0x00EAE02B]
	(No symbol) [0x00F3742E]
	(No symbol) [0x00F49476]
	(No symbol) [0x00F30B36]
	(No symbol) [0x00F0570D]
	(No symbol) [0x00F062CD]
	GetHandleVerifier [0x01306533+2908323]
	GetHandleVerifier [0x01343B4B+3159739]
	GetHandleVerifier [0x010E505B+674763]
	GetHandleVerifier [0x010EB21C+699788]
	(No symbol) [0x00FE6244]
	(No symbol) [0x00FE2298]
	(No symbol) [0x00FE242C]
	(No symbol) [0x00FD4BB0]
	BaseThreadInitThunk [0x75F87BA9+25]
	RtlInitializeExceptionChain [0x777FBE3B+107]
	RtlClearBits [0x777FBDBF+191]
