In [1]:
# !pip install gymnasium
# !pip install --upgrade swig
# !pip install --upgrade box2D
# !pip3 install box2d box2d-kengz
# !pip install pygame
import gymnasium as gym
import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        layers = nn.ModuleList()

        # Input layer
        layers.append(nn.Linear(input_size, hidden_sizes[0]))
        layers.append(nn.ReLU())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            layers.append(nn.ReLU())

        # Output layer
        layers.append(nn.Linear(hidden_sizes[-1], output_size))

        # Combine all layers
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [36]:
def rev_cumsum(x):
    return np.cumsum(x[::-1])[::-1] 
    
def compute_loss(policy, obs, act, weight):
    log_prob = Categorical(logits = policy(obs)).log_prob(act)
    return -1 * torch.sum(weight * log_prob)
    
def episode_train(env, policy, t_lim):
    obs, info = env.reset()

    obs_all = []
    act_all = []
    reward_all = []
    
    n_steps = 0
    while n_steps < t_lim:
        logits = policy.forward(torch.Tensor(obs))
        dist = Categorical(logits=logits)
        action = dist.sample().item()

        obs_all.append(obs)
        act_all.append(action)
        
        obs, reward, terminated, truncated, info = env.step(action)
        reward_all.append(reward)
        
        if terminated or truncated:
            break
            
    return obs_all, act_all, reward_all

def epoch_train(env, policy, optimizer,  n_samples = 1000):

    obs_all = []
    act_all = []
    weight_all = []

    optimizer.zero_grad()

    n_episodes = 0
    while len(obs_all) < n_samples:
        t_lim = n_samples - len(obs_all)
        obs, act, reward = episode_train(env, policy, t_lim)
        
        obs_all.extend(obs)
        act_all.extend(act)
        
        weights = rev_cumsum(reward)
        weight_all.extend(weights)
        n_episodes += 1

    loss = compute_loss(
        policy,
        torch.tensor(obs_all),
        torch.tensor(act_all),
        torch.tensor(weight_all))

    loss.backward()
    optimizer.step()

    return n_episodes
    

In [41]:
policy= MLP(4,[32],2)

env = gym.make("CartPole-v1", render_mode=None)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-2)

In [42]:
obs, info = env.reset()

In [43]:
n_epochs = 100
eps_all = []
n_samples = 10000

for i in range(n_epochs):
    this_n_episodes = epoch_train(env, policy, optimizer,  n_samples)
    eps_all.append(this_n_episodes)
    print(f"epoch {i}: {n_samples/this_n_episodes:0.0f} steps per episode")

epoch 0: 25 steps per episode
epoch 1: 25 steps per episode
epoch 2: 28 steps per episode
epoch 3: 31 steps per episode
epoch 4: 31 steps per episode
epoch 5: 34 steps per episode
epoch 6: 36 steps per episode
epoch 7: 39 steps per episode
epoch 8: 43 steps per episode
epoch 9: 42 steps per episode
epoch 10: 45 steps per episode
epoch 11: 49 steps per episode
epoch 12: 47 steps per episode
epoch 13: 55 steps per episode
epoch 14: 55 steps per episode
epoch 15: 57 steps per episode
epoch 16: 59 steps per episode
epoch 17: 61 steps per episode
epoch 18: 60 steps per episode
epoch 19: 66 steps per episode
epoch 20: 69 steps per episode
epoch 21: 71 steps per episode
epoch 22: 76 steps per episode
epoch 23: 83 steps per episode
epoch 24: 86 steps per episode
epoch 25: 92 steps per episode
epoch 26: 94 steps per episode
epoch 27: 109 steps per episode
epoch 28: 119 steps per episode
epoch 29: 135 steps per episode
epoch 30: 175 steps per episode
epoch 31: 161 steps per episode
epoch 32: 192

In [44]:
env = gym.make("CartPole-v1", render_mode="human")
obs, info = env.reset()

while True:
    logits = policy.forward(torch.Tensor(obs))
    dist = Categorical(logits=logits)
    action = dist.sample().item()

    obs, reward, terminated, truncated, info = env.step(action)

    if terminated or truncated:
        break