In [1]:
!pip install --quiet -U gymnasium[classic-control]

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/958.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.1/958.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m952.3/958.1 kB[0m [31m15.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import gymnasium as gym
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [51]:
gamma = 0.99 # Discount factor for rewards

class Policy(nn.Module):
  def __init__(self, in_dim, out_dim):
    super(Policy, self).__init__()
    layers = [
        nn.Linear(in_dim, 64),
        nn.ReLU(),
        nn.Linear(64, out_dim),
    ]
    self.model = nn.Sequential(*layers) # Model is a Multi-Layer Perceptron of size in_dim, 64, out_dim with the Rectified Linear Unit non-linearity
    self.onpolicy_reset()
    self.train() # set training mode

  def onpolicy_reset(self):
    self.log_probs = []
    self.rewards = []

  def forward(self, x):
    return self.model(x)

  def act(self, state):
    x = torch.from_numpy(state.astype(np.float32)) # Transform state to PyTorch Tensor filled with 32 bit floats
    pdparam = self.forward(x) # Forward pass (pass state through model) to generate action probabilities
    pd = Categorical(logits=pdparam) # Create Probability Distribution
    action = pd.sample() # Sample probability distribution [policy(a|s)] to get action
    log_prob = pd.log_prob(action) # log_prob of policy(a|s)
    self.log_probs.append(log_prob) # store for training
    return action.item() # .item() unwraps action's metadata

In [52]:
def train(policy, optimizer):
  # Inner gradient-ascent loop of REINFORCE algorithm (https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)
  T = len(policy.rewards) # Number of timesteps
  rets = np.empty(T, dtype=np.float32) # The returns
  future_ret = 0.0
  # Efficient computation of the returns
  for t in reversed(range(T)):
    future_ret = policy.rewards[t] + gamma * future_ret
    rets[t] = future_ret
  rets = torch.tensor(rets)
  log_probs = torch.stack(policy.log_probs)
  loss = -log_probs * rets # Gradient term -- negative for maximizing, pytorch's optimizer minimizes loss by default
  loss = torch.sum(loss)
  optimizer.zero_grad() # Zero out the gradients from previous training step -- crucial before calling .backward()
  loss.backward() # Backpropagate, compute gradients
  optimizer.step() # Gradient-ascent, update the weights
  return loss

In [62]:
import matplotlib.pyplot as plt

def main():
  env = gym.make('CartPole-v1', render_mode = "rgb_array")
  in_dim = env.observation_space.shape[0] # 4 observations of state per time-step: cart position, cart velocity, pole angle, pole angular velocity
  out_dim = env.action_space.n # 2 possible actions -- move cart left or right
  policy = Policy(in_dim, out_dim)
  optimizer = optim.Adam(policy.parameters(), lr=0.01)
  for episode in range(300):
    state, _ = env.reset()
    for t in range(500): # CartPole max timestep is 500
      action = policy.act(state)
      state, reward, done, _, _ = env.step(action)
      policy.rewards.append(reward)
      #env.render() # This does not work in google colab
      # [janky Colab workaround] Render the current frame and display it
      #frame = env.render()  # Returns the current frame as an ndarray
      #plt.imshow(frame)  # Display using matplotlib
      #plt.axis('off')  # Turn off axes
      #plt.show()
      if done:
        break
    loss = train(policy, optimizer) # train per episode
    total_reward = sum(policy.rewards)
    solved = total_reward > 195.0
    policy.onpolicy_reset() # REINFORCE is an on-policy algorithm, meaning we clear memory after training
    print(f'Episode {episode+1}, loss: {loss}, total_reward: {total_reward}, solved: {solved}')

In [63]:
main()

Episode 1, loss: 49.319091796875, total_reward: 12.0, solved: False
Episode 2, loss: 117.06388092041016, total_reward: 18.0, solved: False
Episode 3, loss: 84.98805236816406, total_reward: 16.0, solved: False
Episode 4, loss: 336.7796936035156, total_reward: 32.0, solved: False
Episode 5, loss: 247.3609619140625, total_reward: 27.0, solved: False
Episode 6, loss: 80.60472869873047, total_reward: 16.0, solved: False
Episode 7, loss: 236.3660430908203, total_reward: 27.0, solved: False
Episode 8, loss: 41.30571365356445, total_reward: 11.0, solved: False
Episode 9, loss: 60.96812438964844, total_reward: 13.0, solved: False
Episode 10, loss: 135.8144073486328, total_reward: 20.0, solved: False
Episode 11, loss: 60.884666442871094, total_reward: 12.0, solved: False
Episode 12, loss: 314.3273620605469, total_reward: 32.0, solved: False
Episode 13, loss: 138.0590057373047, total_reward: 20.0, solved: False
Episode 14, loss: 190.76080322265625, total_reward: 24.0, solved: False
Episode 15, lo