# Welcome!
Below, we will learn to implement and train a policy to play atari-pong, using only the pixels as input. We will use convolutional neural nets, multiprocessing, and pytorch to implement and train our policy. Let's get started!

In [1]:
# install package for displaying animation
!pip install JSAnimation

# custom utilies for displaying animation, collecting rollouts and more
import pong_utils

%matplotlib inline

# check which device is being used. 
# I recommend disabling gpu until you've made sure that the code runs
device = pong_utils.device
print("using device: ",device)

Collecting JSAnimation
  Downloading https://files.pythonhosted.org/packages/3c/e6/a93a578400c38a43af8b4271334ed2444b42d65580f1d6721c9fe32e9fd8/JSAnimation-0.1.tar.gz
Building wheels for collected packages: JSAnimation
  Running setup.py bdist_wheel for JSAnimation ... [?25ldone
[?25h  Stored in directory: /root/.cache/pip/wheels/3c/c2/b2/b444dffc3eed9c78139288d301c4009a42c0dd061d3b62cead
Successfully built JSAnimation
Installing collected packages: JSAnimation
Successfully installed JSAnimation-0.1
using device:  cpu


In [3]:
# render ai gym environment
import gym
import time

# PongDeterministic does not contain random frameskip
# so is faster to train than the vanilla Pong-v4 environment
env = gym.make('Pendulum-v0')

env.seed(2)
print('observation space:', env.observation_space)
print('action space:', env.action_space)


# we will only use the actions 'RIGHTFIRE' = 4 and 'LEFTFIRE" = 5
# the 'FIRE' part ensures that the game starts again after losing a life
# the actions are hard-coded in pong_utils.py

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
observation space: Box(3,)
action space: Box(1,)


# Preprocessing
To speed up training, we can simplify the input by cropping the images and use every other pixel



# Policy

## Exercise 1: Implement your policy
 
Here, we define our policy. The input is the stack of two different frames (which captures the movement), and the output is a number $P_{\rm right}$, the probability of moving left. Note that $P_{\rm left}= 1-P_{\rm right}$

In [201]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal


# set up a convolutional neural net
# the output is the probability of moving right
# P(left) = 1-P(right)
class Policy(nn.Module):
    
    def __init__(self, s_size=3, h_size=300, a_size=1):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(s_size, h_size)
        self.fc2 = nn.Linear(h_size, h_size)
        
        self.mu = nn.Linear(h_size, a_size)
        self.std = nn.Linear(h_size, a_size)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        
        mu = F.tanh(self.mu(x))
        std = F.softplus(self.std(x))
        return mu, std

    def act(self, state, batch_mode=False):
        if not batch_mode:
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        mu, std = self.forward(state)
        
        m = Normal(mu, std)
        action = m.sample().clamp(-2,2) # obs action space is betwween -2 and 2
        
        if not batch_mode:
            action = action.item() 
        
        return action, m.cdf(action), m.entropy()
# run your own policy!
policy=Policy().to(device)
#policy=pong_utils.Policy().to(device)

# we use the adam optimizer with learning rate 2e-4
# optim.SGD is also possible
import torch.optim as optim
optimizer = optim.Adam(policy.parameters(), lr=1e-3)

In [202]:
policy

Policy(
  (fc1): Linear(in_features=3, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=300, bias=True)
  (mu): Linear(in_features=300, out_features=1, bias=True)
  (std): Linear(in_features=300, out_features=1, bias=True)
)

# Game visualization
pong_utils contain a play function given the environment and a policy. An optional preprocess function can be supplied. Here we define a function that plays a game and shows learning progress

# Function Definitions
Here you will define key functions for training. 

## Exercise 2: write your own function for training
(what I call scalar function is the same as policy_loss up to a negative sign)

### PPO
Later on, you'll implement the PPO algorithm as well, and the scalar function is given by
$\frac{1}{T}\sum^T_t \min\left\{R_{t}^{\rm future}\frac{\pi_{\theta'}(a_t|s_t)}{\pi_{\theta}(a_t|s_t)},R_{t}^{\rm future}{\rm clip}_{\epsilon}\!\left(\frac{\pi_{\theta'}(a_t|s_t)}{\pi_{\theta}(a_t|s_t)}\right)\right\}$

the ${\rm clip}_\epsilon$ function is implemented in pytorch as ```torch.clamp(ratio, 1-epsilon, 1+epsilon)```

In [34]:
states =  envs.reset()

In [35]:
states

array([[-0.07472608, -0.9972041 ,  0.4784946 ],
       [ 0.97676342, -0.21432038,  0.46877921],
       [ 0.06720443,  0.99773923,  0.09108689],
       [-0.99731068, -0.07328989,  0.77271367],
       [-0.32019129,  0.94735291,  0.33915817],
       [-0.88013181, -0.47472939,  0.44229691],
       [-0.94415689,  0.32949623,  0.03627558],
       [-0.76469761, -0.6443893 ,  0.49232976]])

In [111]:
test

tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9]])

In [None]:
# collect trajectories for a parallelized parallelEnv object
def collect_trajectories(envs, policy, tmax=200, nrand=5):
    
    # number of parallel instances
    n=len(envs.ps)

    #initialize returning lists and start the game!
    state_list=[]
    reward_list=[]
    prob_list=[]
    action_list=[]

    states = envs.reset()

    states = torch.from_numpy(states).float() 

    for t in range(tmax):
        
        actions, probs, entrops = policy.act(states, batch_mode=True)        
        states, rewards, dones, _  = envs.step(actions)
        states = torch.from_numpy(states).float()
        # store the result
        state_list.append(states)
        reward_list.append(rewards)
        prob_list.append(probs.squeeze().cpu().detach().numpy())
        action_list.append(actions.squeeze().cpu().detach().numpy())
        
        # stop if any of the trajectories is done
        # we want all the lists to be retangular
        if dones.any():
            break

    #state_list = torch.stack(state_list).squeeze()
    #prob_list = torch.stack(prob_list).squeeze()
    #action_list = torch.stack(action_list).squeeze()
    # return pi_theta, states, actions, rewards, probability
    return prob_list, state_list, \
        action_list, reward_list

In [None]:
old_probs, states, actions, rewards = \
             collect_trajectories(envs, policy, tmax=300)

In [None]:
old_probs.shape

In [None]:
np.asarray(rewards)

In [None]:
test = torch.from_numpy(np.arange(10).reshape((5,2)))
discount = 0.99**np.arange(len(test))

In [204]:
test, discount

(tensor([[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7],
         [ 8,  9]]),
 array([ 1.        ,  0.99      ,  0.9801    ,  0.970299  ,  0.96059601]))

In [None]:
def clipped_surrogate(policy, old_probs, states, actions, rewards,
                      discount = 0.995, epsilon=0.1, beta=0.01):

    ########
    ## 
    ## WRITE YOUR OWN CODE HERE
    ##
    ########
    
    actions = torch.tensor(actions, dtype=torch.int8, device=device)

    # convert states to policy (or probability)
    new_probs = pong_utils.states_to_prob(policy, states)
    new_probs = torch.where(actions == pong_utils.RIGHT, new_probs, 1.0-new_probs)

    # include a regularization term
    # this steers new_policy towards 0.5
    # prevents policy to become exactly 0 or 1 helps exploration
    # add in 1.e-10 to avoid log(0) which gives nan
    entropy = -(new_probs*torch.log(old_probs+1.e-10)+ \
        (1.0-new_probs)*torch.log(1.0-old_probs+1.e-10))

    return torch.mean(beta*entropy)


In [205]:
def clipped_surrogate_continuous(policy, old_probs, states, actions, rewards,
                      discount=0.995,
                      epsilon=0.1, beta=0.01):

    discount = discount**np.arange(len(rewards))
    rewards = np.asarray(rewards)*discount[:,np.newaxis]
    
    # convert rewards to future rewards
    rewards_future = rewards[::-1].cumsum(axis=0)[::-1]
    
    mean = np.mean(rewards_future, axis=1)
    std = np.std(rewards_future, axis=1) + 1.0e-10

    rewards_normalized = (rewards_future - mean[:,np.newaxis])/std[:,np.newaxis]
    rewards = torch.tensor(rewards_normalized, dtype=torch.float, device=device)
    
    actions = torch.tensor(actions, dtype=torch.int8, device=device)
    old_probs = torch.tensor(old_probs, dtype=torch.float, device=device)
    
    _, new_probs, _ = policy.act(torch.stack(states), batch_mode=True)  
    
    new_probs = new_probs.squeeze()
    # ratio for clipping

    ratio = new_probs/old_probs

    # clipped function
    clip = torch.clamp(ratio, 1-epsilon, 1+epsilon)
    clipped_surrogate = torch.min(ratio*rewards, clip*rewards)

    # include a regularization term
    # this steers new_policy towards 0.5
    # add in 1.e-10 to avoid log(0) which gives nan
    entropy = -(new_probs*torch.log(old_probs+1.e-10)+ \
        (1.0-new_probs)*torch.log(1.0-old_probs+1.e-10))

    
    # this returns an average of all the entries of the tensor
    # effective computing L_sur^clip / T
    # averaged over time-step and number of trajectories
    # this is desirable because we have normalized our rewards
    return torch.mean(clipped_surrogate + beta*entropy)


# Training
We are now ready to train our policy!
WARNING: make sure to turn on GPU, which also enables multicore processing. It may take up to 45 minutes even with GPU enabled, otherwise it will take much longer!

In [None]:
from parallelEnv import parallelEnv

import numpy as np
# keep track of how long training takes
# WARNING: running through all 800 episodes will take 30-45 minutes

optimizer = optim.Adam(policy.parameters(), lr=1e-3)
# training loop max iterations
episode = 2000

# widget bar to display progress
!pip install progressbar
import progressbar as pb
widget = ['training loop: ', pb.Percentage(), ' ', 
          pb.Bar(), ' ', pb.ETA() ]
timer = pb.ProgressBar(widgets=widget, maxval=episode).start()


envs = parallelEnv('Pendulum-v0', n=8, seed=1234)

discount_rate = .99
epsilon = 0.1
beta = .01
tmax = 320
SGD_epoch = 4

# keep track of progress
mean_rewards = []

for e in range(episode):

    # collect trajectories
    old_probs, states, actions, rewards = \
        collect_trajectories(envs, policy, tmax=tmax)
        
    total_rewards = np.sum(rewards, axis=0)


    # gradient ascent step
    for _ in range(SGD_epoch):
        
        # uncomment to utilize your own clipped function!
        # L = -clipped_surrogate(policy, old_probs, states, actions, rewards, epsilon=epsilon, beta=beta)

        #L = -pong_utils.clipped_surrogate(policy, old_probs, states, actions, rewards,
        #                                  epsilon=epsilon, beta=beta)
    
        L = -clipped_surrogate_continuous(policy, old_probs, states, actions, rewards,
                                          epsilon=epsilon, beta=beta)
        #print(L.grad_fn)
        optimizer.zero_grad()
        L.backward()
        optimizer.step()
        del L
    
    # the clipping parameter reduces as time goes on
    epsilon*=.999
    
    # the regulation term also reduces
    # this reduces exploration in later runs
    beta*=.995
    
    # get the average reward of the parallel environments
    mean_rewards.append(np.mean(total_rewards))
    
    # display some progress every 20 iterations
    if (e+1)%20 ==0 :
        print("Episode: {0:d}, score: {1:f}".format(e+1,np.mean(total_rewards)))
        print(total_rewards)
        
    # update progress widget bar
    timer.update(e+1)
    
timer.finish()



training loop:   0% |                                          | ETA:  --:--:--

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box au

training loop:   1% |                                           | ETA:  3:02:27

Episode: 20, score: -1231.045602
[ -971.20541383 -1583.80014121 -1069.67028034 -1581.71451825 -1526.56766599
 -1070.46168981  -982.14652612 -1062.79857795]


training loop:   2% |                                           | ETA:  2:58:24

Episode: 40, score: -1253.604319
[-1289.46725265 -1380.48150182 -1190.1387016  -1584.51859791 -1011.45376654
 -1152.25220797 -1454.5734176   -965.94910449]


training loop:   3% |#                                          | ETA:  2:55:40

Episode: 60, score: -1238.770824
[ -975.46196931 -1552.01102686  -969.43518177 -1176.37503771 -1002.19351655
 -1383.68296222 -1254.22733577 -1596.77956141]


training loop:   4% |#                                          | ETA:  2:55:47

Episode: 80, score: -1155.663544
[-1429.67617695 -1596.63135004  -908.36048009 -1290.11951899 -1087.87324337
  -870.51307551 -1066.43260975  -995.70189952]


training loop:   5% |##                                         | ETA:  2:53:06

Episode: 100, score: -1343.730509
[-1566.89704595 -1448.16034138 -1323.2143045  -1276.67092695  -974.92427005
 -1457.85002138  -966.85083899 -1735.27632449]


training loop:   6% |##                                         | ETA:  2:50:48

Episode: 120, score: -1222.853238
[-1099.64696668 -1078.99302066  -961.41498511  -894.00269888 -1431.70230637
 -1035.45372119 -1669.66682132 -1611.94538483]


training loop:   7% |###                                        | ETA:  2:49:16

Episode: 140, score: -1216.158086
[ -898.89993447 -1576.72593061 -1004.98623957 -1322.70277701 -1568.20529525
  -977.90880732 -1391.87698869  -987.9587179 ]


training loop:   8% |###                                        | ETA:  2:47:57

Episode: 160, score: -1240.900161
[-1390.75980305  -965.67007474 -1586.28481828  -981.69889985 -1068.79255876
  -901.93143763 -1694.77478079 -1337.28891176]


training loop:   9% |###                                        | ETA:  2:46:01

Episode: 180, score: -1237.011798
[-1317.82322976 -1209.47325191 -1207.00236628 -1084.51179589 -1183.28393138
 -1087.40258315 -1203.68289419 -1602.91433436]


training loop:  10% |####                                       | ETA:  2:43:48

Episode: 200, score: -1214.106185
[ -891.17693229 -1318.50079417 -1445.96404945 -1314.87429018 -1090.61275462
  -967.29009887  -983.13386396 -1701.29669944]


training loop:  11% |####                                       | ETA:  2:42:38

Episode: 220, score: -1074.677311
[ -998.1851704   -869.76751679  -916.62243176 -1010.02194079 -1260.64651335
  -939.37434818 -1623.04202943  -979.75853896]


training loop:  12% |#####                                      | ETA:  2:40:28

Episode: 240, score: -1352.902071
[-1110.27058167 -1201.97023907 -1472.99470331  -983.37112622 -1793.23879811
 -1077.60582792 -1575.49372517 -1608.27156928]


training loop:  13% |#####                                      | ETA:  2:38:19

Episode: 260, score: -1314.130502
[-1537.69115672 -1465.99357545 -1058.34811404  -972.38793697 -1371.01088351
 -1220.79265877 -1532.34913728 -1354.47055107]


training loop:  14% |######                                     | ETA:  2:36:39

Episode: 280, score: -1206.178833
[-1363.62506288 -1024.62009543 -1028.20107385 -1058.8803659   -981.19621563
 -1692.20564405 -1506.13003009  -994.57217255]


training loop:  15% |######                                     | ETA:  2:35:00

Episode: 300, score: -1261.655925
[ -980.7562829  -1339.21949424 -1076.45688666 -1270.04465957 -1654.41205552
 -1632.03192569 -1120.03632432 -1020.28977119]


training loop:  16% |######                                     | ETA:  2:33:04

Episode: 320, score: -1356.927625
[ -982.06826655  -988.03785255 -1663.68907246  -932.51155674 -1697.37919635
 -1281.91774789 -1764.32113811 -1545.49616718]


training loop:  17% |#######                                    | ETA:  2:31:06

Episode: 340, score: -1272.159060
[ -989.13912396 -1163.65064742 -1164.18399162 -1543.00875518 -1605.40221375
 -1186.80935774  -992.26994833 -1532.80844414]


training loop:  18% |#######                                    | ETA:  2:29:42

Episode: 360, score: -1145.519821
[ -988.56883401 -1694.03247227 -1168.93060334 -1120.988704    -995.60354662
 -1430.00621911  -891.08387037  -874.94431538]


training loop:  19% |########                                   | ETA:  2:27:41

Episode: 380, score: -1069.067353
[-1067.81811742  -998.37176541  -908.27232024 -1581.88583254  -872.59805633
  -881.9657194  -1280.39825867  -961.22875532]


training loop:  20% |########                                   | ETA:  2:25:46

Episode: 400, score: -1149.124252
[-1188.99172649  -973.59159837 -1588.92032409 -1044.63928656 -1144.92893454
 -1076.64660139 -1377.20898714  -798.06655914]


training loop:  21% |#########                                  | ETA:  2:24:04

Episode: 420, score: -1187.553023
[-1586.58855696 -1075.667503    -859.30239651 -1552.90391107  -976.87434857
 -1322.97315955  -869.11262895 -1257.001677  ]


training loop:  22% |#########                                  | ETA:  2:22:14

Episode: 440, score: -1212.730123
[-1014.06798314  -865.74333246 -1298.50192153 -1702.25274586 -1280.43849212
  -849.29804318 -1187.35969908 -1504.17876695]


training loop:  23% |#########                                  | ETA:  2:20:19

Episode: 460, score: -1106.744213
[-1476.61336968 -1043.25355075 -1086.04994936  -980.78318274  -994.203725
 -1176.01549956  -979.42736769 -1117.60705612]


training loop:  24% |##########                                 | ETA:  2:18:23

Episode: 480, score: -1269.213181
[-1335.70017025  -965.84060689 -1402.22642437 -1477.34354171 -1628.06492078
 -1109.01797882 -1385.32001067  -850.19179705]


training loop:  25% |##########                                 | ETA:  2:16:47

Episode: 500, score: -1245.444074
[-1479.01943787 -1002.76672863 -1564.499391   -1425.76222486 -1631.46647512
  -894.1732522   -983.56383932  -982.30123981]


training loop:  26% |###########                                | ETA:  2:14:53

Episode: 520, score: -1305.836121
[-1067.00366122 -1081.58331513 -1151.85310246 -1479.96541074 -1590.33643761
 -1420.86878898 -1002.44065765 -1652.63759151]


training loop:  27% |###########                                | ETA:  2:12:56

Episode: 540, score: -1082.742095
[ -955.56736851  -967.71390943  -834.94822663 -1086.14652868  -873.25379497
 -1211.8569858  -1193.6019797  -1538.84796791]


training loop:  28% |############                               | ETA:  2:11:10

Episode: 560, score: -1186.271833
[-1593.86834202 -1084.28051473 -1051.25131171  -788.1981303  -1541.55336486
  -961.72970706 -1201.95378292 -1267.33950979]


training loop:  29% |############                               | ETA:  2:09:26

Episode: 580, score: -1284.146110
[-1024.55178169 -1026.77097806 -1516.72291729 -1205.0121975  -1480.43615104
 -1629.01355816 -1315.81617787 -1074.84511668]


training loop:  30% |############                               | ETA:  2:07:31

Episode: 600, score: -1136.731451
[-1180.82985039 -1002.25888876  -885.19790868  -966.07803037 -1197.60168905
  -899.57765432 -1368.47397081 -1593.83361503]


training loop:  31% |#############                              | ETA:  2:05:37

Episode: 620, score: -1152.758317
[-1081.59479908 -1485.3402503  -1197.31110261 -1176.28566338  -875.27270034
 -1095.73496579 -1312.66041923  -997.86663791]


training loop:  32% |#############                              | ETA:  2:03:57

Episode: 640, score: -1285.443108
[-1349.52682538 -1388.22189097 -1500.3597072  -1360.83382211 -1219.1449741
  -931.09510624 -1051.62123469 -1482.74130698]


training loop:  33% |##############                             | ETA:  2:02:04

Episode: 660, score: -1316.742350
[-1537.66093553 -1610.68639936 -1088.44862568 -1296.64018297 -1578.9440308
 -1609.4468893   -952.543213    -859.56852297]


training loop:  34% |##############                             | ETA:  2:00:12

Episode: 680, score: -1327.250777
[-1054.97900696  -971.71153439 -1710.32022819 -1442.63633267 -1015.7318176
 -1461.32908427 -1533.74969202 -1427.54851879]


training loop:  35% |###############                            | ETA:  1:58:25

Episode: 700, score: -1209.919486
[-1562.10572067 -1374.75560369 -1070.96280406 -1230.03322017  -998.6241741
 -1085.40686246 -1485.27532305  -872.19218219]


training loop:  36% |###############                            | ETA:  1:56:38

Episode: 720, score: -1323.041588
[-1599.68745628 -1592.25375223 -1518.41612485  -761.64989352 -1007.92681588
 -1746.70241686 -1493.27984658  -864.41639392]


training loop:  36% |###############                            | ETA:  1:54:54

In [None]:
pong_utils.play(env, policy, time=200) 

In [None]:
# save your policy!
torch.save(policy, 'PPO.policy')

# load policy if needed
# policy = torch.load('PPO.policy')

# try and test out the solution 
# make sure GPU is enabled, otherwise loading will fail
# (the PPO verion can win more often than not)!
#
# policy_solution = torch.load('PPO_solution.policy')
# pong_utils.play(env, policy_solution, time=2000) 