### 1. Setup Dependencies

In [None]:
!pip --version && python --version
%conda install -y pytorch torchvision torchaudio -c pytorch
%pip install stable-baselines3[extra] protobuf
%pip install pytesseract mss pydirectinput opencv-python matplotlib

### 2. Import Libraries

In [None]:
from mss import mss
import pydirectinput
import cv2
import numpy as np
import pytesseract
from matplotlib import pyplot as plt
import time
from gym import Env
from gym.spaces import Box, Discrete
import os
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker
from stable_baselines3 import DQN

### 3. Define Environment

#### 3.1 Game Environment

In [137]:
class WebGame(Env):
    # Setup the environment action and observation shapes
    def __init__(self):
        super().__init__()

        # Setup spaces
        self.observation_space = Box(low=0, high=255, shape=(1, 83, 100), dtype=np.uint8)
        self.action_space = Discrete(2) # We have three actions in total

        # Define extraction parameters for the game
        self.cap = mss()

        # Spaces on screen for game action and game over
        self.game_location = {
            'top': 310,
            'left': 170,
            'width': 180,
            'height': 45
        }
        self.done_location = {
            'top': 240,
            'left': 260,
            'width': 250,
            'height': 35
        }

    # Called to do sth in the game
    def step(self, action):
        # Action key - 0 = Jump (Space), 1 = Duck (down), 2 = Nothin()
        action_map = {
            0: 'space',
            1: 'nothing',
        }
        
        if action != 1:
            pydirectinput.press(action_map[action])

        done, done_cap = self.get_done()
        new_observation = self.get_observation()

        # Reward for every frame we are alive
        reward = 1

        info = {}

        return new_observation, reward, done, info

    
    def render(self):
        winname = "Game"
        cv2.namedWindow(winname)
        cv2.moveWindow(winname, 1500, 800)
        cv2.imshow(winname, np.array(self.cap.grab(self.game_location))[:,:,:3])
        if cv2.waitKey(1) & 0xFF == ord('q'):
            self.close()


    # Reset the game
    def reset(self):
        time.sleep(1)
        pydirectinput.click(x=150, y=150)
        pydirectinput.press('space')
        return self.get_observation()

    # Close observation
    def close(self):
        cv2.destroyAllWindows()

    # Get the part og the observation of the game that we want
    def get_observation(self):
        # TODO: Adjust screen cap and resize params
        # Get screen capture of game
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3]
        # Grayscale
        gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
        # Resize
        resized = cv2.resize(gray, (100, 83))
        # Add channles
        channel = np.reshape(resized, (1,83,100))
        return channel

    # Get the game over text
    def get_done(self):
        # Get game over screen using pytesseract
        cap = np.array(self.cap.grab(self.done_location))[:,:,:3]
        # Valid done text. We just take the first word, and give some room for failure to speed things up
        done_strings = ['GAME', 'GAHE', 'GANE']

        done = False
        # Apply OCR
        res = pytesseract.image_to_string(cap)[:4]
        if res in done_strings:
            done = True

        return done, cap

#### 3.2 Logging Callback

In [None]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
        
    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
            
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
        
        return True

### Show observation and done screen cap

In [None]:
env = WebGame()
plt.imshow(env.get_observation()[0])

In [None]:
env = WebGame()
done, cap = env.get_done()
print(done)
plt.imshow(cap)

### Environment Test run

In [None]:
env = WebGame()

# Make sure environment is setup properly
env_checker.check_env(env)

# Make a few testruns
for episode in range(4):
    obs = env.reset()
    done = False
    total_reward = 0
    
    while not done:
        obs, reward, done, info = env.step(env.action_space.sample())
        total_reward += reward
    print(f"Total reward for episode {episode} is {total_reward}")

In [156]:
CHECKPOINT_DIR = './train_final/'
LOG_DIR = './logs_final/'

#### Train existing model (load)

In [None]:
env = WebGame()
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)              
model = DQN.load("./out/best_model_30000.zip",env=env)
model.learn(reset_num_timesteps=False, total_timesteps=100_000, callback=callback)

#### Train new model (creatae)

In [145]:
env = WebGame()
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)              
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, buffer_size=70_000, learning_starts=1000)
model.learn(total_timesteps=120_000, callback=callback)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




Logging to ./logs_final/DQN_1


KeyboardInterrupt: 

### Test Model

In [165]:
env = WebGame()
model = DQN.load("./out/best_model_49000.zip",env=env, print_system_info=False)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [171]:
episodes = 10
for ep in range(episodes):
    obs = env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        env.render()

env.close()
    

KeyboardInterrupt: 

In [172]:
env.close()