# Reinforcement Learning on Chrome's offline Dino
___

In [7]:
import os
import pandas as pd
import numpy as np
import json
import random
import pickle
import enum
import time
from tqdm import tqdm
from collections import deque

# image processing
import cv2
import base64
from io import BytesIO
%matplotlib inline
from matplotlib import pyplot as plt
from PIL import Image

# selenium and drivers
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.chrome.options import Options
from webdriver_manager.chrome import ChromeDriverManager

# ML (keras + tensorflow)
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.optimizers import Adam
from keras.callbacks import TensorBoard

import tensorflow as tf
from ModifiedTensorBoard import ModifiedTensorBoard
from helpers import load_pkl, save_pkl

## Notebook variables

In [8]:
DISCOUNT = 0.99
REPLAY_MEMORY_SIZE = 50_000  # How many last steps to keep for model training
MIN_REPLAY_MEMORY_SIZE = 1_000  # Minimum number of steps in a memory to start training
MINIBATCH_SIZE = 32  # How many steps (samples) to use for training
UPDATE_TARGET_EVERY = 5  # Terminal states (end of episodes)
MODEL_NAME = '2x256'
MIN_REWARD = -200  # For model save
MEMORY_FRACTION = 0.20

# Environment settings
EPISODES = 10000

# Exploration settings
epsilon = 1  # not a constant, going to be decayed
EPSILON_DECAY = 0.999
MIN_EPSILON = 0.001

#  Stats settings
AGGREGATE_STATS_EVERY = 20  # episodes
SHOW_PREVIEW = False

## Implementation

In [9]:
class Game: 
    def __init__(self):
        options = Options()
        options.add_argument("disable-infobars")
        options.add_argument("--mute-audio")
        self.driver = webdriver.Chrome(ChromeDriverManager().install())
        self.driver.get("chrome://dino")
        self.driver.execute_script("Runner.config.ACCELERATION=0")
        self.driver.execute_script("document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'")
    
    def __exec(self, command):
        self.driver.execute_script("Runner.instance_.{}".format(command))
        
    def __get_val(self, value): 
        return self.driver.execute_script("return Runner.instance_.{}".format(value))
    
    def press_key(self, key):
        self.driver.find_element_by_tag_name("body").send_keys(key)
        
    def pause(self):
        self.__exec("stop()")
        
    def resume(self):
        self.__exec("play()")
        
    def restart(self):
        self.__exec("restart()")
    
    def get_crashed(self):
        return self.__get_val("crashed")
    
    def get_playing(self):
        return self.__get_val("playing")
    
    def get_score(self):
        return int(''.join(self.__get_val("distanceMeter.digits")))
    
    def get_driver(self):
        return self.driver
    
    def end_game(self):
        self.driver.close()

In [10]:
class Player: 
    def __init__(self, game): 
        self._game = game
            
    def do_action(self, choice):
        if choice == 0:  # walk
            pass
        
        if choice == 1:  # jump
            self._game.press_key(Keys.ARROW_UP)
            
        if choice == 2:  # duck
            self._game.press_key(Keys.ARROW_DOWN)
            
    def is_running(self):
        return self._game.get_playing()
    
    def is_crashed(self):
        return self._game.get_crashed()


    
class GameEnv:
    MOVE_REWARD = 1
    CRASH_PENALTY = -10
    IMAGE_SIZE = (84, 84)
    OBSERVATION_SPACE_SIZE = (*IMAGE_SIZE, 4)
    ACTION_SPACE_SIZE = 2
    
    def __init__(self, game, player):
        self._game = game
        self._player = player
        
        self.current_state = self.get_initial_state()
            
    def get_screenshot(self):
        getbase64Script = "canvasRunner = document.getElementById('runner-canvas');\
        return canvasRunner.toDataURL().substring(22)"

        # take screenshot
        image_b64 = self._game.driver.execute_script(getbase64Script)
        screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64))))

        # color, crop and resize
        image = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
        image = image[:300, :600]
        image = cv2.resize(image, self.IMAGE_SIZE)
    
        return image
    
    def get_initial_state(self):
        observation = self.get_screenshot()
        return np.stack((observation, observation, observation, observation), axis=2)
    
    def restart_game(self):
        self._game.restart()

    def step(self, action):
        self._player.do_action(action)
        
        reward = self.MOVE_REWARD
        done = False
        
        new_observation = self.get_screenshot()  # (80, 80)
        new_state = np.append(new_observation.reshape(*new_observation.shape, 1), self.current_state[:,:,:-1], axis=2)
        self.current_state = new_state

        if self._player.is_crashed():
            reward = self.CRASH_PENALTY
            done = True

        return new_state, reward, done
    
class DQNAgent:
    def __init__(self, env):

        # Main model
        self.model = self.create_model()

        # Target network
        self.target_model = self.create_model()
        self.target_model.set_weights(self.model.get_weights())

        # An array with last n steps for training
        self.replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)

        # Custom tensorboard object
        self.tensorboard = ModifiedTensorBoard(log_dir="logs/{}-{}".format(MODEL_NAME, int(time.time())))

        # Used to count when to update target network with main network's weights
        self.target_update_counter = 0
        
        self.env = env
    
        
    def create_model(self):
        model = Sequential()
        
        model.add(Conv2D(32, (8, 8), padding='same',strides=(4, 4),input_shape=env.OBSERVATION_SPACE_SIZE))
        model.add(MaxPooling2D(pool_size=(2,2)))
        model.add(Activation('relu'))
        
        model.add(Conv2D(64, (4, 4),strides=(2, 2),  padding='same'))
        model.add(MaxPooling2D(pool_size=(2,2)))
        model.add(Activation('relu'))
        
        model.add(Conv2D(64, (3, 3),strides=(1, 1),  padding='same'))
        model.add(MaxPooling2D(pool_size=(2,2)))
        model.add(Activation('relu'))
        
        model.add(Flatten())
        model.add(Dense(64))
        model.add(Activation('relu'))
        
        model.add(Dense(env.ACTION_SPACE_SIZE)
        model.compile(loss='mse', optimizer=Adam(lr=1e-2))
        
        return model
    
    def update_replay_memory(self, transition):
        self.replay_memory.append(transition)
    
    # Trains main network every step during episode
    def train(self, terminal_state, step):

        # Start training only if certain number of samples is already saved
        if len(self.replay_memory) < MIN_REPLAY_MEMORY_SIZE:
            return

        # Get a minibatch of random samples from memory replay table
        minibatch = random.sample(self.replay_memory, MINIBATCH_SIZE)

        # Get current states from minibatch, then query NN model for Q values
        current_states = np.array([transition[0] for transition in minibatch])
        current_qs_list = self.model.predict(current_states/255)

        # Get future states from minibatch, then query NN model for Q values
        # When using target network, query it, otherwise main network should be queried
        new_current_states = np.array([transition[3] for transition in minibatch])
        future_qs_list = self.target_model.predict(new_current_states/255)

        X, y = [], []
        
        for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):

            # If not a terminal state, get new q from future states, otherwise set it to 0
            # almost like with Q Learning, but we use just part of equation here
            if not done:
                max_future_q = np.max(future_qs_list[index])
                new_q = reward + DISCOUNT * max_future_q
            else:
                new_q = reward

            # Update Q value for given state
            current_qs = current_qs_list[index]
            current_qs[action] = new_q

            # And append to our training data
            X.append(current_state)
            y.append(current_qs)
            
        # Fit on all samples as one batch
        self.model.fit(np.array(X)/255, np.array(y), batch_size=MINIBATCH_SIZE, verbose=0, shuffle=False, callbacks=None)

        # Update target network counter every episode
        if terminal_state:
            self.target_update_counter += 1

        # If counter reaches set value, update target network with weights of main network
        if self.target_update_counter >= UPDATE_TARGET_EVERY:
            self.target_model.set_weights(self.model.get_weights())
            self.target_update_counter = 0
                
    def get_qs(self, state):
        return self.model.predict(np.array(state).reshape(-1, *state.shape)/255)[0]


In [15]:
minibatch = random.sample(agent.replay_memory, MINIBATCH_SIZE)

# Get current states from minibatch, then query NN model for Q values
current_states = np.array([transition[0] for transition in minibatch])
current_qs_list = agent.amodel.predict(current_states/255)

# Get future states from minibatch, then query NN model for Q values
# When using target network, query it, otherwise main network should be queried
new_current_states = np.array([transition[3] for transition in minibatch])
future_qs_list = agent.target_model.predict(new_current_states/255)

In [44]:
agent.replay_memory[1000][4]

False

In [46]:
future_qs_list[0]

array([1., 1.], dtype=float32)

In [33]:
future_qs_list

array([[1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)

In [11]:
game = Game()
player = Player(game)
env = GameEnv(game, player)
agent = DQNAgent(env)

CONTINUE_FROM_CHECKPOINT = False


Looking for [chromedriver 81.0.4044.69 mac64] driver in cache 
File found in cache by path [/Users/halvorreiten/.wdm/drivers/chromedriver/81.0.4044.69/mac64/chromedriver]


In [12]:
if (CONTINUE_FROM_CHECKPOINT):
    # load stats
    ep_stats = pd.read_csv(f'./stats/{MODEL_NAME}-ep_stats.csv')
    ep_rewards = np.load(f'./stats/{MODEL_NAME}-ep_rewards.npy')
    ep_scores = np.load(f'./stats/{MODEL_NAME}-ep_scores.npy')
    highscore = max(ep_scores)
    
    # load parameters
    notebook_checkpoint = load_pkl(f'./store/{MODEL_NAME}-notebook_checkpoint.pkl')
    epsilon = notebook_checkpoint['epsilon']
    episode = notebook_checkpoint['episode']
    
    # load the last saved model
    agent.model.load_weights(f'./store/{MODEL_NAME}-latest_weights.h5')
    agent.target_model.load_weights(f'./store/{MODEL_NAME}-latest_weights.h5')
    
    # load replay memory
    agent.replay_memory = load_pkl(f'./store/{MODEL_NAME}-replay_memory.pkl')
    
else: 
    ep_stats = pd.DataFrame([], columns=["episode", "min_reward", "max_reward", "avg_reward", "min_score", "max_score", "avg_score", "epsilon"])
    ep_rewards, ep_scores = np.array([]), np.array([])
    highscore = -200
    epsilon = 1
    episode = 0

# For more repetitive results
random.seed(1)
np.random.seed(1)
tf.random.set_seed(1)

# Create models folder
if not os.path.isdir('models'):
    os.makedirs('models')
    
# Iterate over episodes
for ep in tqdm(range(episode+1, EPISODES+1), ascii=True, unit='episodes'):

    # Update tensorboard step every episode
    agent.tensorboard.step = ep

    # Restarting episode - reset episode reward and step number
    episode_reward = 0
    step = 0

    # Reset environment and get initial state
    current_state = env.get_initial_state()

    # Reset flag and start iterating until episode ends
    done = False
    while not done:

        # This part stays mostly the same, the change is to query a model for Q values
        if np.random.random() > epsilon:
            # Get action from Q table
            action = np.argmax(agent.get_qs(current_state))
        else:
            # Get a random action
            action = np.random.randint(0, env.ACTION_SPACE_SIZE)

        new_state, reward, done = env.step(action)

        # Transform new continous state to new discrete state and count reward
        episode_reward += reward

        # Every step we update replay memory and train main network
        agent.update_replay_memory((current_state, action, reward, new_state, done))
        agent.train(done, step)

        current_state = new_state
        step += 1

    episode_score = game.get_score()
    
    # Save model if it is a new highscore
    if (episode_score > highscore):
        print(f'New highscore: {episode_score}')
        agent.model.save(f'./models/{MODEL_NAME}-highscore-{episode_score}.model')
        highscore = episode_score
    
    # Append episode reward to a list and log stats (every given number of episodes)
    ep_rewards = np.append(ep_rewards, episode_reward)
    ep_scores = np.append(ep_scores, episode_score)
    
    if not ep % AGGREGATE_STATS_EVERY or ep == 1:
        average_reward = ep_rewards[-AGGREGATE_STATS_EVERY:].mean()
        min_reward = ep_rewards[-AGGREGATE_STATS_EVERY:].min()
        max_reward = ep_rewards[-AGGREGATE_STATS_EVERY:].max()
        
        avg_score = ep_scores[-AGGREGATE_STATS_EVERY:].mean()
        min_score = ep_scores[-AGGREGATE_STATS_EVERY:].min()
        max_score = ep_scores[-AGGREGATE_STATS_EVERY:].max()
        
        ep_stats.loc[len(ep_stats)] = [ep, min_reward, max_reward, average_reward,min_score,max_score,avg_score,epsilon]
                
        # Save stats
        ep_stats.to_csv(f'./stats/{MODEL_NAME}-ep_stats.csv', index=False)
        np.save(f'./stats/{MODEL_NAME}-ep_rewards.npy', ep_rewards)
        np.save(f'./stats/{MODEL_NAME}-ep_scores.npy', ep_scores)
        
        # Save notebook checkpoint
        notebook_params = {'episode': ep, 'epsilon': epsilon}
        save_pkl(f'./store/{MODEL_NAME}-notebook_checkpoint.pkl', notebook_params)
        
        # Save replay memory
        save_pkl(f'./store/{MODEL_NAME}-replay_memory.pkl', agent.replay_memory)
        
        # Save model weights
        agent.model.save_weights(f'./store/{MODEL_NAME}-latest_weights.h5')
        
        # Save model, but only when min reward is greater or equal a set value
        #if min_reward >= MIN_REWARD:
        #    agent.model.save(f'./models/{MODEL_NAME}__{max_reward:_>7.2f}max_{average_reward:_>7.2f}avg_{min_reward:_>7.2f}min__{int(time.time())}.model')

    # Decay epsilon
    if epsilon > MIN_EPSILON:
        epsilon *= EPSILON_DECAY
        epsilon = max(MIN_EPSILON, epsilon)
    
    env.restart_game()

  0%|          | 0/10000 [00:00<?, ?episodes/s]

New highscore: 50
Saving stats


  0%|          | 2/10000 [00:11<16:19:57,  5.88s/episodes]

New highscore: 85


  0%|          | 19/10000 [01:50<16:53:51,  6.09s/episodes]

Saving stats


  0%|          | 38/10000 [03:48<15:59:07,  5.78s/episodes]

New highscore: 92


  0%|          | 39/10000 [03:58<19:46:29,  7.15s/episodes]

Saving stats


  1%|          | 59/10000 [32:40<171:17:05, 62.03s/episodes]  

Saving stats


  1%|          | 79/10000 [34:46<17:51:00,  6.48s/episodes] 

Saving stats


  1%|          | 99/10000 [36:39<14:51:35,  5.40s/episodes]

Saving stats


  1%|1         | 119/10000 [38:28<15:34:08,  5.67s/episodes]

Saving stats


  1%|1         | 139/10000 [40:19<14:26:32,  5.27s/episodes]

Saving stats


  2%|1         | 159/10000 [42:16<15:25:57,  5.65s/episodes]

Saving stats


  2%|1         | 179/10000 [1:14:08<16:41:53,  6.12s/episodes]   

Saving stats


  2%|1         | 199/10000 [1:16:03<14:34:42,  5.35s/episodes]

Saving stats


  2%|2         | 219/10000 [1:17:57<17:17:57,  6.37s/episodes]

Saving stats


  2%|2         | 239/10000 [1:19:56<16:53:56,  6.23s/episodes]

Saving stats


  3%|2         | 259/10000 [1:21:43<13:37:20,  5.03s/episodes]

Saving stats


  3%|2         | 279/10000 [1:23:36<14:15:48,  5.28s/episodes]

Saving stats


  3%|2         | 299/10000 [1:25:42<18:25:50,  6.84s/episodes]

Saving stats


  3%|3         | 319/10000 [1:27:36<14:55:40,  5.55s/episodes]

Saving stats


  3%|3         | 331/10000 [1:28:45<14:44:29,  5.49s/episodes]

New highscore: 102


  3%|3         | 339/10000 [1:29:34<14:39:52,  5.46s/episodes]

Saving stats


  4%|3         | 359/10000 [1:31:29<14:00:53,  5.23s/episodes]

Saving stats


  4%|3         | 379/10000 [1:33:26<14:40:46,  5.49s/episodes]

Saving stats


  4%|3         | 399/10000 [1:35:18<13:33:26,  5.08s/episodes]

Saving stats


  4%|4         | 419/10000 [1:37:13<14:40:11,  5.51s/episodes]

Saving stats


  4%|4         | 439/10000 [1:39:06<15:37:33,  5.88s/episodes]

Saving stats


  5%|4         | 459/10000 [1:40:58<14:15:45,  5.38s/episodes]

Saving stats


  5%|4         | 479/10000 [1:42:56<14:59:16,  5.67s/episodes]

Saving stats


  5%|4         | 499/10000 [1:44:46<13:37:55,  5.17s/episodes]

Saving stats


  5%|5         | 519/10000 [1:46:42<14:20:46,  5.45s/episodes]

Saving stats


  5%|5         | 539/10000 [1:48:41<16:38:39,  6.33s/episodes]

Saving stats


  6%|5         | 559/10000 [1:50:36<14:09:53,  5.40s/episodes]

Saving stats


  6%|5         | 579/10000 [1:52:34<16:42:23,  6.38s/episodes]

Saving stats


  6%|5         | 599/10000 [1:54:24<12:42:18,  4.87s/episodes]

Saving stats


  6%|6         | 619/10000 [1:56:14<13:04:43,  5.02s/episodes]

Saving stats


  6%|6         | 639/10000 [1:58:03<14:44:49,  5.67s/episodes]

Saving stats


  7%|6         | 659/10000 [1:59:49<13:28:49,  5.20s/episodes]

Saving stats


  7%|6         | 679/10000 [2:01:43<14:04:05,  5.43s/episodes]

Saving stats


  7%|6         | 699/10000 [2:03:25<13:05:31,  5.07s/episodes]

Saving stats


  7%|7         | 719/10000 [2:05:18<13:41:58,  5.31s/episodes]

Saving stats


  7%|7         | 739/10000 [2:07:02<14:11:29,  5.52s/episodes]

Saving stats


  8%|7         | 759/10000 [2:08:54<14:26:30,  5.63s/episodes]

Saving stats


  8%|7         | 779/10000 [2:10:42<13:14:44,  5.17s/episodes]

Saving stats


  8%|7         | 799/10000 [2:12:28<12:43:24,  4.98s/episodes]

Saving stats


  8%|8         | 808/10000 [2:13:28<14:59:26,  5.87s/episodes]

New highscore: 106


  8%|8         | 819/10000 [2:14:34<13:53:33,  5.45s/episodes]

Saving stats


  8%|8         | 839/10000 [2:16:24<12:38:18,  4.97s/episodes]

Saving stats


  9%|8         | 859/10000 [2:18:18<14:28:46,  5.70s/episodes]

Saving stats


  9%|8         | 879/10000 [2:20:11<12:56:07,  5.11s/episodes]

Saving stats


  9%|8         | 899/10000 [2:22:01<13:46:18,  5.45s/episodes]

Saving stats


  9%|9         | 919/10000 [2:23:54<13:49:42,  5.48s/episodes]

Saving stats


  9%|9         | 939/10000 [2:25:42<13:42:46,  5.45s/episodes]

Saving stats


 10%|9         | 959/10000 [2:27:28<12:36:41,  5.02s/episodes]

Saving stats


 10%|9         | 979/10000 [2:29:23<14:27:18,  5.77s/episodes]

Saving stats


 10%|9         | 999/10000 [2:31:14<12:33:25,  5.02s/episodes]

Saving stats


 10%|#         | 1019/10000 [2:33:08<13:32:23,  5.43s/episodes]

Saving stats


 10%|#         | 1039/10000 [2:34:57<12:07:36,  4.87s/episodes]

Saving stats


 11%|#         | 1059/10000 [2:36:45<15:45:17,  6.34s/episodes]

Saving stats


 11%|#         | 1079/10000 [2:38:30<12:07:20,  4.89s/episodes]

Saving stats


 11%|#         | 1099/10000 [2:40:20<15:47:48,  6.39s/episodes]

Saving stats


 11%|#1        | 1119/10000 [2:42:08<12:25:03,  5.03s/episodes]

Saving stats


 11%|#1        | 1139/10000 [2:43:48<12:36:00,  5.12s/episodes]

Saving stats


 11%|#1        | 1140/10000 [2:43:59<21:14:28,  8.63s/episodes]


KeyboardInterrupt: 