Install and import dependencies

In [None]:
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

In [None]:
!pip list

In [None]:
!pip install stable-baselines3[extra] protobuf==3.20.*

In [None]:
!pip install mss pydirectinput pytesseract

In [None]:
#MSS  used for screen capturing
from mss import mss

#sending commands
import pydirectinput

#it allows us to do frame processing
import cv2

import numpy as np

#OCR for game over extraction
import pytesseract

from matplotlib import pyplot as plt

import time 

#environment components
from gym import Env
from gym.spaces import Box, Discrete

Build the environment

In [None]:
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'

In [None]:
# create environment
class WebGame(Env):
    
    #setup the environment action and observation shapes
    def __init__(self):
        #subclass model
        super().__init__()
        #setup spaces
        self.observation_space = Box(low=0,high=255,shape= (1,83,100),dtype=np.uint8)
        self.action_space = Discrete(3)
        #define extraction parameters for the game
        self.cap = mss()
        self.game_location = {'top':350,'left':0,'width':400,'height':200}
        self.done_location = {'top':300,'left':300,'width':500,'height':80}
    
    #what is called to do something in the game
    def step(self,action):
        #action keys  --->  0=space(up),1=duck(down),2=no action
        action_map = {
            0:'space',
            1:'down',
            2:'no_op'
        }
        if action != 2:
            pydirectinput.press(action_map[action])
        
        #checking whether the game is done or not 
        done,done_cap = self.get_done()
        #get the next observation
        new_observation = self.get_observation()
        #reward - we get a point for every frame we're alive
        reward = 1
        #info dictionary
        info = {}
        
        return new_observation,reward,done,info
    
    #visualize the game
    def render(self):
        cv2.imshow('Game', np.array(self.cap.grab(self.game_location))[:,:,:3])   
        if cv2.waitKey(1) & 0xFF == ord('q'):
            self.close()
    
    #restart the game
    def reset(self):
        time.sleep(1)
        pydirectinput.click(x=150, y=150)
        pydirectinput.press('space')
        return self.get_observation()
      
    #this closes down the observation\
    def close(self):
        cv2.destroyAllWindows()
    
    # get the part the observation of the  game that we want
    def get_observation(self):
        #get screen capture of the game
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3]
        #grayscale
        gray = cv2.cvtColor(raw,cv2.COLOR_BGR2GRAY)
        #resize
        resized = cv2.resize(gray,(100,83))
        #add channels first
        channel = np.reshape(resized,(1,83,100))
        return channel
    
    #get the done text using OCR
    def get_done(self):
        #get done screen
        done_cap = np.array(self.cap.grab(self.done_location))
        #valid done text
        done_strings = ['GAME','GAHE']
        #apply ocr
        done = False
        res = pytesseract.image_to_string(done_cap)[:4]
        if res in done_strings:
            done = True
        
        return done, done_cap

In [None]:
env = WebGame()

In [None]:
plt.imshow(cv2.cvtColor(env.get_observation()[0],cv2.COLOR_BGR2RGB))

In [None]:
done,done_cap = env.get_done()

In [None]:
plt.imshow(done_cap)

In [None]:
done

In [None]:
env.render()

In [None]:
env.close()

In [None]:
for episode in range(10): 
    obs = env.reset()
    done = False  
    total_reward   = 0
    while not done: 
        obs, reward,  done, info =  env.step(env.action_space.sample())
        total_reward  += reward
    print('Total Reward for episode {} is {}'.format(episode, total_reward))    


In [None]:
# Import os for file path management
import os 
# Import Base Callback for saving models
from stable_baselines3.common.callbacks import BaseCallback
# Check Environment    
from stable_baselines3.common import env_checker

In [None]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [None]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [None]:
callback = TrainAndLoggingCallback(check_freq=300, save_path=CHECKPOINT_DIR)

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

In [None]:
env = WebGame()

In [None]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1,
            buffer_size=12000, learning_starts=1000)

In [None]:
model.learn(total_timesteps=1000, callback=callback)

In [None]:
model= DQN.load(os.path.join('train','best_model_900')) 

In [None]:
for episode in range(5):     
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)   
        obs, reward, done, info = env.step(int(action))
        time.sleep(0.01) 
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(episode, total_reward))
    time.sleep(2)     