In [16]:
import sys
import torch
import gym
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pathlib import Path
from torch.distributions import Categorical
from collections import namedtuple, deque
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from itertools import count
import pandas as pd
import math
import random
import cv2

In [17]:
env_name = "SpaceInvaders-v4"
env = gym.make(env_name, obs_type = 'grayscale')

In [18]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<matplotlib.pyplot._IonContext at 0x1d9f252c2e0>

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [20]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

In [21]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque([], maxlen = capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = [], [], [], [], []
        for i in batch:
            states.append(i.state)
            actions.append(i.action)
            rewards.append(i.reward)
            next_states.append(i.next_state)
            dones.append(i.done)
        
        # print(states.shape, actions.shape, rewards.shape, next_states.shape, dones.shape)
        # print(type(states), type(actions), type(rewards), type(next_states), type(dones))
        # print(type(states[0]), type(actions[0]), type(rewards[0]), type(next_states[0]), type(dones[0]))
        # print(states[0].shape, actions[0].shape, rewards[0].shape, next_states[0].shape, dones[0].shape)

        return (
            torch.vstack(states).to(device),
            torch.vstack(actions).to(device),
            torch.vstack(rewards).to(device),
            torch.vstack(next_states).to(device),
            torch.tensor(np.array(dones), dtype = torch.float).to(device).unsqueeze(1)
        )
    def __len__(self):
        return len(self.memory)

In [22]:
class DDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DDQN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(input_shape, 32, kernel_size = 8, stride = 4, 
                      padding = 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 4, stride = 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*7*7, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def forward(self, x):
        return self.network(x)

In [23]:
class Agent:
    def __init__(self, env, 
                 input_shape = 4,
                 BATCH_SIZE = 512,
                GAMMA = 0.99,
                EPSILON = 1,
                MIN_EPSILON = 0.05,
                EPSILON_DECAY = 10000,
                FRAME_STACKING=4, 
                FRAME_SKIP=4,
                TAU = 0.005,
                LR = 1e-4,
                update_every = 1):
        self.BATCH_SIZE = BATCH_SIZE 
        self.GAMMA = GAMMA 
        self.EPSILON = EPSILON 
        self.MIN_EPSILON = MIN_EPSILON 
        self.EPSILON_DECAY = EPSILON_DECAY 
        self.FRAME_STACKING = FRAME_STACKING
        self.FRAME_SKIP = FRAME_SKIP
        self.TAU = TAU 
        self.LR = LR 
        self.update_every = update_every 

        self.Transition = namedtuple('Transition',
                                     ('state', 'action', 'reward', 'next_state', 'done'))
        
        self.memory = ReplayMemory(1000000)
        self.n_actions = env.action_space.n
        
        self.input_shape = input_shape
        self.q_eval = DDQN(self.input_shape, self.n_actions).to(device)
        self.q_target = DDQN(self.input_shape, self.n_actions).to(device)

        self.optimizer = optim.AdamW(self.q_eval.parameters(),
                                     lr=self.LR, amsgrad=True)

        self.steps_done = 0
        self.num_train_steps = 0
        self.episode_durations = []
        self.rewards = []
        self.mean_rewards = []

    def get_exploration_rate(self, update = True):
        if update:
            self.EPSILON = self.MIN_EPSILON + (1.0 - self.MIN_EPSILON) * math.exp(-1.0 * self.steps_done / self.EPSILON_DECAY)
        return self.EPSILON
    
    def step(self, state, action, reward, next_state, done):
        self.memory.push(state, action, reward, next_state, done)
        
        self.steps_done += 1
        if self.steps_done % self.update_every == 0:
            if len(self.memory) > self.BATCH_SIZE:
                experiences = self.memory.sample(self.BATCH_SIZE)
                self.learn(experiences)

    
    def learn(self, experiences):
        states, actions, rewards, next_states, dones = experiences
        # print(states.shape, actions.shape, rewards.shape, next_states.shape, dones.shape)
        # print(type(states), type(actions), type(rewards), type(next_states), type(dones))
        # print(type(states[0]), type(actions[0]), type(rewards[0]), type(next_states[0]), type(dones[0]))
        # print(states[0].shape, actions[0].shape, rewards[0].shape, next_states[0].shape, dones[0].shape)
        q_target_next = self.q_target(next_states).detach().max(1)[0].unsqueeze(1)
        q_targets = rewards + (self.GAMMA * q_target_next * (1 - dones))

        q_expected = self.q_eval(states).gather(1, actions.view(-1,1))
        # print(q_expected.shape, q_targets.shape)
        loss = F.mse_loss(q_expected, q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        for target_param, eval_param in zip(self.q_target.parameters(), self.q_eval.parameters()):
            target_param.data.copy_(self.TAU * eval_param.data + (1.0 - self.TAU) * target_param.data)


    def select_action(self, state, explore=True):
        sample = random.random()
        eps_thresh = self.get_exploration_rate(update = True)
        self.steps_done += 1
        if sample > eps_thresh or not explore:
            with torch.no_grad():
                return self.q_eval(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.n_actions)]], device=device, dtype=torch.long)

    def plot_durations(self):
        plt.figure(1)
        durations_t = torch.tensor(self.episode_durations, dtype=torch.float)
        plt.title('Episode_Lengths')
        plt.xlabel('Episode')
        plt.ylabel('Duration')
        plt.plot(durations_t.numpy())

        if len(durations_t) >= 100:
            means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
            means = torch.cat((torch.zeros(99), means))
            plt.plot(means.numpy())

        plt.show()

        plt.figure(2)
        rewards_t = torch.tensor(self.rewards, dtype=torch.float)
        mean_rewards_t = torch.tensor(self.mean_rewards, dtype=torch.float)
        plt.title('Rewards')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.plot(rewards_t.numpy())
        plt.plot(mean_rewards_t.numpy())
        plt.show()

    # def save_model(self, name, checkpoint_name):
    #     MODEL_PATH = Path('weights')
    #     MODEL_PATH.mkdir(parents=True, exist_ok=True)

    #     MODEL_NAME = name + checkpoint_name + '.pth'
    #     MODEL


In [24]:
def preprocess(img):
    img = cv2.resize(img, (84, 84))
    img = torch.FloatTensor(img).to(device)
    img /= 255.0
    return img

In [25]:
def process_frame_skip(action, num_stacked_frames, num_skip_frames, 
                       stacked_frames, env):
    total_reward = 0
    for i in range(num_skip_frames):
        next_state, reward, terminated, truncated, _ = env.step(action)
        stacked_frames.append(preprocess(next_state))
        total_reward += reward
        done = terminated or truncated
        if done:
            break
        stacked_frames.append(preprocess(next_state))
    return stacked_frames, total_reward, done

In [26]:
def get_frames_tensor(stacked_frames):
    state_tensor = torch.stack(list(stacked_frames.copy()), dim = 2)
    state_tensor = torch.permute(state_tensor, (2, 0, 1)).unsqueeze(0)
#     print(state_tensor.shape)
    return state_tensor

In [27]:
def show_game(agent, num_stacked_frames, show_action = False):
    env2 = gym.make(env_name, render_mode = 'human', obs_type = 'grayscale')
    stacked_frames = deque([torch.zeros((84,84)).to(device) 
                            for i in range(num_stacked_frames)], 
                           maxlen = num_stacked_frames)
    state = env2.reset()[0]
    state = preprocess(state)
    stacked_frames.append(state)
#     state_tensor = get_frame_tensor(stacked_frames.copy())
    done = False
    steps = 0
    total_reward = 0
    while not done:
        steps += 1
        state_tensor = get_frames_tensor(stacked_frames.copy())
        action = agent.select_action(state_tensor)
        if show_action:
            print(action.item(), end = ' ')
        next_state, reward, terminated, truncated, _ = env2.step(action.item())
        stacked_frames.append(preprocess(next_state))
        total_reward += reward
        done = terminated or truncated
    env2.close()
    print()
    print(f'Game Completed# Reward:{total_reward} | Game_Length:{steps}')
    return

In [29]:
agent = Agent(env,
              input_shape = 4,
              BATCH_SIZE = 64,
              GAMMA = 0.99, 
              EPSILON = 1, 
              MIN_EPSILON = 0.05, 
              EPSILON_DECAY = 100000,
              FRAME_STACKING=4,
              FRAME_SKIP=3,
              TAU = 0.005, 
              LR = 1e-5, 
              update_every = 4)

In [None]:
if torch.cuda.is_available():
    num_episodes = 50000
else:
    num_episodes = 50

num_stacked_frames = 4
num_skip_frames = 3
train_after_steps = 4

for i_episode in range(num_episodes):
    print(f"Episode: {i_episode}")

    stacked_frames = deque([torch.zeros((84,84)).to(device)
                            for i in range(num_stacked_frames)],
                           maxlen = num_stacked_frames)
    frame, info = env.reset()
    state = preprocess(frame)
    stacked_frames.append(state)
    states_tensor = get_frames_tensor(stacked_frames.copy())
    # score = 0
    done = False
    steps = 0
    state = state.unsqueeze(0)
    total_reward = 0
    while not done:
        steps += 1
        action = agent.select_action(states_tensor)
        next_stacked_frames, reward, done = process_frame_skip(action.item(), 
                                                      num_stacked_frames, 
                                                      num_skip_frames, 
                                                      stacked_frames.copy(),
                                                      env)
        total_reward += reward
        reward = torch.FloatTensor([reward]).to(device)
        # observation, reward, terminated, truncated, _ = env.step(action.item())
        # agent.step(state, action, reward, observation, terminated or truncated)
        # score += reward
        # done = terminated or truncated

        if done:
            next_stacked_frames = deque([torch.zeros((84,84)).to(device) 
                                    for i in range(num_stacked_frames)], 
                                   maxlen = num_stacked_frames)
        agent.step(states_tensor, action, reward, get_frames_tensor(next_stacked_frames.copy()), done)

        stacked_frames = next_stacked_frames

        if done:
            agent.episode_durations.append(steps + 1)
            agent.rewards.append(total_reward)
            agent.mean_rewards.append(np.mean(agent.rewards[-100:]))
            print(f"\tScore: {total_reward} | Mean: {np.mean(agent.rewards[-100:])} | Epsilon: {agent.get_exploration_rate(update=False)} | Duration: {steps+1}")
            if i_episode % 100 == 0 and i_episode > 0:
                try:
                    agent.plot_durations()
                    if i_episode % 1000 == 0:
                        show_game(num_games=1)
                except:
                    print('Game_Show_Error!!!...')
                    pass
                print("Continuing")
            

print('Complete')
agent.plot_durations()
plt.ioff()
plt.show()


In [None]:
agent.plot_durations()