## Preliminary Results

**Resources & References**
1. The Game: https://dinosaurgame.app/ 
2. Video Tutorial: https://www.youtube.com/watch?v=vahwuupy81A&t=5517s
3. Stable-Baselines3 Documentations: https://stable-baselines3.readthedocs.io/en/master/
4. Gymnasium Documentation: https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
5. PyTorch DQN Documentation for Gym Retro: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html


### 1. Import Dependencies

In [36]:
# Import Dependencies
import torch # PyToch library for building and training neural networks
from torch import nn
from torch.optim import AdamW
import numpy as np # for numerical calculations
from collections import namedtuple, deque # provides useful data structures may not need
import random # for random sampling 
from mss import mss # for grabbing a screen shot of a monitor 
import pydirectinput # for mouse and keyboard input on windows
import cv2 as cv # for image and video processing
import pytesseract # OCR tool for reading text from images
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import time
from gymnasium import Env
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import env_checker  # Import the environment checker
from collections import deque

In [41]:
class PacMan(Env):
    def __init__(self):
        super().__init__()
        # Define spaces
        self.observation_space = Box(low=0, high=255, shape=(6,50,80), dtype=np.uint8)
        self.action_space = Discrete(5) # number of possible actions
        
        self.previous_lives = 2
        self.current_lives = self.previous_lives
        self.previous_score = 0
        
        self.pellet_address = 0x7268
        self.file_path = "pellet_count.txt"
        self.previous_pellet_count = self.read_pellet_count_from_file()
        
        # Define capture locations
        self.cap = mss()
        self.game_location = {'top':50, 'left':-2280, 'width':1400, 'height':1300}# defines game viewing location
        self.lives_location = {'top':1070, 'left':-902, 'width':600, 'height':200} # defines lives location
        self.frame_stack = deque(maxlen=6) # stack frames to provide a sense of motion
        #self.score_location = {'top':380, 'left':-920, 'width':600, 'height':80} # defines score location
        #self.done_location = {'top':508, 'left':-1810, 'width':450, 'height':80}     

        # Define templates for tracking
        self.ghost_template = cv.imread('ghost_template.png', 0)
        self.ghost_template2 = cv.imread('ghost_template3.png', 0)
        self.ghost_template3 = cv.imread('ghost_template4.png', 0)
        self.pacman_life_template = cv.imread('pacman_life_icon.png', 0)
        self.pacman_template_left = cv.imread('pacman_template_left.png', 0)
        self.pacman_template_right = cv.imread('pacman_template_right.png', 0)
        self.pacman_template_up = cv.imread('pacman_template_up.png', 0)
        self.pacman_template_down = cv.imread('pacman_template_down.png', 0)
        self.pacman_template_closed = cv.imread('pacman_template_closed.png', 0)
        
    # observation of the state of the environment
    def get_observation(self):
        # Get screen capture of game
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3]
        # Grayscale
        gray = cv.cvtColor(raw, cv.COLOR_BGR2GRAY)
        # Resize
        resized = cv.resize(gray, (80,50))
        # Add channels first
        channel = np.reshape(resized, (1,50,80))
        return channel
    
    def get_stacked_observation(self):
        # stack the frames in the deque and convert to the required shape
        return np.concatenate(list(self.frame_stack), axis=0)
    
    # get number of lives left
    def get_lives(self):   
        # Capture the area where the lives are displayed
        lives_cap = np.array(self.cap.grab(self.lives_location))[:,:,:3]
        # Convert to grayscale
        lives_gray = cv.cvtColor(lives_cap, cv.COLOR_BGR2GRAY)
        
        # Perform template matching
        result = cv.matchTemplate(lives_gray, self.pacman_life_template, cv.TM_CCORR_NORMED)
        threshold = 0.8
        locations = np.where(result >= threshold)
        
        lives_value = len(list(zip(*locations[::-1])))
        
        # Determine number of lives
        if lives_value == 684:
            num_lives = 2
        elif lives_value == 344:
            num_lives = 1
        else:
            num_lives = 0
            
        return num_lives
    
    # Get game over
    def get_done(self):
        # Get the number of lives left 
        num_lives = self.get_lives()
        return num_lives == 0 # return bool
    
    def read_pellet_count_from_file(self):
        try:
            with open(self.file_path, "r") as file:
                return int(file.read().strip())
        except (FileNotFoundError, ValueError):
            return 0
        
    # Resets the environment to its initial state
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # restart the game
        pydirectinput.click(x=-890, y=374) # select game window
        pydirectinput.press('f1') # Start state 1 save
        
        # reset pellet count
        self.previous_pellet_count = self.read_pellet_count_from_file()
        
        # reset frame stack
        self.frame_stack.clear()
        for _ in range(6):
            initial_frame = self.get_observation()
            self.frame_stack.append(initial_frame)
            
        return self.get_stacked_observation(), {}
    def render(self):
        frame = self.render_positions()
        
        cv.imshow('Game', frame)
        
        if cv.waitKey(1) & 0xFF == ord('q'):
            self.close()
            
    def close(self):
        cv.destroyAllWindows()
               
    def get_character_positions(self):
        # Capture the area where the lives are displayed
        screen_capture = np.array(self.cap.grab(self.game_location))[:,:,:3]
        cv.imwrite('game_capture.png', screen_capture)
        # Convert to grayscale
        gray_screen = cv.cvtColor(screen_capture, cv.COLOR_BGR2GRAY)
        # Match the templates to find Pac-Man
        result_left = cv.matchTemplate(gray_screen, self.pacman_template_left, cv.TM_CCOEFF_NORMED)
        result_right = cv.matchTemplate(gray_screen, self.pacman_template_right, cv.TM_CCOEFF_NORMED)
        result_up = cv.matchTemplate(gray_screen, self.pacman_template_up, cv.TM_CCOEFF_NORMED)
        result_down = cv.matchTemplate(gray_screen, self.pacman_template_down, cv.TM_CCOEFF_NORMED)
        result_closed = cv.matchTemplate(gray_screen, self.pacman_template_closed, cv.TM_CCOEFF_NORMED)
        result_ghost = cv.matchTemplate(gray_screen, self.ghost_template, cv.TM_CCOEFF_NORMED)
        result_ghost2 = cv.matchTemplate(gray_screen, self.ghost_template2, cv.TM_CCOEFF_NORMED)
        result_ghost3 = cv.matchTemplate(gray_screen, self.ghost_template3, cv.TM_CCOEFF_NORMED)
        threshold = 0.6 # Adjust this value based on testing
        locations_left = np.where(result_left >= threshold)
        locations_right = np.where(result_right >= threshold)
        locations_up = np.where(result_up >= threshold)
        locations_down = np.where(result_down >= threshold)
        locations_closed = np.where(result_closed >= threshold)
        location_ghost = np.where(result_ghost >= 0.5)
        location_ghost2 = np.where(result_ghost2 >= 0.5)
        location_ghost3 = np.where(result_ghost3 >= 0.5)
        pacman_combined_locations = list(zip(*locations_left[::-1])) + list(zip(*locations_right[::-1])) + list(zip(*locations_up[::-1])) + list(zip(*locations_down[::-1])) + list(zip(*locations_closed[::-1]))
        ghost_position = list(zip(*location_ghost[::-1])) + list(zip(*location_ghost2[::-1]))  + list(zip(*location_ghost3[::-1]))

        return ghost_position, pacman_combined_locations, screen_capture
        
    def render_positions(self):
        ghost_position, pacman_combined_locations, screen_capture = self.get_character_positions()

        screen_capture = np.ascontiguousarray(screen_capture) # convert captured image to OpenCV compatability
        
        # Draw rectangles around matched locations using Matplotlib patches
        for loc in pacman_combined_locations:
            top_left = loc
            bottom_right = (top_left[0] + self.pacman_template_right.shape[1], top_left[1] + self.pacman_template_right.shape[0])
            # Create a rectangle patch and add it to the plot
            cv.rectangle(screen_capture, top_left, bottom_right, (255, 0, 0), 2)

        for loc in ghost_position:
            top_left = loc
            bottom_right = (top_left[0] + self.ghost_template.shape[1], top_left[1] + self.ghost_template.shape[0])
            # Create a rectangle patch and add it to the plot
            cv.rectangle(screen_capture, top_left, bottom_right, (0, 0, 255), 2)

        # cv.imshow('Test Render positions', screen_capture)
        # cv.waitKey(0)
        # cv.destroyAllWindows()
        return screen_capture
    
    def calculate_distance (self, pos1, pos2):
        return np.sqrt((pos1[0] - pos2[0])**2 + ( pos1[1] - pos2[1])**2)
    
    # Reward for eating pellets
    def get_pellet_reward(self, current_pellet_count):
        if current_pellet_count < self.previous_pellet_count:
            reward = 30 
            self.previous_pellet_count = current_pellet_count
        else:
            reward = 0    
        return reward
    
    
    # Action that is called to do something in the game
    def step(self, action):
        action_map = {
            0: 'left',   # Move Left
            1: 'right',  # Move Right
            2: 'up',     # Move Up
            3: 'down',   # Move Down
            4: 'no_op'   # No operation (do nothing)
        }
        
        if action != 4:
            pydirectinput.press(action_map[action])
            
        current_pellet_count = self.read_pellet_count_from_file()
        pellet_reward = self.get_pellet_reward(current_pellet_count)
        
        # ghost_positions, pacman_positions, _ = self.get_character_positions()
        
        # if pacman_positions:
        #     pacman_pos = pacman_positions[0]
        # else:
        #     pacman_pos = (0, 0) # Default position if not detected
            
        # if ghost_positions:
        #     ghost_positions = ghost_positions[0]
        # else:
        #     ghost_positions = (0, 0)
            
        # ghost_penalty = 0
        # threshold_distance = 50
        # for ghost_pos in ghost_positions:
        #     distance = self.calculate_distance(pacman_pos, ghost_pos)
        #     if distance < threshold_distance:
        #         ghost_penalty -= 10
                
        current_lives = self.get_lives()
        life_penalty = 0
        # Penalize only when a life is lost (and only once per life loss)
        if current_lives < self.previous_lives:
            life_penalty -= 50
            self.previous_lives = current_lives # update previous lives 
            
        reward = pellet_reward + life_penalty 
        
        # Penalize heavily if all lives are lost
        done = self.get_done()
        # end_game_penalty = 0
        # if done:
        #     end_game_penalty -= 500
        # else: 
        #     end_game_penalty -= 0

        # Get the next observation
        new_frame = self.get_observation()
        self.frame_stack.append(new_frame)
        stacked_observation = self.get_stacked_observation()
        
        return stacked_observation, reward, done, False, {}
    
    

In [27]:
env = PacMan()

In [24]:
ghost_positions, pacman_positions, _ = env.get_character_positions()
if pacman_positions:
    pacman_pos = pacman_positions[0]
else:
    pacman_pos = (0, 0) # Default position if not detected
    
if ghost_positions:
    ghost_positions = ghost_positions
else:
    ghost_positions = (0,0)
for ghost_pos in ghost_positions:
    distance = env.calculate_distance(pacman_pos, ghost_pos)
    print(distance)


446.5254751971045
447.0212522912082
444.6684157886638
445.16176834943946
445.6568186396344
446.1535610078664
446.65198980862044
442.8148597325975
443.30576355378014
443.79837764462366
444.292696316291
444.7887138855931
445.28642467517466
445.7858230136979
446.2869032360237
441.4532817864196
441.9434352946087
442.4353060052961
442.92888819764283
443.42417615642023
443.9211641721985
444.4198465415333
444.9202175671499
445.4222715581249
445.92600283006595
440.0920358288707
440.58143401645964
441.0725563895355
441.5653971950248
442.05995068542455
442.55621111899444
443.0541727599459
443.5538298786293
444.05517675171853
444.5582076623937
445.0629169005209
445.56929876282993
438.7311249501225
439.21976276119454
439.7101317913882
440.20222625516107
440.69604037249985
441.1915683691156
441.688804476636
442.18774293279546
442.68837798162264
443.1907038736259
443.69471486597627
444.2004052226877
444.7077692147957
445.21680112053275
437.85842460777206
438.34803524140494
438.8393783606936
439.3324

In [42]:

env = PacMan()
for episode in range(5):
    obs = env.reset()
    done =False
    total_reward = 0
    
    while not done:
        obs, reward, done, truncated, info = env.step(env.action_space.sample())
        total_reward += reward
        env.render()
print("Reward: {total_reward}")
env.close()

KeyboardInterrupt: 

In [43]:
env.close()

In [None]:
import matplotlib.patches as patches
cap = mss()
game_location = {'top':50, 'left':-2280, 'width':1400, 'height':1300}
screen_capture = np.array(cap.grab(game_location))[:,:,:3]
cv.imwrite('game_capture.png', screen_capture)

ghost_template = cv.imread('ghost_template.png', 0)
pacman_template_left = cv.imread('pacman_template_left.png', 0)
pacman_template_right = cv.imread('pacman_template_right.png', 0)
pacman_template_up = cv.imread('pacman_template_up.png', 0)
pacman_template_down = cv.imread('pacman_template_down.png', 0)
gray_screen = cv.cvtColor(screen_capture, cv.COLOR_BGR2GRAY)
# Match the templates to find Pac-Man
result_left = cv.matchTemplate(gray_screen, pacman_template_left, cv.TM_CCOEFF_NORMED)
result_right = cv.matchTemplate(gray_screen, pacman_template_right, cv.TM_CCOEFF_NORMED)
result_up = cv.matchTemplate(gray_screen, pacman_template_up, cv.TM_CCOEFF_NORMED)
result_down = cv.matchTemplate(gray_screen, pacman_template_down, cv.TM_CCOEFF_NORMED)
result_ghost = cv.matchTemplate(gray_screen, ghost_template, cv.TM_CCOEFF_NORMED)
threshold = 0.5 # Adjust this value based on testing
locations_left = np.where(result_left >= threshold)
locations_right = np.where(result_right >= threshold)
locations_up = np.where(result_up >= threshold)
locations_down = np.where(result_down >= threshold)
location_ghsot = np.where(result_ghost >= 0.5)
combined_locations = list(zip(*locations_left[::-1])) + list(zip(*locations_right[::-1])) + list(zip(*locations_up[::-1])) + list(zip(*locations_down[::-1]))
ghost_position = list(zip(*location_ghsot[::-1]))

# Set up the Matplotlib figure
fig, ax = plt.subplots(1)
ax.imshow(screen_capture)

# Draw rectangles around matched locations using Matplotlib patches
for loc in combined_locations:
    top_left = loc
    bottom_right = (top_left[0] + pacman_template_right.shape[1], top_left[1] + pacman_template_right.shape[0])
    
    # Create a rectangle patch and add it to the plot
    rect = patches.Rectangle(top_left, pacman_template_right.shape[1], pacman_template_right.shape[0], 
                             linewidth=2, edgecolor='b', facecolor='none')
    ax.add_patch(rect)
for loc in ghost_position:
    top_left = loc
    bottom_right = (top_left[0] + ghost_template.shape[1], top_left[1] + ghost_template.shape[0])
    
    # Create a rectangle patch and add it to the plot
    rect = patches.Rectangle(top_left, ghost_template.shape[1], ghost_template.shape[0], 
                             linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)
# Display the result
plt.show()


In [37]:

start_time = time.time()

# Designing DQN Model
class DQN(nn.Module): # defines a new neural netwokr model that inherits from Pytorch's base class nn.module
    def __init__(self, lr, input_dims, fc1_dims, fc2_dims, num_actions): 
        super(DQN, self).__init__() # calls the initializer of the parent class nn.module 
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.conv1 = nn.Conv2d(self.input_dims, 32, kernel_size=8, stride=4) # convolutional layer with 32 filters, each of size 8 x8, applied with a stride of 4
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) # convolutional layer with 64 filters, each of size 4 x 4, applied with a stride of 2
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) # convolutional layer with 64 filters, each of size 3 x 3
        self.fc_input_size = self._calculate_fc_input_size(self.input_dims)
        self.fc1 = nn.Linear(self.fc_input_size, self.fc1_dims) # fully connected layer with 512 units
        self.fc2 = nn.Linear(self.fc2_dims, num_actions) # final fully connected layer with output units equal to the number of possible actions
        self.optimizer = AdamW(self.model.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def _calculate_fc_input_size(self, input_dims):
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_dims, 50, 80)
            x = torch.relu(self.conv1(dummy_input))
            x = torch.relu(self.conv2(x))
            x = torch.relu(self.conv3(x))
            
            return x.view(1, -1).size(1)  # Flatten and get the size
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1) # Flatten the output from conv layers
        x = torch.relu(self.fc1(x))
        actions =  self.fc2(x)  # Output Q-Values for each action
        
        return actions
    
    
# Creating DQN Agent
class DQNAgent:       
    def __init__(self, gamma, epsilon, lr, input_dims, batch_size, num_actions,
                 max_mem_size=100000, eps_end=0.01, eps_dec=5e-4):
        self.gamma = gamma # Determines the weighting of future rewards
        self.epsilon = epsilon
        self.epsilon_min = eps_end
        self.epsilon_decay = eps_dec
        self.lr = lr
        self.action_space = [i for i in range(num_actions)]
        self.num_actions = num_actions
        self.mem_size = max_mem_size
        self.batch_size = batch_size
        self.mem_cntr = 0
        
        self.Q_eval = DQN(self.lr, input_dims, fc1_dims=512, fc2_dims=512, num_actions=num_actions)
        # self.memory = deque(maxlen=2000) rather than use this use this:
        self.state_memory = np.zeros((self.mem_size, *input_dims), dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_dims), dtype=np.float32)
        self.action_memory = np.zeros(self.mem_size, dtype=np.int32)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)
    
    # Method for storing memory   
    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.reward_memory[index] = reward
        self.action_memory[index] = action
        self.terminal_memory[index] = done
        
        self.mem_cntr += 1
    
    # Method for choosing an action
    def choose_action(self, observation):
        if np.random.random() > self.epsilon:
            state = torch.tensor([observation]).to(self.Q_eval.device)
            actions = self.Q_eval.forward(state)
            action = torch.argmax(actions).item()
        else:
            action = np.random.choice(self.action_space)
            
        return action
    
    def learn(self):
        if self.mem_cntr < self.batch_size:
            return
        self.Q_eval.optimizer.zero_grad()
        
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_mem, self.batch_size, replace=False)
        
        batch_index = np.arange(self.batch_size, dtype=np.int32)
        
        state_batch = torch.tensor(self.state_memory[batch]).to(self.Q_eval.dvice)
        new_state_batch = torch.tensor(self.new_state_memory[batch]).to(self.Q_eval.dvice)
        reward_batch = torch.tensor(self.reard_memory[batch]).to(self.Q_eval.dvice)
        terminal_batch = torch.tensor(self.terminal_memory[batch]).to(self.Q_eval.dvice)
        
        action_batch = self.action_memory[batch]
        
        q_eval = self.Q_eval.forward(state_batch)[batch_index, action_batch]
        q_next = self.Q_eval.forward(new_state_batch)
        q_next[terminal_batch] = 0.0
        
        q_target = reward_batch + self.gamma * torch.max(q_next, dim=1)[0]
        
        loss = self.Q_eval.loss(q_target, q_eval).to(self.Q_eval.device)
        loss.backward()
        self.Q_eval.optimizer.step()
        
        if self.epsilon > self.epsilon_min:
            self.epsilon = self.epsilon - self.epsilon_decay
        else:
            self.epsilon = self.epsilon_min
        
        
        
    # def remember(self, state, action, reward, next_state, done):
    #     self.memory.append((state, action, reward, next_state, done))
    
    # def act(self, state):
    #     if np.random.rand() <= self.epsilon:
    #         return random.randrange(self.num_actions)
    #     state = torch.FloatTensor(state).unsqueezse(0)
    #     q_values = self.model(state)
    #     return torch.argmax(q_values[0]).item()
    
    # def replay(self, batch_size):
    #     if len(self.memory) < batch_size:
    #         return
    #     minibatch = random.sample(self.memory. batch_size)
    #     for state, action, reward, next_state, done in minibatch:
    #         target = reward
    #         if not done:
    #             next_state = torch.FloatTensor(next_state).unsqueeze(0)
    #             target += self.gamma * torch.max(self.model(next_state)).item()
    #         state = torch.FloatTensor(state).unsqueeze(0)
    #         target_f = self.model(state)
    #         target_f[0][action] = target
    #         self.optimizer.zero_grad()
    #         loss = self.criterion(target_f, self.model(state))
    #         loss.backward()
    #         self.optimizer.step()
        
    #     if self.epsilon > self.epsilon_min:
    #         self.epsilon *= self.epsilon_decay

In [None]:
# Hyperparameters
num_episodes = 100
batch_size = 32
gamma = 0.99
epsilon = 0.1 # Exploration rate
buffer_capacity = 10000
learning_rate = 1e-3    

# Initialize environment and model
env = PacMan()
input_dims = env.observation_space.shape[0]
num_actions = env.action_space.n

model = GameNet(input_dims, num_actions)
target_model = GameNet(input_dims, num_actions)
target_model.load_state_dict(model.state_dict())
optimizer = Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
replay_buffer = ReplayBuffer(buffer_capacity) # Stores experiences (state, action, reward, next state, done) for training

# Training function to choose an action using epsilon-greedy policy (Exploration vs exploitation)
def select_action(state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample() # Random action (Exploration)
    else:
        state = torch.FloatTensor(state).unsqueeze(0) # Add batch dimension
        with torch.no_grad():
            q_values = model(state)
        return q_values.argmax().item() # Action with highest Q-value

In [None]:
# Training function to procvide interaction with the environment, 
def train_gamenet(env, model, target_model, optimizer, criterion, replay_buffer, num_episodes=10):
    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            action = select_action(state, epsilon)
            next_state, reward, done, truncated, info = env.step(action)
            total_reward += reward
            
            # Store experience in replay buffer
            replay_buffer.push((state, action, reward, next_state, done))
            state = next_state
            
            # Perform optimization step
            if len(replay_buffer) >= batch_size:
                batch = replay_buffer.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*batch)
                
                states = torch.FloatTensor(np.array(states))
                actions = torch.LongTensor(np.array(actions))
                rewards = torch.FloatTensor(np.array(rewards))
                next_states = torch.FloatTensor(np.array(next_states))
                dones = torch.FloatTensor(np.array(dones))
                
                # Compute Q-values
                q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
                next_q_values = target_model(next_states).max(1)[0]
                target_q_values = rewards + (gamma * next_q_values * (1 - dones))
                
                # Compute loss
                loss = criterion(q_values, target_q_values)
                
                # Optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
        # Print progress
        print(f"Episode {episode}: Total Reward = {total_reward}")
        
        # Update target network
        if episode % 100 == 0:
            target_model.load_state_dict(model.state_dict())
        
train_gamenet(env, model, target_model, optimizer, criterion, replay_buffer, num_episodes)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Total training time: {elapsed_time // 60:.0f} minutes, {elapsed_time % 60:.0f} seconds")

## Training

In [822]:
# import dependencies
import os
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker

In [824]:
# Check environment for errors
env_checker.check_env(env)

In [823]:
class TrainAndLoggingCallback(BaseCallback):
    
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
        
    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
            
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
            
        return True

In [825]:
# Declare directories to store files
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs'

In [826]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

In [827]:
from stable_baselines3 import DQN

In [None]:
# declare model
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, buffer_size=1200000, learning_starts=1000, exploration_initial_eps=1.0, exploration_final_eps=0.1, exploration_fraction=0.1, learning_rate = 0.01 )

In [None]:
# start learning
model.learn(total_timesteps=5000, callback=callback)

In [None]:
# load model
model.load(os.path.join('train', 'best_model_5000'))


In [None]:
# play loaded model
for episode in range(10):
    obs, _ = env.reset()
    done =False
    total_reward = 0
    
    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, _, info = env.step(int(action))
        total_reward += reward
    print(f'Total Reward for episode {episode} is {total_reward}')
    time.sleep(2)