In [5]:
import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

# Cart Pole
'''
parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='interval between training status logs (default: 10)')
args = parser.parse_args()
'''

args={'gamma':1, 'seed':1234, 'log_interval':1}


env = gym.make('CartPole-v0')
env.seed(args['seed'])
torch.manual_seed(args['seed'])


SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])


class Policy(nn.Module):
    """
    implements both actor and critic in one model
    """
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)

        # actor's layer
        self.action_head = nn.Linear(128, 2)

        # critic's layer
        self.value_head = nn.Linear(128, 1)

        # action & reward buffer
        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        """
        forward of both actor and critic
        """
        x = F.relu(self.affine1(x))

        # actor: choses action to take from state s_t 
        # by returning probability of each action
        action_prob = F.softmax(self.action_head(x), dim=-1)

        # critic: evaluates being in the state s_t
        state_values = self.value_head(x)

        # return values for both actor and critic as a tupel of 2 values:
        # 1. a list with the probability of each action over the action space
        # 2. the value from state s_t 
        return action_prob, state_values


model = Policy()
optimizer = optim.Adam(model.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()


def select_action(state):
    state = torch.from_numpy(state).float()
    probs, state_value = model(state)

    # create a categorical distribution over the list of probabilities of actions
    m = Categorical(probs)

    # and sample an action using the distribution
    action = m.sample()

    # save to action buffer
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))

    # the action to take (left or right)
    return action.item()


def finish_episode():
    """
    Training code. Calcultes actor and critic loss and performs backprop.
    """
    R = 0
    saved_actions = model.saved_actions
    policy_losses = [] # list to save actor (policy) loss
    value_losses = [] # list to save critic (value) loss
    returns = [] # list to save the true values

    # calculate the true value using rewards returned from the environment
    for r in model.rewards[::-1]:
        # calculate the discounted value
        R = r + args['gamma'] * R
        returns.insert(0, R)

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)

    for (log_prob, value), R in zip(saved_actions, returns):
        advantage = R - value.item()

        # calculate actor (policy) loss 
        policy_losses.append(-log_prob * advantage)

        # calculate critic (value) loss using L1 smooth loss
        value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))

    # reset gradients
    optimizer.zero_grad()

    # sum up all the values of policy_losses and value_losses
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

    # perform backprop
    loss.backward()
    optimizer.step()

    # reset rewards and action buffer
    del model.rewards[:]
    del model.saved_actions[:]


def main():
    running_reward = 10

    # run inifinitely many episodes
    for i_episode in count(1):

        # reset environment and episode reward
        state = env.reset()
        ep_reward = 0

        # for each episode, only run 9999 steps so that we don't 
        # infinite loop while learning
        for t in range(1, 10000):

            # select action from policy
            action = select_action(state)

            # take the action
            state, reward, done, _ = env.step(action)

            #if args.render:
            #    env.render()

            model.rewards.append(reward)
            ep_reward += reward
            if done:
                break

        # update cumulative reward
        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward

        # perform backprop
        finish_episode()

        # log results
        if i_episode % args['log_interval'] == 0:
            print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
                  i_episode, ep_reward, running_reward))

        # check if we have "solved" the cart pole problem
        if running_reward > env.spec.reward_threshold:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))
            break


if __name__ == '__main__':
    main()

Episode 1	Last reward: 16.00	Average reward: 10.30
Episode 2	Last reward: 12.00	Average reward: 10.38
Episode 3	Last reward: 12.00	Average reward: 10.47
Episode 4	Last reward: 8.00	Average reward: 10.34
Episode 5	Last reward: 9.00	Average reward: 10.28
Episode 6	Last reward: 10.00	Average reward: 10.26
Episode 7	Last reward: 10.00	Average reward: 10.25
Episode 8	Last reward: 8.00	Average reward: 10.14
Episode 9	Last reward: 10.00	Average reward: 10.13
Episode 10	Last reward: 10.00	Average reward: 10.12
Episode 11	Last reward: 9.00	Average reward: 10.07
Episode 12	Last reward: 9.00	Average reward: 10.01
Episode 13	Last reward: 9.00	Average reward: 9.96
Episode 14	Last reward: 10.00	Average reward: 9.96
Episode 15	Last reward: 8.00	Average reward: 9.87
Episode 16	Last reward: 8.00	Average reward: 9.77
Episode 17	Last reward: 10.00	Average reward: 9.78
Episode 18	Last reward: 9.00	Average reward: 9.75
Episode 19	Last reward: 9.00	Average reward: 9.71
Episode 20	Last reward: 10.00	Average 

Episode 191	Last reward: 8.00	Average reward: 9.43
Episode 192	Last reward: 10.00	Average reward: 9.46
Episode 193	Last reward: 10.00	Average reward: 9.48
Episode 194	Last reward: 8.00	Average reward: 9.41
Episode 195	Last reward: 8.00	Average reward: 9.34
Episode 196	Last reward: 9.00	Average reward: 9.32
Episode 197	Last reward: 10.00	Average reward: 9.36
Episode 198	Last reward: 10.00	Average reward: 9.39
Episode 199	Last reward: 9.00	Average reward: 9.37
Episode 200	Last reward: 10.00	Average reward: 9.40
Episode 201	Last reward: 9.00	Average reward: 9.38
Episode 202	Last reward: 9.00	Average reward: 9.36
Episode 203	Last reward: 10.00	Average reward: 9.39
Episode 204	Last reward: 9.00	Average reward: 9.37
Episode 205	Last reward: 10.00	Average reward: 9.40
Episode 206	Last reward: 9.00	Average reward: 9.38
Episode 207	Last reward: 10.00	Average reward: 9.42
Episode 208	Last reward: 10.00	Average reward: 9.44
Episode 209	Last reward: 10.00	Average reward: 9.47
Episode 210	Last rewa

Episode 366	Last reward: 10.00	Average reward: 9.58
Episode 367	Last reward: 10.00	Average reward: 9.60
Episode 368	Last reward: 9.00	Average reward: 9.57
Episode 369	Last reward: 9.00	Average reward: 9.54
Episode 370	Last reward: 10.00	Average reward: 9.57
Episode 371	Last reward: 10.00	Average reward: 9.59
Episode 372	Last reward: 10.00	Average reward: 9.61
Episode 373	Last reward: 8.00	Average reward: 9.53
Episode 374	Last reward: 10.00	Average reward: 9.55
Episode 375	Last reward: 10.00	Average reward: 9.57
Episode 376	Last reward: 9.00	Average reward: 9.55
Episode 377	Last reward: 9.00	Average reward: 9.52
Episode 378	Last reward: 8.00	Average reward: 9.44
Episode 379	Last reward: 9.00	Average reward: 9.42
Episode 380	Last reward: 10.00	Average reward: 9.45
Episode 381	Last reward: 10.00	Average reward: 9.48
Episode 382	Last reward: 10.00	Average reward: 9.50
Episode 383	Last reward: 9.00	Average reward: 9.48
Episode 384	Last reward: 11.00	Average reward: 9.55
Episode 385	Last rew

Episode 548	Last reward: 9.00	Average reward: 9.07
Episode 549	Last reward: 9.00	Average reward: 9.06
Episode 550	Last reward: 10.00	Average reward: 9.11
Episode 551	Last reward: 10.00	Average reward: 9.16
Episode 552	Last reward: 10.00	Average reward: 9.20
Episode 553	Last reward: 11.00	Average reward: 9.29
Episode 554	Last reward: 10.00	Average reward: 9.32
Episode 555	Last reward: 9.00	Average reward: 9.31
Episode 556	Last reward: 9.00	Average reward: 9.29
Episode 557	Last reward: 9.00	Average reward: 9.28
Episode 558	Last reward: 10.00	Average reward: 9.31
Episode 559	Last reward: 9.00	Average reward: 9.30
Episode 560	Last reward: 9.00	Average reward: 9.28
Episode 561	Last reward: 10.00	Average reward: 9.32
Episode 562	Last reward: 10.00	Average reward: 9.35
Episode 563	Last reward: 9.00	Average reward: 9.34
Episode 564	Last reward: 10.00	Average reward: 9.37
Episode 565	Last reward: 9.00	Average reward: 9.35
Episode 566	Last reward: 10.00	Average reward: 9.38
Episode 567	Last rewa

Episode 712	Last reward: 9.00	Average reward: 9.61
Episode 713	Last reward: 9.00	Average reward: 9.58
Episode 714	Last reward: 8.00	Average reward: 9.50
Episode 715	Last reward: 10.00	Average reward: 9.53
Episode 716	Last reward: 10.00	Average reward: 9.55
Episode 717	Last reward: 8.00	Average reward: 9.47
Episode 718	Last reward: 10.00	Average reward: 9.50
Episode 719	Last reward: 10.00	Average reward: 9.52
Episode 720	Last reward: 8.00	Average reward: 9.45
Episode 721	Last reward: 10.00	Average reward: 9.48
Episode 722	Last reward: 10.00	Average reward: 9.50
Episode 723	Last reward: 9.00	Average reward: 9.48
Episode 724	Last reward: 10.00	Average reward: 9.50
Episode 725	Last reward: 9.00	Average reward: 9.48
Episode 726	Last reward: 8.00	Average reward: 9.40
Episode 727	Last reward: 9.00	Average reward: 9.38
Episode 728	Last reward: 10.00	Average reward: 9.41
Episode 729	Last reward: 10.00	Average reward: 9.44
Episode 730	Last reward: 10.00	Average reward: 9.47
Episode 731	Last rewa

Episode 872	Last reward: 9.00	Average reward: 9.29
Episode 873	Last reward: 10.00	Average reward: 9.33
Episode 874	Last reward: 9.00	Average reward: 9.31
Episode 875	Last reward: 8.00	Average reward: 9.24
Episode 876	Last reward: 9.00	Average reward: 9.23
Episode 877	Last reward: 9.00	Average reward: 9.22
Episode 878	Last reward: 10.00	Average reward: 9.26
Episode 879	Last reward: 9.00	Average reward: 9.25
Episode 880	Last reward: 8.00	Average reward: 9.18
Episode 881	Last reward: 10.00	Average reward: 9.22
Episode 882	Last reward: 10.00	Average reward: 9.26
Episode 883	Last reward: 9.00	Average reward: 9.25
Episode 884	Last reward: 9.00	Average reward: 9.24
Episode 885	Last reward: 11.00	Average reward: 9.33
Episode 886	Last reward: 11.00	Average reward: 9.41
Episode 887	Last reward: 9.00	Average reward: 9.39
Episode 888	Last reward: 9.00	Average reward: 9.37
Episode 889	Last reward: 10.00	Average reward: 9.40
Episode 890	Last reward: 9.00	Average reward: 9.38
Episode 891	Last reward:

Episode 1051	Last reward: 9.00	Average reward: 9.02
Episode 1052	Last reward: 9.00	Average reward: 9.02
Episode 1053	Last reward: 9.00	Average reward: 9.02
Episode 1054	Last reward: 9.00	Average reward: 9.02
Episode 1055	Last reward: 8.00	Average reward: 8.97
Episode 1056	Last reward: 9.00	Average reward: 8.97
Episode 1057	Last reward: 10.00	Average reward: 9.02
Episode 1058	Last reward: 10.00	Average reward: 9.07
Episode 1059	Last reward: 10.00	Average reward: 9.12
Episode 1060	Last reward: 10.00	Average reward: 9.16
Episode 1061	Last reward: 8.00	Average reward: 9.10
Episode 1062	Last reward: 10.00	Average reward: 9.15
Episode 1063	Last reward: 10.00	Average reward: 9.19
Episode 1064	Last reward: 10.00	Average reward: 9.23
Episode 1065	Last reward: 9.00	Average reward: 9.22
Episode 1066	Last reward: 9.00	Average reward: 9.21
Episode 1067	Last reward: 10.00	Average reward: 9.25
Episode 1068	Last reward: 11.00	Average reward: 9.34
Episode 1069	Last reward: 10.00	Average reward: 9.37
Ep

Episode 1210	Last reward: 29.00	Average reward: 12.53
Episode 1211	Last reward: 35.00	Average reward: 13.66
Episode 1212	Last reward: 29.00	Average reward: 14.42
Episode 1213	Last reward: 19.00	Average reward: 14.65
Episode 1214	Last reward: 20.00	Average reward: 14.92
Episode 1215	Last reward: 21.00	Average reward: 15.22
Episode 1216	Last reward: 16.00	Average reward: 15.26
Episode 1217	Last reward: 16.00	Average reward: 15.30
Episode 1218	Last reward: 15.00	Average reward: 15.29
Episode 1219	Last reward: 9.00	Average reward: 14.97
Episode 1220	Last reward: 10.00	Average reward: 14.72
Episode 1221	Last reward: 11.00	Average reward: 14.54
Episode 1222	Last reward: 14.00	Average reward: 14.51
Episode 1223	Last reward: 10.00	Average reward: 14.28
Episode 1224	Last reward: 13.00	Average reward: 14.22
Episode 1225	Last reward: 15.00	Average reward: 14.26
Episode 1226	Last reward: 18.00	Average reward: 14.45
Episode 1227	Last reward: 14.00	Average reward: 14.42
Episode 1228	Last reward: 18.

Episode 1364	Last reward: 33.00	Average reward: 42.80
Episode 1365	Last reward: 48.00	Average reward: 43.06
Episode 1366	Last reward: 57.00	Average reward: 43.76
Episode 1367	Last reward: 66.00	Average reward: 44.87
Episode 1368	Last reward: 39.00	Average reward: 44.58
Episode 1369	Last reward: 24.00	Average reward: 43.55
Episode 1370	Last reward: 42.00	Average reward: 43.47
Episode 1371	Last reward: 25.00	Average reward: 42.55
Episode 1372	Last reward: 63.00	Average reward: 43.57
Episode 1373	Last reward: 52.00	Average reward: 43.99
Episode 1374	Last reward: 60.00	Average reward: 44.79
Episode 1375	Last reward: 27.00	Average reward: 43.90
Episode 1376	Last reward: 33.00	Average reward: 43.36
Episode 1377	Last reward: 30.00	Average reward: 42.69
Episode 1378	Last reward: 42.00	Average reward: 42.66
Episode 1379	Last reward: 34.00	Average reward: 42.22
Episode 1380	Last reward: 45.00	Average reward: 42.36
Episode 1381	Last reward: 35.00	Average reward: 41.99
Episode 1382	Last reward: 17

Episode 1513	Last reward: 200.00	Average reward: 193.20
Episode 1514	Last reward: 200.00	Average reward: 193.54
Episode 1515	Last reward: 200.00	Average reward: 193.86
Episode 1516	Last reward: 200.00	Average reward: 194.17
Episode 1517	Last reward: 200.00	Average reward: 194.46
Episode 1518	Last reward: 200.00	Average reward: 194.74
Episode 1519	Last reward: 200.00	Average reward: 195.00
Episode 1520	Last reward: 200.00	Average reward: 195.25
Solved! Running reward is now 195.24981340889315 and the last episode runs to 200 time steps!
