In [None]:
import os
import time
import cv2
import gym
import collections

import numpy as np
import torch as T
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

%matplotlib inline

In [None]:
# T.cuda.get_device_name()

## **Wrappers**

In [None]:
# PREPROCESS EACH FRAME
class PreprocessFrames(gym.ObservationWrapper):
    """
    PREPROCESSES EACH FRAME (input = (rows, columns, 3)) [0-255]
    1. convert to grayscale (3 channels to 1)   -   (rows, columns, 1)  [0-255]
    2. resize to new shape                      -   (new_rows, new_columns)  [0-255]
    3. convert to nparray & reshape             -   array(1, new_rows, new_columns)  [0-255]
    4. scale values from 0-1                    -   array(1, new_rows, new_columns)  [0.0-1.0]
    """
    def __init__(self, env, new_observation_shape):
        super().__init__(env)
        self.new_observation_shape = new_observation_shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.new_observation_shape, dtype=np.float32)
    
    def observation(self, observation):
        temp_frame = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        temp_frame = cv2.resize(temp_frame, self.new_observation_shape[1:], interpolation=cv2.INTER_AREA)
        new_observation = np.array(temp_frame).reshape(self.new_observation_shape)
        new_observation = new_observation / 255.0 
        return new_observation

# TO BE CALLED ON EACH SINGLE IMAGE (AFTER PREPROCESS)
class CustomStep(gym.Wrapper):
    """
    OVERRIDES step() & reset()
    1. repeats same action in 'n' skipped frames to compute faster.
    2. removes flicker in frames by taking max of 2 consecutive frames.
    """
    def __init__(self, env, frame_skip, clip_reward, no_ops, fire_first):
        super().__init__(env)
        self.frame_skip = frame_skip
        self.observation_shape = env.observation_space.shape
        self.observation_buffer = np.zeros_like((2, self.observation_shape))
        # DURING TESTING ONLY
        self.clip_reward = clip_reward
        self.no_ops = no_ops
        self.fire_first = fire_first

    def reset(self):
        observation = self.env.reset()
        # FOR no_ops
        no_ops = (np.random.randint(self.no_ops) + 1) if (self.no_ops > 0) else 0
        for _ in range(no_ops):
            _, _, done, _ = env.step(0) # 0 - NOOP
            if done: self.env.reset()
        # FOR fire_first
        if (self.fire_first):
            assert (self.env.get_action_meanings()[0] == 'FIRE')
            observation, _, _, _ = env.step(1) # 1 - FIRE
        self.observation_buffer = np.zeros_like((2, self.observation_shape))
        self.observation_buffer[0] = observation
        return observation

    # RETURN FRAME_SKIPPED & FLICKER REMOVED FRAMES 
    def step(self, action):
        total_reward = 0.0
        done = False

        for frame in range(self.frame_skip):
            observation, reward, done, info = self.env.step(action)
            # CLIP REWARD (-1,1) IF true
            reward = reward if (not self.clip_reward) else np.clip(reward, -1,1)
            total_reward += reward

            idx = frame % 2
            self.observation_buffer[idx] = observation

            if done: break

        observation_max = np.maximum(self.observation_buffer[0], self.observation_buffer[1])
        return observation_max, total_reward, done, info


# STACK OBSERVATIONS
class StackFrames(gym.ObservationWrapper):
    """
    STACKS stack_size FRAMES TOGETHER AND RETURNS AS THE 'observation'
    1. on reset() returns first 'observation' STACKED 'stack_size' times
    2. observation() returns current 'observation' STACKED with 'stack_size-1' previous 'observation'
    """
    def __init__(self, env, stack_size):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
                                    env.observation_space.low.repeat(stack_size, axis=0),
                                    env.observation_space.high.repeat(stack_size, axis=0)
                                 )
        self.stack = collections.deque(maxlen=stack_size)

    def reset(self):
        self.stack.clear()
        observation = self.env.reset()
        for _ in range(self.stack.maxlen):
            self.stack.append(observation)
        observation = np.array(self.stack).reshape(self.observation_space.shape)
        return observation
        
    def observation(self, observation):
        self.stack.append(observation)
        observation = np.array(self.stack).reshape(self.observation_space.shape)
        return observation

In [None]:
# TIE EVERYTHING TOGETHER
def make_env(env_name, new_observation_shape=(1,84,84), stack_size=4, frame_skip=4, clip_reward=False, no_ops=0, fire_first=False):
    env = gym.make(env_name)
    env = PreprocessFrames(env, new_observation_shape=new_observation_shape)
    env = CustomStep(env, frame_skip=4, clip_reward=clip_reward, no_ops=no_ops, fire_first=fire_first)
    env = StackFrames(env, stack_size=stack_size)
    return env

## **ReplayBuffer**

In [None]:
class ReplayBuffer:
    def __init__(self, mem_size, observation_shape, n_actions):
        self.mem_size = mem_size
        self.mem_counter = 0
        # DATA
        self.states = np.zeros((mem_size, *observation_shape), dtype=np.float32)
        self.actions = np.zeros(mem_size, dtype=np.int64)
        self.rewards = np.zeros(mem_size, dtype=np.int64)
        self.states_ = np.zeros((mem_size, *observation_shape), dtype=np.float32)
        self.terminals = np.zeros(mem_size, dtype=bool)

    # STORE TRANSITIONS IN BUFFER
    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_counter % self.mem_size
        self.states[index] = state
        self.actions[index] = action
        self.rewards[index] = reward
        self.states_[index] = state_
        self.terminals[index] = done # 1 if 'done' else 0
        self.mem_counter += 1

    # UNIFORMLY SAMPLES 'BUFFER' AND RETURNS A 'BATCH' OF batch_size
    def sample_batch(self, batch_size):
        max_index = min(self.mem_counter, self.mem_size)
        batch_indices = np.random.choice(max_index, batch_size, replace=False)
        states = self.states[batch_indices]
        actions = self.actions[batch_indices]
        rewards = self.rewards[batch_indices]
        states_ = self.states_[batch_indices]
        terminals = self.terminals[batch_indices]
        return (states, actions, rewards, states_, terminals)

## **Network**

In [None]:
class DuelingDeepQNetwork(nn.Module):
    def __init__(self, lr, observation_shape, n_actions, model_name, model_dir):
        super().__init__()
        self.model_dir = model_dir
        self.model_file = os.path.join(self.model_dir, model_name)
        # CNN
        self.conv1 = nn.Conv2d(observation_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        # CNN -> ANN
        fc_input_dims = self.caculate_conv_output_dims(observation_shape)
        # ANN
        self.fc1 = nn.Linear(fc_input_dims, 512)
        # DUELING
        self.V = nn.Linear(512, 1)
        self.A = nn.Linear(512, n_actions)
        # UTILS
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.to(self.device)
    
    def forward(self, state):
        t = F.relu(self.conv1(state))
        t = F.relu(self.conv2(t))
        t = F.relu(self.conv3(t))
        t = F.relu(self.fc1(t.reshape(t.shape[0], -1)))
        V = self.V(t)
        A = self.A(t)
        return V,A

    def caculate_conv_output_dims(self, observation_shape):
        dims = T.zeros((1, *observation_shape))
        dims = self.conv1(dims)
        dims = self.conv2(dims)
        dims = self.conv3(dims)
        return int(np.prod(dims.shape))

    def save_model(self):
        print("[INFO] Saving model")
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict' : self.optimizer.state_dict()
        }
        T.save(checkpoint, self.model_file)
    
    def load_model(self, cpu=False):
        print("[INFO] Loading model")
        
        map_location = T.device('cpu') if (cpu) else None
        
        checkpoint = T.load(self.model_file, map_location=map_location)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

## **Agent**

In [None]:
class DuelingDQNAgent:
    def __init__(self, observation_shape, n_actions, lr, gamma, epsilon, epsilon_min, epsilon_decay,
                 mem_size, batch_size, Q_TARGET_replace_interval, algo_name, env_name, model_dir):
        self.observation_shape = observation_shape
        self.n_actions = n_actions
        self.LR = lr
        self.GAMMA = gamma
        self.EPSILON = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay

        # MEM PARAMS
        self.mem_size = mem_size
        self.batch_size = batch_size
        self.memory = ReplayBuffer(mem_size, observation_shape, n_actions)

        # MODEL PARAMS
        self.learn_counter = 0 # TO UPDATE TARGET NETWORK
        self.algo_name = algo_name
        self.env_name = env_name
        self.model_dir = model_dir
        self.Q_TARGET_replace_interval = Q_TARGET_replace_interval
        # Q1
        self.Q_STEP = DuelingDeepQNetwork(lr, observation_shape, n_actions,
                              model_name = algo_name+'_Q_STEP',
                              model_dir = model_dir)
        # Q2
        self.Q_TARGET = DuelingDeepQNetwork(lr, observation_shape, n_actions,
                              model_name = algo_name+'_Q_TARGET',
                              model_dir = model_dir)

    # e-GREEDY POLICY
    def get_action(self, observation, greedy=False):
        if ( (np.random.uniform() >= self.EPSILON) or greedy):
            observation = T.tensor(observation, dtype=T.float32).to(self.Q_STEP.device)
            state = T.unsqueeze(observation, 0)
            _,A = self.Q_STEP(state)
            action = T.argmax(A).item()
        else:
            action = env.action_space.sample()
        return action

    def learn(self):
        if (self.memory.mem_counter < self.batch_size): return # return if insufficient samples present
        # RESET TARGET NETWORK (1 / 1000)
        self.update_Q_TARGET()

        states, actions, rewards, states_, terminals = self.sample_batch()
        # PREDICT Q1(s,a)
        v1,a1 = self.Q_STEP(states)
        q1 = v1 + (a1 - a1.mean(dim=1, keepdim=True))
        indices = np.arange(len(actions))
        q1_preds = q1[indices,actions]

        # GET Q2(s_,a_) WHERE a_ = max(Q2(s_, A))
        v2_, a2_ = self.Q_TARGET(states_)
        q2_ = v2_ + (a2_ - a2_.mean(dim=1, keepdim=True))
        q2_next = (q2_.max(dim=1))[0]       # MAX VAL ACTION (without added reward)
        q2_next[terminals] = 0.0            # Q(s_) = 0 where terminal=1
        q2_targets = rewards + (self.GAMMA * q2_next)

        # CALC LOSS & BACKPROP
        loss = self.Q_STEP.loss(q2_targets, q1_preds).to(self.Q_STEP.device)
        self.Q_STEP.optimizer.zero_grad()
        loss.backward()
        self.Q_STEP.optimizer.step()

        self.learn_counter += 1
        self.decay_epsilon()

    def update_Q_TARGET(self):
        if ((self.learn_counter % self.Q_TARGET_replace_interval) == 0):
            self.Q_TARGET.load_state_dict(self.Q_STEP.state_dict())
    
    def decay_epsilon(self):
        if (self.EPSILON > self.epsilon_min):
            self.EPSILON -= self.epsilon_decay
        else:
            self.EPSILON = self.epsilon_min
    
    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def sample_batch(self):
        states, actions, rewards, states_, terminals = self.memory.sample_batch(self.batch_size)
        states = T.tensor(states).to(self.Q_STEP.device)
        actions = T.tensor(actions).to(self.Q_STEP.device)
        rewards = T.tensor(rewards).to(self.Q_STEP.device)
        states_ = T.tensor(states_).to(self.Q_STEP.device)
        terminals = T.tensor(terminals).to(self.Q_STEP.device)
        return states, actions, rewards, states_, terminals
        
    def save_models(self):
        self.Q_STEP.save_model()
        self.Q_TARGET.save_model()
    
    def load_models(self, cpu=False):
        self.Q_STEP.load_model(cpu)
        self.Q_TARGET.load_model(cpu)

## **Training**

In [None]:
## TRAINING ##

In [None]:
env_name = 'PongNoFrameskip-v4'
env = make_env(env_name)

N_EPISODES = 300

agent = DuelingDQNAgent(observation_shape=env.observation_space.shape,
                  n_actions=env.action_space.n,
                  lr=1e-4,
                  gamma=0.99,
                  epsilon=1.0,
                  epsilon_min=0.06,
                  epsilon_decay=1e-5,
                  mem_size=25000,
                  batch_size=128,
                  Q_TARGET_replace_interval=1000,
                  algo_name='DuelingDQN',
                  env_name=env_name,
                  model_dir='./')

In [None]:
episode_rewards, episode_lengths, episode_epsilons, mean_rewards = [],[],[],[]
best_reward = -np.inf

for episode_n in tqdm(range(N_EPISODES)):
    total_reward, total_moves = 0,0

    done = False
    observation = env.reset()

    while not done:
        # e_GREEDY ACTION
        action = agent.get_action(observation)
        observation_, reward, done, _ = env.step(action)

        total_reward += reward
        total_moves += 1

        # STORE DATA & LEARN
        agent.store_transition(observation, action, reward, observation_, done)
        agent.learn()

        observation = observation_

    episode_rewards.append(total_reward)
    episode_lengths.append(total_moves)
    episode_epsilons.append(agent.EPSILON)

    mean_reward = np.mean(episode_rewards[-100:])
    mean_rewards.append(mean_reward)
    if(mean_reward > best_reward):
        agent.save_models()
        best_reward = mean_reward

    print("ITER: ",episode_n,"\tRWD: ",total_reward,"\tMEAN_RWD: ",round(mean_reward,2),"\tLEN: ",total_moves,"\tEPS: ",round(agent.EPSILON,4))

# Testing

In [None]:
env_name = 'PongNoFrameskip-v4'
env = make_env(env_name)

agent = DuelingDQNAgent(observation_shape=env.observation_space.shape,
                 n_actions=env.action_space.n,
                 lr=1e-4,
                 gamma=0.99,
                 epsilon=0.001,
                 epsilon_min=0.001,
                 epsilon_decay=1e-5,
                 mem_size=1,
                 batch_size=1,
                 Q_TARGET_replace_interval=1000,
                 algo_name='DuelingDQN',
                 env_name=env_name,
                 model_dir='./')

In [None]:
agent.load_models(cpu=True)

In [None]:
with T.no_grad():
    total_reward, total_moves = 0,0
    done = False
    observation = env.reset()

    while not done:
        time.sleep(0.0001)
        env.render()

        # e_GREEDY ACTION
        action = agent.get_action(observation, greedy=True)
        observation_, reward, done, _ = env.step(action)

        total_reward += reward
        total_moves += 1

        observation = observation_
    print("RWD: ",total_reward,"\tLEN: ",total_moves)
    env.close()