In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym

In [12]:
class Policy(nn.Module):
  def __init__(self, in_dim, out_dim):
    super(Policy, self).__init__()

    layers = [
              nn.Linear(in_dim, 128),
              nn.ReLU(),
              nn.Linear(128, out_dim)
    ]
    self.model = nn.Sequential(*layers)
    self.onpolicy_reset()
    self.train()

  def onpolicy_reset(self):
    self.log_probabilities = []
    self.rewards = []

  def forward(self, x):
    return self.model(x)

  def take_action(self, current_state):
    x = torch.from_numpy(current_state.astype('float32'))
    probability_distribution_parameters = self.forward(x)
    probability_distribution = Categorical(logits=probability_distribution_parameters)
    action = probability_distribution.sample()
    log_probability = probability_distribution.log_prob(action)
    self.log_probabilities.append(log_probability)
    return action.item()

In [13]:
gamma = 0.99

def train(policy, optimizer):
  tau = len(policy.rewards)
  returns = np.empty(tau, dtype=np.float32)
  future_returns = 0.0

  for t in reversed(range(tau)):
    future_returns = policy.rewards[t] + gamma * future_returns
    returns[t] = future_returns

  returns = torch.tensor(returns)
  log_probabilities = torch.stack(policy.log_probabilities)

  loss = -log_probabilities * returns
  loss = torch.sum(loss)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  return loss

In [18]:
def main():
  env = gym.make('CartPole-v0')
  in_dim = env.observation_space.shape[0]
  out_dim = env.action_space.n
  policy = Policy(in_dim, out_dim)
  optimizer = optim.Adam(policy.parameters(), lr=0.01)

  for episode in range(300):
    state = env.reset()

    for t in range(200):
      action = policy.take_action(state)
      state, reward, done, _ = env.step(action)
      policy.rewards.append(reward)
      env.render()

      if done:
        break

  loss = train(policy, optimizer)
  total_reward = sum(policy.rewards)
  solved = total_reward > 195.0

  policy.onpolicy_reset()
  print(f'Episode {episode}, loss: {loss}, total_reward: {total_reward}, action: {action}, solved: {solved}')

In [19]:
main()

Episode 299, loss: 380400.375, total_reward: 5751.0, action: 0, solved: True
