In [1]:
import gymnasium as gym

env_name = 'CartPole-v1'
env = gym.make(env_name)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import matplotlib.pyplot as plt

from itertools import count
from collections import namedtuple

from IPython import display

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])


class Policy(nn.Module):
  def __init__(self):
    super(Policy, self).__init__()
    self.affine1 = nn.Linear(4, 128)
    # actor's layer
    self.action_head = nn.Linear(128, 2)
    # critic's layer
    self.value_head = nn.Linear(128, 1)

  def forward(self, x):
    x = F.relu(self.affine1(x))

    # actor: choses action to take from state s_t
    # by returning probability of each action
    action_prob = F.softmax(self.action_head(x), dim=-1)

    # critic: evaluates being in the state s_t
    state_values = self.value_head(x)

    # return values for both actor and critic as a tuple of 2 values:
    # 1. a list with the probability of each action over the action space
    # 2. the value from state s_t
    return action_prob, state_values


saved_actions = []
rewards = []

def select_action(state):
  state = torch.from_numpy(state).float()
  probs, state_value = model(state)

  # create a categorical distribution over the list of probabilities of actions
  m = Categorical(probs)

  # and sample an action using the distribution
  action = m.sample()

  # save to action buffer
  saved_actions.append(SavedAction(m.log_prob(action), state_value))

  # the action to take (left or right)
  return action.item()


NUM_EPISODE = 1000
GAMMA = 0.99
MAX_TIME = 3000

model = Policy()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
eps = np.finfo(np.float32).eps.item()


episode_durations = []


def plot_durations(show_result=False):
  plt.figure(1)
  durations_t = torch.tensor(episode_durations, dtype=torch.float)
  if show_result:
    plt.title('Result')
  else:
    plt.clf()
    plt.title('Training...')
  plt.xlabel('Episode')
  plt.ylabel('Duration')
  plt.plot(durations_t.numpy())
  # Take 100 episode averages and plot them too
  if len(durations_t) >= 100:
    means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
    means = torch.cat((torch.zeros(99), means))
    plt.plot(means.numpy())

  plt.pause(0.001)  # pause a bit so that plots are updated
  if not show_result:
    display.display(plt.gcf())
    display.clear_output(wait=True)
  else:
    display.display(plt.gcf())


def update_model():
  global rewards, saved_actions
  """
  Training code. Calculates actor and critic loss and performs backprop.
  """
  R = 0
  policy_losses = [] # list to save actor (policy) loss
  value_losses = [] # list to save critic (value) loss
  returns = [] # list to save the true values

  # calculate the true value using rewards returned from the environment
  for r in rewards[::-1]:
    # calculate the discounted value
    R = r + GAMMA * R
    returns.insert(0, R)

  returns = torch.tensor(returns)
  returns = (returns - returns.mean()) / (returns.std() + eps)

  for (log_prob, value), R in zip(saved_actions, returns):
    advantage = R - value.item()
    
    # calculate actor (policy) loss
    policy_losses.append(-log_prob * advantage)

    # calculate critic (value) loss using L1 smooth loss
    value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))

  # reset gradients
  optimizer.zero_grad()

  # sum up all the values of policy_losses and value_losses
  loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

  # perform backprop
  loss.backward()
  optimizer.step()

  # reset rewards and action buffer
  del rewards[:]
  del saved_actions[:]
  rewards = []
  saved_actions = []


def run():
  for episode in range(NUM_EPISODE):
    # reset environment and episode reward
    state, _ = env.reset()
    ep_reward = 0
    for t in count():
      action = select_action(state)

      # take the action
      state, reward, done, _, _ = env.step(action)

      rewards.append(reward)
      if done or t >= MAX_TIME:
        episode_durations.append(t + 1)
        plot_durations()
        break

    update_model()

run()

<Figure size 640x480 with 0 Axes>

In [3]:
torch.save(model.state_dict(), './actor_critic.pt')