In [None]:
import sys
import torch
import gym
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pathlib import Path
from torch.distributions import Categorical
from collections import namedtuple, deque
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
import random
import cv2

In [None]:
env_name = "SpaceInvaders-v4"
env = gym.make(env_name, obs_type = 'grayscale')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [None]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen = capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(input_shape, 32, kernel_size = 8, stride = 4, 
                      padding = 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(6400, 512),
            nn.ReLU(),
            nn.Linear(512, 6)
        )
    
    def forward(self, x):
        return self.network(x)

In [None]:
class Agent:
    def __init__(self, input_shape = 4):
        self.BATCH_SIZE = 128
        self.GAMMA = 0.99
        self.EPSILON = 1
        self.MIN_EPSILON = 0.1
        self.EPSILON_DECAY = 50000
        self.TAU = 0.005
        self.LR = 0.0002
        
        self.Transition = namedtuple('Transition', 
                                ('state', 'action', 'next_state', 'reward'))
        self.memory = ReplayMemory(100000)
        self.n_actions = env.action_space.n
        self.n_observations = len(state)
        self.Transition = namedtuple('Transition', 
                                     ('state', 'action', 'next_state', 
                                      'reward'))
        
        self.input_shape = input_shape
        self.policy_net = DQN(self.input_shape, self.n_actions).to(device)
        self.target_net = DQN(self.input_shape, self.n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.optimizer = optim.AdamW(self.policy_net.parameters(), 
                                     lr = self.LR, amsgrad = True)

        self.steps_done = 0
        self.num_train_steps = 0
        self.episode_durations = []
        self.rewards = []
    def get_exploration_rate(self, k = None):
        if k is not None:
            return (self.MIN_EPSILON + 
                      ((self.EPSILON - self.MIN_EPSILON) * 
                       math.exp(-1.*k/self.EPSILON_DECAY)))
        return (self.MIN_EPSILON + 
                      ((self.EPSILON - self.MIN_EPSILON) * 
                       math.exp(-1.*self.steps_done/self.EPSILON_DECAY)))
    
    def select_action(self, state, explore = True):
        sample = random.random()
        eps_thresh = (self.MIN_EPSILON + 
                      ((self.EPSILON - self.MIN_EPSILON) * 
                       math.exp(-1.*self.steps_done/self.EPSILON_DECAY)))
        self.steps_done += 1
        if sample > eps_thresh or explore == False:
            with torch.no_grad():
                return self.policy_net(state).max(1)[1].view(1,1)
        else:
            return torch.tensor([[env.action_space.sample()]], 
                                device = device, dtype = torch.long)
    
    def plot_durations(self):
        plt.figure(1)
        durations_t = torch.tensor(self.episode_durations, 
                                   dtype = torch.float)
        plt.title('Episode_Lengths')
        plt.xlabel('Episode')
        plt.ylabel('Duration')
        plt.plot(durations_t.numpy())

        if len(durations_t) >= 100:
            means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
            means = torch.cat((torch.zeros(99), means))
            plt.plot(means.numpy())

        plt.show()
        
        plt.figure(2)
        rewards_t = torch.tensor(self.rewards, 
                                   dtype = torch.float)
        plt.title('Rewards')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.plot(rewards_t.numpy())

        if len(durations_t) >= 100:
            means = rewards_t.unfold(0, 100, 1).mean(1).view(-1)
            means = torch.cat((torch.zeros(99), means))
            plt.plot(means.numpy())
        plt.show()
    
    def optimize_model(self):
        if len(self.memory) < self.BATCH_SIZE:
            return
        transitions = self.memory.sample(self.BATCH_SIZE)
        batch = self.Transition(*zip(*transitions))

        non_final_mask = torch.tensor(
            tuple(map(lambda s: s is not None, batch.next_state)), 
            device = device, dtype = torch.bool)

        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None])

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = self.policy_net(state_batch).gather(1, 
                                                             action_batch)

        next_state_values = torch.zeros(self.BATCH_SIZE, device = device)

        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(
                non_final_next_states).max(1)[0]

        expected_state_action_values = reward_batch + (
            next_state_values * self.GAMMA)

        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, 
                         expected_state_action_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
#         torch.nn.utils.clip_grad_value_(self.policy_net.param eters(), 100)
        self.optimizer.step()
    
    def train_step(self):
        self.optimize_model()
        
        target_net_state_dict = self.target_net.state_dict()
        policy_net_state_dict = self.policy_net.state_dict()
        
        for key in policy_net_state_dict:
            target_net_state_dict[key] = (policy_net_state_dict[key]*self.TAU +
                                          target_net_state_dict[key]*(1-self.TAU))
        self.target_net.load_state_dict(target_net_state_dict)
        self.num_train_steps += 1
    
    def save_model(self, name, checkpoint_name):
        MODEL_PATH = Path('weights')
        MODEL_PATH.mkdir(parents = True, exist_ok = True)

        MODEL_NAME = name + checkpoint_name + '.pth'
        MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

        print(f"Saving model to: {MODEL_SAVE_PATH}")
        torch.save(obj = self.policy_net.state_dict(), f = MODEL_SAVE_PATH)
    
    

In [None]:
def show_game(agent, num_stacked_frames, show_action = False):
    env2 = gym.make(env_name, render_mode = 'human', obs_type = 'grayscale')
    stacked_frames = deque([torch.zeros((84,84)).to(device) 
                            for i in range(num_stacked_frames)], 
                           maxlen = num_stacked_frames)
    state = env2.reset()[0]
    state = preprocess(state)
    stacked_frames.append(state)
#     state_tensor = get_frame_tensor(stacked_frames.copy())
    done = False
    steps = 0
    total_reward = 0
    while not done:
        steps += 1
        state_tensor = get_frame_tensor(stacked_frames.copy())
        action = agent.select_action(state_tensor)
        if show_action:
            print(action.item())
        next_state, reward, terminated, truncated, _ = env2.step(action.item())
        stacked_frames.append(preprocess(next_state))
        total_reward += reward
        done = terminated or truncated
    env2.close()
    print(f'Game Completed# Reward:{total_reward} | Game_Length:{steps}')
    return

In [None]:
def preprocess(img):
    img = cv2.resize(img, (84, 84))
    img = torch.FloatTensor(img).to(device)
    img /= 255.0
    return img

In [None]:
def process_frame_skip(action, num_stacked_frames, num_skip_frames, 
                       stacked_frames, env):
    total_reward = 0
    for i in range(num_skip_frames):
        next_state, reward, terminated, truncated, _ = env.step(action)
        stacked_frames.append(preprocess(next_state))
        total_reward += reward
        done = terminated or truncated
        if done:
            break
        stacked_frames.append(preprocess(next_state))
    return stacked_frames, total_reward, done

In [None]:
def get_frame_tensor(stacked_frames):
    state_tensor = torch.stack(list(stacked_frames.copy()), dim = 2)
    state_tensor = torch.permute(state_tensor, (2, 0, 1)).unsqueeze(0)
#     print(state_tensor.shape)
    return state_tensor

In [None]:
def save_model(num_epochs):
    MODEL_PATH = Path('weights')
    MODEL_PATH.mkdir(parents = True, exist_ok = True)

    MODEL_NAME = f'DQN_SpaceInvaders_{num_epochs}.pth'
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

    print(f"Saving model to: {MODEL_SAVE_PATH}")
    torch.save(obj = agent.policy_net.state_dict(), f = MODEL_SAVE_PATH)

In [None]:
input_shape = 4
agent = Agent(input_shape)

In [None]:
num_episodes = 50000
num_steps = 50
num_stacked_frames = 4
num_skip_frames = 4
train_after_steps = 4
for i_episode in range(1, num_episodes+1):
    total_reward = 0
    steps = 0
    
    stacked_frames = deque([torch.zeros((84,84)).to(device) 
                            for i in range(num_stacked_frames)], 
                           maxlen = num_stacked_frames)
    state = env.reset()[0]
    state = preprocess(state)
    stacked_frames.append(state)
    state_tensor = get_frame_tensor(stacked_frames.copy())
    done = False
    
    while not done:
        steps += 1
        action = agent.select_action(state_tensor)
        next_stacked_frames, reward, done = process_frame_skip(action.item(), 
                                                      num_stacked_frames, 
                                                      num_skip_frames, 
                                                      stacked_frames.copy(),
                                                      env)
        total_reward += reward
        reward = torch.FloatTensor([reward]).to(device)
        if done:
            next_stacked_frames = deque([torch.zeros((84,84)).to(device) 
                                    for i in range(num_stacked_frames)], 
                                   maxlen = num_stacked_frames)
        
        state_tensor = get_frame_tensor(stacked_frames)
        next_state_tensor = get_frame_tensor(next_stacked_frames)
        
#         print(state_tensor, action, next_state_tensor, reward)
        agent.memory.push(state_tensor, action, next_state_tensor, reward)
        
        stacked_frames = next_stacked_frames
        
        if steps % train_after_steps == 0 and not done:
            print(f"Step: {steps}:{steps//train_after_steps} | CurrentReward:{total_reward}", end = '\r')
            agent.train_step()
        
        if done:
            print(f"Episode: {i_episode} | Reward: {total_reward} \
| Duration:{steps} | Epsilon: {agent.get_exploration_rate():.4f} \
| memory_len: {len(agent.memory.memory)} | train_steps:{agent.num_train_steps}" )
            agent.episode_durations.append(steps)
            agent.rewards.append(total_reward)
    if i_episode % 100 == 0:
        agent.plot_durations()
    if i_episode % 500 == 0:
        try:
            save_model(i_episode)
            show_game(agent, num_stacked_frames)
        except:
            print('Game Show Error!!!!!')
            pass
        print('Continuing Training')
    
print('Training Complete')
agent.plot_durations()