In [None]:
import os
import gym
import numpy as np
import random
import torch
from torch import nn
from torch.nn import functional as F
from PIL import Image
import random
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

## Exercise 5 

Reinforcement learning: the game of pong with proximal policy optimisation.

The goals are:
1. understand how reinforcement learning connects with the neural networks we saw in previous exercises
2. understand the concpet of a "policy" network
3. understand the data collection/training loop

<b> In this exercise you have to understand what is going on and build a training loop and let the agent score at least 5 points. </b>

In [None]:
%load_ext autoreload
%autoreload 2

### Loading the environment 

The position of the paddles and the ball define the state of our enviroment. The agent recieves the state as the input and it outputs the action to take. The action is taken and the state changes (giving a reward to the agent). This cycle repeats itself!

The enviroment will give a feedback to the agent (+1 if the player scores, -1 if the computer scores, or 0).

<img src="RL_architecture.jpeg" width="800" height="400">

In [None]:
!python3 -m atari_py.import_roms roms

In [None]:
env = gym.make('PongNoFrameskip-v4')
env.reset();

In [None]:
gym.__version__

In [None]:
env.action_space

In [None]:
# set of all the actions, we will assume only RIGHT and LEFT are relevant!

env.unwrapped.get_action_meanings()

In [None]:
state = env.reset()

screens = []

for t in range(190000):
    
    next_state, reward, done, info = env.step(env.action_space.sample())
    screens.append(next_state)
    
    if done:
        break

The following cells are for creating the animation!

In [None]:
fig, ax = plt.subplots()
ims = []

for i, screen in enumerate( screens ):
    if i % 5 !=0:
        continue
    im = ax.imshow(screen,animated=True)
    if i == 0:
        ax.imshow(screen)
    ims.append([im])

    if i > 5000:
        break

ani = animation.ArtistAnimation(fig, ims, interval=20, blit=False)

In [None]:
HTML(ani.to_jshtml())

### Preprocessing

We want to cut away all the uneeded information from the screen, turn the image into a binary image only showing the paddles and ball.

In [None]:
from model import PreProcess, PolicyNetwork

In [None]:
preprocess = PreProcess()

In [None]:
fig,ax = plt.subplots(1,2,figsize=(6,3),dpi=120)

idx = 1892
ax[0].imshow(screens[idx])
ax[1].imshow( preprocess(screens[idx]) )

plt.show()

In [None]:
preprocess(screens[idx]).shape

In [None]:
preprocess(screens[idx]).view(-1).unsqueeze(0).shape

## Neural Network

The network will take 2 states (the current and the previous to give some sense of motion) and output the logits, the numbers that represent the probablity to pick between the two actions (left and right).

We give 2 states to give a sense of motion to the neural network (where the ball is going). From the output we generate a categorical distribution, a discrete probability distribution that describes the possible results of a random variable that can take on one of K possible categories, with the probability of each category separately specified.

In [None]:
net = PolicyNetwork()

In [None]:
state = preprocess(screens[idx]).view(-1).unsqueeze(0)
previous_state = preprocess(screens[idx-1]).view(-1).unsqueeze(0)

# Logits
net(state,previous_state)

In [None]:
# Converting the logits to probabilities
torch.softmax(net(state,previous_state),dim=1)

In [None]:
# Picks an action and report its probability
net.sample_action(state,previous_state)

## DataLoader and PolicyLoss

At every step of the loop I save the current state, the previous state, the action, its probability and the reward I get. I let the game playing sometimes without training the network but saving all the actions and rewards.

After the end of the game I give a delayed reward (reward knowing the future). The idea is to give a positive reward (that decays in time) to all the actions the allowed us to get a point. Similarly, I give a negative reward to all the actions the brought us to loosing!

The loss function will have as an input the state, action, probability of the action, reward and delayed reward. We look at the action and delayed reward, if the reward is positive I want to increase the action probability (and viceversa).
Defining $a$ as action and $dr$ as delayed reward

\begin{equation}
- \frac{P_{new} \left( a \right)}{P_{old} \left( a \right)} \cdot dr  
\end{equation}

where $P_{new}$ is the probability of action take now by the NN and $P_{old}$ is the probability during the beginning games. We divide by the old probability in order to not allowing a drastic change!

In [None]:
from dataloader import GamesMemoryBank
from policy_loss import PolicyLoss

from torch.utils.data import Dataset, DataLoader, RandomSampler

### Training Loop

In [None]:
net = PolicyNetwork()
loss_func = PolicyLoss()
memory_bank = GamesMemoryBank()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm
from IPython.display import clear_output

In [None]:
if not os.path.exists('trained_model.pt'):
    n_epochs = 100
    games_per_epoch = 10
    steps_per_game = 190000
    batch_size = 24000
    num_batches = 5

    points_scored_per_game = []

    for epoch in range(n_epochs):

        # We clear the memory bank
        memory_bank.clear_memory()

        # First part is about letting the agent play and store all the action it takes
        net.eval()
        net.cpu()

        for game_i in tqdm( range(games_per_epoch) ):

            state, previous_state = env.reset(), None

            state = preprocess(state).view(-1).unsqueeze(0)
            previous_state = preprocess(previous_state).view(-1).unsqueeze(0)

            points_in_game = 0

            for t in range(steps_per_game):
                with torch.no_grad():

                    action, action_prob = net.sample_action(state,previous_state)

                new_state, reward, done, info = env.step(action+2) # +2 is because in the set of actions left and right are idx 2 and 3

                memory_bank.add_event(...)

                previous_state = state
                state = ...

                if reward > 0:
                    points_in_game+=1

                if done:
                    points_scored_per_game.append(points_in_game)
                    break

        # We compute the rewards based on the history of actions
        memory_bank.compute_reward_history()

        clear_output(wait=True)
        plt.title('epoch '+ str(epoch) + ', mean points per last 10 games ' + str(np.mean(points_scored_per_game[-10:])))
        plt.plot(points_scored_per_game)
        plt.xlim(0,1000)
        plt.ylim(-1,21)
        plt.xlabel('n_epochs')
        plt.ylabel('points scored')
        plt.show()

        # Training phase
        net.train()

        for batch_i in range(num_batches):

            optimizer...

            state, previous_state, action, action_prob, reward, discounted_reward = memory_bank.get_sample(batch_size)

            # Be careful of the shape
            logits = ...
            loss = ...

            loss...
            optimizer..

        torch.save(net.state_dict(), 'trained_model.pt')

### Another game :)

I want to see the improvements made by the neural network!

In [None]:
net.load_state_dict(torch.load('trained_model.pt',map_location='cpu'))

In [None]:
state, previous_state = env.reset(), None

screens = []

state = preprocess(state).view(-1).unsqueeze(0)
previous_state = preprocess(previous_state).view(-1).unsqueeze(0)

for t in range(190000):
    
    action, action_prob = net.sample_action(state,previous_state)

    new_state, reward, done, info = env.step(action+2)
    
    next_state, reward, done, info = env.step(env.action_space.sample())
    screens.append(next_state)
    
    previous_state = state
    state = preprocess(new_state).view(-1).unsqueeze(0)

    if done:
        break

In [None]:
fig, ax = plt.subplots()
ims = []

for i, screen in enumerate( screens ):
    if i % 5 !=0:
        continue
    im = ax.imshow(screen,animated=True)
    if i == 0:
        ax.imshow(screen)
    ims.append([im])

    if i > 5000:
        break

ani_results = animation.ArtistAnimation(fig, ims, interval=20, blit=False)

In [None]:
HTML(ani_results.to_jshtml())

The performance clearly increased, but we can do better!