# Deep Q-Network (DQN) for Atari-Pong
---
In this notebook, we will implement a DQN agent with OpenAI Gym's Pong-v0 environment.
The main challenge is that we train the Deep Reinforcement Learning agent directly using pixels as inputs.

### 1. Import the Necessary Packages

In [1]:
import gym
import random
import torch
import numpy as np
from gym import spaces
import cv2
from collections import deque
import matplotlib.pyplot as plt
from atari_wrappers import *
%matplotlib inline

### 2. Instantiate the Environment and Agent

Initialize the environment in the code cell below. The WarpFrame class is a utility class used to translate the original RGB image provided by the OpenAi Gym environment (260, 160, 3) into a grayscale image of dimension (1, 84, 84), as in the original DQN paper.

In [2]:
#env = make_atari('Pong-v0')
#env = gym.make('PongNoFrameskip-v4')
env = gym.make('Pong-v0')


env = wrap_deepmind(env)

env.seed(2)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)

State shape:  (84, 84, 4)
Number of actions:  6


Initialise the DQN agent

In [3]:
from dqn_agent import Agent

# For games such as Pong, we stack 4 grayscale images together, to catch movement. The input will be 84x84x4.
agent = Agent(in_channels=4, action_size=env.action_space.n, seed=2)

### 3. Train the Agent with DQN

Run the code cell below to train the agent from scratch.  You are welcome to amend the supplied values of the parameters in the function, to try to see if you can get better performance!

Alternatively, you can skip to the next step below (**4. Watch a Smart Agent!**), to load the saved model weights from a pre-trained agent.

In [None]:
def dqn(n_episodes=100000, max_t=1000, eps_start=1.0, eps_end=0.02, eps_decay=0.98):
    """Deep Q-Learning.
    
    Params
    ======
        n_episodes (int): maximum number of training episodes
        max_t (int): maximum number of timesteps per episode
        eps_start (float): starting value of epsilon, for epsilon-greedy action selection
        eps_end (float): minimum value of epsilon
        eps_decay (float): multiplicative factor (per episode) for decreasing epsilon
    """
    scores = []                        # list containing scores from each episode
    scores_window = deque(maxlen=100)  # last 100 scores
    eps = eps_start                    # initialize epsilon
    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        score = 0
        for t in range(max_t):
            action = agent.act(state, eps)
            next_state, reward, done, _ = env.step(action)
            #reward = np.clip(reward, -1, 1)
            agent.step(state, action, reward, next_state, done)
            state = next_state
            score += reward
            if done:
                break 
        scores_window.append(score)       # save most recent score
        scores.append(score)              # save most recent score
        eps = max(eps_end, eps_decay*eps) # decrease epsilon
        print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
            torch.save(agent.qnetwork_local.state_dict(), 'checkpoint_periodic.pth')
        if np.mean(scores_window)>=18.0:
            print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
            torch.save(agent.qnetwork_local.state_dict(), 'checkpoint_final.pth')
            break
    return scores

scores = dqn()

# plot the scores
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.arange(len(scores)), scores)
plt.ylabel('Score')
plt.xlabel('Episode #')
plt.show()

Episode 26	Average Score: -10.12

### 4. Watch a Smart Agent!

In the next code cell, you will load the trained weights from file to watch a smart agent!

In [None]:
# load the weights from file
agent.qnetwork_local.load_state_dict(torch.load('checkpoint_periodic.pth'))

for i in range(3):
    state = env.reset()
    for j in range(2000):
        action = agent.act(state)
        env.render()
        state, reward, done, _ = env.step(action)
        if done:
            break 
            
env.close()