### 1. Deps

In [None]:
!pipenv install torch torchvision tensorboard torchaudio git+https://github.com/DLR-RM/stable-baselines3 protobuf==3.20.* mss pydirectinput pytesseract opencv-python gym shimmy

In [160]:
# mss used for screen capture
from mss import mss
# sending commands to the game
import pydirectinput
# opencv allows frame processing
import cv2
# transformational framework
import numpy as np
# OCR for game over extraction
import pytesseract
# visualiaze captured frames
from matplotlib import pyplot as plt
# pauses
import time
# environment components
from gym import Env
from gym.spaces import Box, Discrete

In [161]:
pytesseract.pytesseract.tesseract_cmd = 'C:\Program Files\Tesseract-OCR\\tesseract.exe'; # change to your tesseract path

### 2. Building the environment

#### 2.1 Create environment

In [279]:
class WebGame(Env):
	# setup the environment action and observation shapes
	def __init__(self):
		# subclass model
		super().__init__()

		# setup spaces
		self.observation_space = Box(
			low=0, high=255, shape=(1, 100, 150), dtype=np.uint8
		)
		self.action_space = Discrete(2)

		# define extraction paramenters for the game
		self.cap = mss()
		self.game_location = {"top": 200, "left": 120, "width": 750, "height": 500}
		self.done_location = {"top": 220, "left": 630, "width": 650, "height": 70}
		self.done_pixel_location = {"top": 254, "left": 678, "width": 5, "height": 2}
		self.day_or_night_observation_space = {"top": 150, "left": 150, "width": 10, "height": 10}
		self.action_map = {0: "space", 1: "no_op"}
  
	# what is called to do someting in the game
	def step(self, action):
		if action != 1:
			pydirectinput.press(self.action_map[action])
			
		# checking whether the game is done
		done = self.get_done()
		# get the next observation
		new_observation = self.get_observation()
		# reward - we get a point for every frame we're alive
		reward = 1
  
		info = {}

		return new_observation, reward, done, info

	# visualize the game
	def render(self):
		cv2.imshow('Game', np.array(self.cap.grab(self.game_location))[:,:,:3])
		if cv2.waitKey(1) & 0xFF == ord('q'):
			self.close()
	
	# this closes down the observation
	def close(self):
		cv2.destroyAllWindows()

	# reset the game
	def reset(self):
		time.sleep(1)
		pydirectinput.click(x=150, y=150)
		pydirectinput.press('space')
		return self.get_observation()

	# get the part of the observation of the game that we want
	def get_observation(self):
		# get screen capture of game
		raw = np.array(self.cap.grab(self.game_location))[:, :, :3]

		# grayscale
		gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
		# resize
		resized = cv2.resize(gray, (150, 100))
		# add chanels first
		channel = np.reshape(resized, (1, 100, 150))

		return channel

	def get_done(self):
		is_day = True;
		done = False
     
		day_or_night_space = np.array(self.cap.grab(self.day_or_night_observation_space))[:, :, :3]
		if all(day_or_night_space[0][0] == [255, 255, 255]):
			is_day = True
		else:
			is_day = False
   
		done_pixel_cap = np.array(self.cap.grab(self.done_pixel_location))[:, :, :3]
		if is_day:
			if all(done_pixel_cap[0][0] == [83, 83, 83]):
				done = True
		else:
			if all(done_pixel_cap[0][0] == [172, 172, 172]):
				done = True

		return done

In [280]:
env = WebGame()

In [None]:
done = env.get_done()
print(done)

In [None]:
plt.imshow(cv2.cvtColor(env.get_observation()[0], cv2.COLOR_BGR2RGB))

#### 2.2 Test environment

In [None]:
# play 10 games
for episode in range(10):
	obs = env.reset()
	done = False
	total_reward = 0
	
	while not done:
		obs, reward, done, info = env.step(env.action_space.sample())
		total_reward += reward;
		
	print(f'Total reward for episode {episode} is {total_reward}')

### 3. Train the model

#### 3.1 Create Callback

In [226]:
# import os for file path management
import os
# import  base callback for saving models
from stable_baselines3.common.callbacks import BaseCallback
# check environment
from stable_baselines3.common import env_checker

In [None]:
# check that the environment is ok
env_checker.check_env(env)

In [273]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
        
    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
            
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
            
        return True

In [274]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [275]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

#### 3.2 Build DQN and Train

In [276]:
# import the DQN algorithm
from stable_baselines3 import DQN

In [None]:
# create the DQN model
model = DQN(
    'CnnPolicy',
    env,
    tensorboard_log=LOG_DIR,
    verbose=1,
    buffer_size=30000,
    learning_starts=0
)

In [None]:
model.learn(total_timesteps=100000, callback=callback)

### 4. Test