In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import collections
import gym
import numpy as np

  for external in metadata.entry_points().get(self.group, []):


In [2]:
SavedAction = collections.namedtuple('SavedAction', ['log_prob', 'value'])

In [3]:
env = gym.make('CartPole-v0') # We make the Cartpole environment here

In [4]:
print("There are {} actions".format(env.action_space.n))

There are 2 actions


In [5]:
# You can move either left or right to balance the pole
# Lets implement the Actor critic network
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 128) # 4 because there are 4 parameters as the observation space
        self.actor = nn.Linear(128, 2) # 2 for the number of actions
        self.critic = nn.Linear(128, 1) # Critic is always 1
        self.saved_actions = []
        self.rewards = []
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_prob = F.softmax(self.actor(x), dim=-1)
        state_values = self.critic(x)
        return action_prob, state_values

In [6]:
def select_action(state):
    state = torch.from_numpy(state).float()
    probs, state_value = model(state)
    m = Categorical(probs)
    action = m.sample()
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.item()

In [7]:
def finish_episode():
    # We calculate the losses and perform backprop in this function
    R = 0
    saved_actions = model.saved_actions
    policy_losses = []
    value_losses =[]
    returns = []
    
    for r in model.rewards[::-1]:
        R = r + 0.99 * R # 0.99 is our gamma number
        returns.insert(0, R)
        
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    
    for (log_prob, value), R in zip(saved_actions, returns):
        advantage = R - value.item()
        
        policy_losses.append(-log_prob * advantage)
        value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))
    
    optimizer.zero_grad()
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
    
    loss.backward()
    optimizer.step()
    
    del model.rewards[:]
    del model.saved_actions[:]

In [8]:
model = ActorCritic()
optimizer = optim.Adam(model.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()

In [9]:
def train():
    running_reward = 10
    for i_episode in range(1, 10): # We need around this much episodes
        state = env.reset()
        print(state)
        ep_reward = 0
        for t in range(1, 10000):
            action = select_action(state)
            state, reward, done, _ = env.step(action)
            model.rewards.append(reward)
            ep_reward += reward
            if done:
                break

        running_reward = 0.05 * ep_reward + (1-0.05) * running_reward
        
        finish_episode()

        if i_episode % 10 == 0: # We will print some things out
            print("Episode {}\tLast Reward: {:.2f}\tAverage reward: {:.2f}".format(
                i_episode, ep_reward, running_reward
            ))
        if running_reward > env.spec.reward_threshold:
            print("Solved, running reward is now {} and the last episode runs to {} time steps".format(
                    running_reward, t
            ))
            break
            # This means that we solved cartpole and training is complete

In [10]:
train()

[-0.04563867 -0.03841383  0.01112975 -0.02404055]
[ 0.0220001  -0.03506866 -0.00972496 -0.03419371]
[-0.0394846  -0.0235162  -0.02118999 -0.01162963]
[-0.01321444  0.02298042 -0.02189541  0.03425798]
[-0.00936118 -0.0059507   0.03475915  0.02425839]
[ 0.04100468  0.01323972 -0.00444313 -0.02121717]
[0.00785876 0.03894312 0.04166164 0.03540104]
[0.04792369 0.03203343 0.04467328 0.00062689]
[-0.0098329  -0.00442624  0.04590943 -0.00255707]


In [24]:
# There. we finished
# Lets see it in action
done = False
cnt = 0

In [37]:
observation = env.reset()
while True:
    cnt += 1
    env.render()
    action = select_action(observation)
    print(action)
    observation, reward, done, _ = env.step(action)
    # Lets see how long it lasts until failing
print(f"Game lasted {cnt} moves")

0
0
0
0
1
0
1
1
0
1
1
1
1
0
0
1
0
1
1
1
1
1
1
0
1
0
1
0
1
0
0
0
0
1
0
0
0
1
1
1
1
0
0
0
0
1
1
1
0
0
1
0
1
1
1
0
1
1
1
0
1
0
0
0
1
1
0
0
0
1
0
0
1
1
1
0
1
1
0
1
1
0
0
1
0
1
0
1
0
1
0
0
0
0
0
1
1
1
1
0
1
0
1
1
0
1
0
1
0
0
1
0
1
0
0
1
0
1
0
1
0
1
0
0
1
0
1
1
0
1
1
1
0
0
0
0
0
1
1
1
0
1
1
1
1
0
1
0
1
0
0
1
1
0
1
0
0
0
1
0
1
0
1
1
0
0
0
0
1
0
0
0
1
1
1
1
1
0
0
1
0
0
0
1
0
1
0
0
1
1
1
1
0
1
0
1
1
1
0
1
0
1
1
1
0
1
0
1
1
0
1
0
0
1
1
0
1
0
0
1
0
0
1
1
0
1
0
0
0
1
0
1
0
0
0
0
0
1
1
1
0
0
0
0
0
0
1
1
1
0
1
1
1
0
0
0
1
1
1
0
1
0
0
1
1
0
1
1
1
1
0
1
0
1
1
0
1
0
0
1
0
1
1
0
1
0
0
0
1
0
0
1
0
1
0
1
0
1
0
1
0
1
0
0
0
0
1
0
1
0
1
1
1
1
0
0
0
1
1
0
1
1
0
1
0
1
0
1
0
0
0
1
1
1
1
0
0
1
0
1
0
0
1
0
0
1
1
1
0
1
1
0
1
0
0
1
1
0
1
0
1
0
1
0
0
1
0
0
1
1
0
0
0
0
1
0
1
1
1
1
1
1
0
0
0
0
1
1
1
1
1
0
0
1
0
1
0
0
0
1
1
1
0
0
1
0
1
1
0
1
0
0
1
1
0
1
0
0
1
0
0
0
1
1
1
1
0
1
0
0
0
1
0
1
1
0
0
1
0
1
1
1
0
0
1
1
0
1
0
0
0
1
1
1
0
1
0
0
0
1
1
0
1
0
1
0
0
1
1
1
0
1
1
0
0
1
1
0
1
0
1
0
1
0
0
1
0
1
1
0
1
0
1
0
1
0
1
0
1
0


KeyboardInterrupt: 