# Welcome!
Below, we will learn to implement and train a policy to play atari-pong, using only the pixels as input. We will use convolutional neural nets, multiprocessing, and pytorch to implement and train our policy. Let's get started!

In [None]:
# custom utilies for displaying animation, collecting rollouts and more
import pong_utils

import numpy as np
from IPython import display

import matplotlib.pyplot as plt

%matplotlib inline

# check which device is being used. 
# I recommend disabling gpu until you've made sure that the code runs
device = pong_utils.device
print("using device: ",device)

In [None]:
# render ai gym environment
import gym
import time

# PongDeterministic does not contain random frameskip
# so is faster to train than the vanilla Pong-v4 environment
env = gym.make('LunarLander-v2')

print("State space: ", env.observation_space)
print("Action space: ", env.action_space.n)


# Preprocessing
To speed up training, we can simplify the input by cropping the images and use every other pixel



# Policy

## Exercise 1: Implement your policy
 
Here, we define our policy. The input is the stack of two different frames (which captures the movement), and the output is a number $P_{\rm right}$, the probability of moving left. Note that $P_{\rm left}= 1-P_{\rm right}$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()

        self.fc1 = nn.Linear(16, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 4)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return torch.softmax(x, dim=0)

    def act(self, state):
        prob = self.forward(torch.Tensor(state))
        try:
            m = Categorical(logits = prob)
            action = m.sample()
            return action.cpu().numpy(), m.probs.gather(1, action.unsqueeze(1))
        except:
            print("Failed with probabilities", prob, "actions", action.unsqueeze(1))
            raise

    
    def act_with_existing_actions(self, states, actions):
        probs = self.forward(torch.tensor(states).type(torch.float)).cpu()
        try:
            m = Categorical(logits = probs)
            sel_probs = m.probs.gather(2, torch.tensor(actions).type(torch.long).unsqueeze(2))
            return actions, sel_probs
        except:
            print("Failed with probabilities", m.probs, "actions", actions.unsqueeze(2))
            raise


policy=Policy().to(device)

import torch.optim as optim
optimizer = optim.Adam(policy.parameters(), lr=1e-4)

# Game visualization
pong_utils contain a play function given the environment and a policy. An optional preprocess function can be supplied. Here we define a function that plays a game and shows learning progress

In [None]:
def play(env, policy, time=2000, preprocess=None, nrand=5):
    rewards = []
    actions = []
    env.reset()

    # star game
    env.step(1)
    
    # perform nrand random steps in the beginning
    for _ in range(nrand):
        frame1, r1, is_done, _ = env.step(np.random.choice(4))
        frame2, r2, is_done, _ = env.step(0)
        env.render()
        rewards.append(r1)
        rewards.append(r2)
    
    anim_frames = []
    
    for i in range(time):
        
        frame_input = [np.concatenate([frame1, frame2])]
        action, log_prob = policy.act(frame_input)
        
        actions.append(action)
        frame1, r1, is_done, _ = env.step(action[0])
        frame2, r2, is_done, _ = env.step(1)
        env.render()
        rewards.append(r1)
        rewards.append(r2)

        if is_done:
            break
    
    env.close()
    
    return rewards,actions

In [None]:
play(env, policy, time=200) 

# Function Definitions
Here you will define key functions for training. 

## Exercise 2: write your own function for training
(what I call scalar function is the same as policy_loss up to a negative sign)

### PPO
Later on, you'll implement the PPO algorithm as well, and the scalar function is given by
$\frac{1}{T}\sum^T_t \min\left\{R_{t}^{\rm future}\frac{\pi_{\theta'}(a_t|s_t)}{\pi_{\theta}(a_t|s_t)},R_{t}^{\rm future}{\rm clip}_{\epsilon}\!\left(\frac{\pi_{\theta'}(a_t|s_t)}{\pi_{\theta}(a_t|s_t)}\right)\right\}$

the ${\rm clip}_\epsilon$ function is implemented in pytorch as ```torch.clamp(ratio, 1-epsilon, 1+epsilon)```

In [None]:
def collect_trajectories(envs, policy, tmax=200, nrand=5):
    
    # number of parallel instances
    n=len(envs.ps)

    #initialize returning lists and start the game!
    state_list=[]
    reward_list=[]
    prob_list=[]
    action_list=[]

    envs.reset()
    
    # start all parallel agents
    envs.step([1]*n)
    
    # perform nrand random steps
    for _ in range(nrand):
        fr1, re1, _, _ = envs.step(np.random.choice(4,n))
        fr2, re2, _, _ = envs.step([0]*n)
    
    for t in range(tmax):

        # prepare the input
        # preprocess_batch properly converts two frames into 
        # shape (n, 2, 80, 80), the proper input for the policy
        # this is required when building CNN with pytorch
        batch_input = np.concatenate([fr1,fr2], axis=1)

        
        # probs will only be used as the pi_old
        # no gradient propagation is needed
        # so we move it to the cpu
        actions, action_probs = policy.act(batch_input)
        
        # advance the game (0=no action)
        # we take one action and skip game forward
        fr1, re1, is_done, _ = envs.step(actions)
        fr2, re2, is_done, _ = envs.step([0]*n)

        reward = re1 + re2
        
        # store the result
        state_list.append(batch_input)
        reward_list.append(reward)
        prob_list.append(action_probs.detach().numpy())
        action_list.append(actions)
        
        # stop if any of the trajectories is done
        # we want all the lists to be retangular
        if is_done.any():
            break


    # return pi_theta, states, actions, rewards, probability
    return prob_list, state_list, action_list, reward_list



In [None]:
def states_to_prob(policy, states, actions):
    return policy.act_with_existing_actions(states, actions)[1]

# clipped surrogate function
# similar as -policy_loss for REINFORCE, but for PPO
def clipped_surrogate(policy, old_probs, states, actions, rewards,
                      discount=0.995,
                      epsilon=0.1, beta=0.01):

    discount = discount**np.arange(len(rewards))
    rewards = np.asarray(rewards)*discount[:,np.newaxis]
    
    # convert rewards to future rewards
    rewards_future = rewards[::-1].cumsum(axis=0)[::-1]
    
    mean = np.mean(rewards_future, axis=1)
    std = np.std(rewards_future, axis=1) + 1.0e-10

    rewards_normalized = (rewards_future - mean[:,np.newaxis])/std[:,np.newaxis]
    
    # convert everything into pytorch tensors and move to gpu if available
    actions = torch.tensor(actions, dtype=torch.int8, device=device)
    old_probs = torch.tensor(old_probs, dtype=torch.float, device=device).squeeze(2)
    rewards = torch.tensor(rewards_normalized, dtype=torch.float, device=device)
    # convert states to policy (or probability)
    new_probs = states_to_prob(policy, states, actions).squeeze(2)
    
    # ratio for clipping
    ratio = new_probs/old_probs
    
#    print("Ratio", ratio)

    # clipped function
    clip = torch.clamp(ratio, 1-epsilon, 1+epsilon)
    clipped_surrogate = torch.min(ratio*rewards, clip*rewards)
    
#    print("R*R", ratio*rewards)
#    print("Clip", clip)
#    print("Clipped surrogate", clipped_surrogate)
#    print("Clip mean", torch.mean(clipped_surrogate))

    # include a regularization term
    # this steers new_policy towards 0.5
    # add in 1.e-10 to avoid log(0) which gives nan
    entropy = -(new_probs*torch.log(old_probs+1.e-10)+(1.0-new_probs)*torch.log(1.0-old_probs+1.e-10))
    

    
    # this returns an average of all the entries of the tensor
    # effective computing L_sur^clip / T
    # averaged over time-step and number of trajectories
    # this is desirable because we have normalized our rewards
    return torch.mean(clipped_surrogate + beta*entropy)


# Training
We are now ready to train our policy!
WARNING: make sure to turn on GPU, which also enables multicore processing. It may take up to 45 minutes even with GPU enabled, otherwise it will take much longer!

In [None]:
from parallelEnv import parallelEnv
import numpy as np
# keep track of how long training takes
# WARNING: running through all 800 episodes will take 30-45 minutes

# training loop max iterations
episode = 1000


envs = parallelEnv('LunarLander-v2', n=32, seed=1234)

discount_rate = .99
epsilon = 0.3
beta = .01
tmax = 1600
SGD_epoch = 4

# keep track of progress
mean_rewards = []

for e in range(episode):

    # collect trajectories
    old_probs, states, actions, rewards = collect_trajectories(envs, policy, tmax=tmax)
    losses = []
        
    total_rewards = np.sum(rewards, axis=0)
    

    # gradient ascent step
    for _ in range(SGD_epoch):

        L = -clipped_surrogate(policy, old_probs, states, actions, rewards,
                                          epsilon=epsilon, beta=beta)
        losses.append(L.detach().numpy().tolist())
#        print("Loss", L)
        optimizer.zero_grad()
        L.backward()
        optimizer.step()
        del L
    
    # the clipping parameter reduces as time goes on
    epsilon*=.999
    beta*=.995

    mean_rewards.append(np.mean(total_rewards))
    print("Episode: {0:d}, score: {1:f}, losses: {2}".format(e+1,np.mean(total_rewards),losses))

        


In [None]:
plt.plot(mean_rewards)

In [None]:
play(env, policy, time=200) 

In [None]:
# save your policy!
torch.save(policy, 'PPO.policy')

# load policy if needed
# policy = torch.load('PPO.policy')

# try and test out the solution 
# make sure GPU is enabled, otherwise loading will fail
# (the PPO verion can win more often than not)!
#
# policy_solution = torch.load('PPO_solution.policy')
# pong_utils.play(env, policy_solution, time=2000) 