In [2]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.distributions.categorical import Categorical


In [3]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [5]:
# environment parameters
env_name="CartPole-v1"
env = gym.make(env_name)
obs_dim = env.observation_space.shape[0]
acts_dim = env.action_space.n
print(f"{obs_dim} obs | {acts_dim} acts")

# mlp parameters
hidden_sizes = [32]
sizes = [obs_dim]+hidden_sizes+[acts_dim]

# training parameters
epochs=50
batch_size=5000
lr = 1e-2

4 obs | 2 acts


In [6]:
class MLP(nn.Module):
    def __init__(self, sizes, activation=nn.Tanh, output_activation=nn.Identity):
        super().__init__()
        layers = []
        for j in range(len(sizes)-1):
            act = activation if j < len(sizes)-2 else output_activation
            layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
        self.layers = nn.Sequential(*layers)
        

    def forward(self, x):
        return self.layers(x)


In [7]:

def get_policy(obs):
    logits = model(obs)
    return Categorical(logits=logits)

def get_action(obs): # only one observation as input
    # you can remove .item() to pass batch of obs as input
    return get_policy(obs).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    logp = get_policy(obs).log_prob(act)
    return -(logp * weights).mean()


In [8]:
model = MLP(sizes)

optimizer = Adam(model.parameters(), lr=lr)

# training loop
for i in range(epochs):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs, _ = env.reset()    # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        batch_obs.append(obs)
        
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _, _ = env.step(act)
        
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            
            # the weight for each logprob(a|s) is R(tau)
            batch_weights += [ep_ret] * ep_len
            
            obs, _ = env.reset()
            done = False
            ep_rews = []
            
            if len(batch_obs) > batch_size:
                break
                
    # take a single policy gradient update step using the experience gained
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                              )
    batch_loss.backward()
    optimizer.step()
    
    print("epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f"%
            (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))


  batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),


epoch:   0 	 loss: 16.430 	 return: 19.913 	 ep_len: 19.913
epoch:   1 	 loss: 19.935 	 return: 22.734 	 ep_len: 22.734
epoch:   2 	 loss: 22.266 	 return: 25.646 	 ep_len: 25.646
epoch:   3 	 loss: 27.589 	 return: 28.432 	 ep_len: 28.432
epoch:   4 	 loss: 31.144 	 return: 32.732 	 ep_len: 32.732
epoch:   5 	 loss: 32.936 	 return: 34.628 	 ep_len: 34.628
epoch:   6 	 loss: 37.001 	 return: 40.320 	 ep_len: 40.320
epoch:   7 	 loss: 44.630 	 return: 47.886 	 ep_len: 47.886
epoch:   8 	 loss: 53.061 	 return: 57.575 	 ep_len: 57.575
epoch:   9 	 loss: 51.824 	 return: 59.282 	 ep_len: 59.282
epoch:  10 	 loss: 60.625 	 return: 65.961 	 ep_len: 65.961
epoch:  11 	 loss: 60.103 	 return: 71.014 	 ep_len: 71.014
epoch:  12 	 loss: 56.729 	 return: 70.493 	 ep_len: 70.493
epoch:  13 	 loss: 65.323 	 return: 78.250 	 ep_len: 78.250
epoch:  14 	 loss: 67.315 	 return: 84.867 	 ep_len: 84.867
epoch:  15 	 loss: 75.611 	 return: 104.367 	 ep_len: 104.367
epoch:  16 	 loss: 91.308 	 return: 11

In [9]:
PATH = f"{env_name}_{ep_len}"
torch.save(model, PATH)