 ## Installing Dependencies


In [None]:
!pip install torch torchvision torchaudio

In [None]:
!pip install stable-baselines3[extra] protobuf==3.20.*

In [None]:
!pip install mss pydirectinput pytesseract




In [None]:
from mss import mss
import pydirectinput
import cv2
import numpy as np
import pytesseract
from matplotlib import pyplot as plt 
import time
from gym import Env
from gym.spaces import Box,Discrete

In [None]:
# passing tesseract.exe
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'


## Creating game enviroment

In [None]:
class WebGame(Env):
    # Initializing enviroment
    def __init__(self):
        super().__init__()
        
        self.observation_space = Box(low=0, high=255, shape=(1,83,100), dtype= np.uint8)
        
        self.action_space = Discrete(3)
        
        self.cap = mss()
        self.game_location = {"top": 300, "left": 0, "width":600, "height": 500}
        self.done_location = {"top": 405, "left": 630, "width":660, "height": 70}
        
    # pass in any actions player might do i.e jump, shoot
    def step(self, action):
        # mapping actions to keys
        action_map = {
            0:"space",
            1:"down",
            2:"no_op"
        }
        
        # passing our map into direct input
        if action !=2:
            pydirectinput.press(action_map[action])
            
        #mapping values
        done, done_cap = self.get_done()
            
            
        observation = self.get_observation()
            
        # setting reward
        reward = .1 if not done else -1
            
        info = {}
            
        return  observation,reward,done,info
        
    # Reset the game
    def reset(self):
        time.sleep(1)
        pydirectinput.click(x=150,y=150)
        pydirectinput.press("space")
        return self.get_observation()
            
            
        
    # visualize the game
    def render(self):
        cv2.imshow("Game",self.current_frame)
        if cv2.waitKey(1) & 0xFF == ord("q"):
                   self.close()
                
    def close(self):
        cv2.destroyAllWindows()
        
   
        
    # captures the dinosaur
    def get_observation(self):
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3].astype(np.uint8)
        
        # Gray scaling image
        gray = cv2.cvtColor(raw,cv2.COLOR_BGR2GRAY)
        
        resized = cv2.resize(gray,(100,83))
        channel = np.reshape(resized,(1,83,100))
        return channel
        

        
        
    # is the game over? captures game over screen
    def get_done(self):
        done_cap = np.array(self.cap.grab(self.done_location))
        # game over text
        done_strings = ["GAME", "GAHE"]
        done = False
        res = pytesseract.image_to_string(done_cap)[:4]
        if res in done_strings:
            done = True
        return done,done_cap
        
        
    
        
    

## Testing webgame class

In [None]:
env = WebGame()


In [None]:
# Should return True if "GAME" is found in screenshot
env.get_done()[0]


In [None]:
plt.imshow(env.get_done()[1])

## Auto-Run Game

In [None]:
for episode in range(10): 
    obs = env.reset()
    done = False  
    total_reward   = 0
    while not done: 
        obs, reward,  done, info =  env.step(env.action_space.sample())
        total_reward  += reward
    
    print(episode, reward)
   
        

### Callback function

In [None]:
import os 
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker


In [None]:
env_checker.check_env(env)

In [None]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [None]:
# Create these files in root dir
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [None]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

## Build and Train

In [None]:
from stable_baselines3 import PPO
from stable_baselines3 import DQN



In [None]:
env = WebGame()

In [None]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, buffer_size=25000, learning_starts=0)

In [None]:
model.learn(total_timesteps=100000, callback=callback)


## Test Model

In [None]:
model.load('train/best_model_100000')  

In [None]:
for episode in range(5): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        time.sleep(0.01)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(episode, total_reward))
    time.sleep(2)