In [75]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.distributions.categorical import Categorical


In [76]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [98]:
# environment parameters
env_name="CartPole-v1"
env = gym.make(env_name)
obs_dim = env.observation_space.shape[0]
acts_dim = env.action_space.n
print(f"{obs_dim} obs | {acts_dim} acts")

# mlp parameters
hidden_sizes = [32]
sizes = [obs_dim]+hidden_sizes+[n_acts]

# training parameters
epochs=50
batch_size=5000
lr = 1e-2

4 obs | 2 acts


In [78]:
class MLP(nn.Module):
    def __init__(self, sizes, activation=nn.Tanh, output_activation=nn.Identity):
        super().__init__()
        layers = []
        for j in range(len(sizes)-1):
            act = activation if j < len(sizes)-2 else output_activation
            layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
        self.layers = nn.Sequential(*layers)
        

    def forward(self, x):
        return self.layers(x)


In [96]:

def get_policy(obs):
    logits = model(obs)
    return Categorical(logits=logits)

def get_action(obs): # only one observation as input
    # you can remove .item() to pass batch of obs as input
    return get_policy(obs).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    logp = get_policy(obs).log_prob(act)
    return -(logp * weights).mean()


In [99]:
model = MLP(sizes)

optimizer = Adam(model.parameters(), lr=lr)

# training loop
for i in range(epochs):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs, _ = env.reset()    # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        batch_obs.append(obs)
        
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _, _ = env.step(act)
        
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            
            # the weight for each logprob(a|s) is R(tau)
            batch_weights += [ep_ret] * ep_len
            
            obs, _ = env.reset()
            done = False
            ep_rews = []
            
            if len(batch_obs) > batch_size:
                break
                
    # take a single policy gradient update step using the experience gained
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                              )
    batch_loss.backward()
    optimizer.step()
    
    print("epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f"%
            (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))


epoch:   0 	 loss: 22.116 	 return: 24.539 	 ep_len: 24.539
epoch:   1 	 loss: 28.380 	 return: 29.263 	 ep_len: 29.263
epoch:   2 	 loss: 29.597 	 return: 31.854 	 ep_len: 31.854
epoch:   3 	 loss: 35.768 	 return: 37.463 	 ep_len: 37.463
epoch:   4 	 loss: 38.804 	 return: 41.372 	 ep_len: 41.372
epoch:   5 	 loss: 41.675 	 return: 49.176 	 ep_len: 49.176
epoch:   6 	 loss: 39.246 	 return: 48.385 	 ep_len: 48.385
epoch:   7 	 loss: 47.756 	 return: 61.084 	 ep_len: 61.084
epoch:   8 	 loss: 45.948 	 return: 59.094 	 ep_len: 59.094
epoch:   9 	 loss: 56.869 	 return: 68.446 	 ep_len: 68.446
epoch:  10 	 loss: 48.076 	 return: 63.899 	 ep_len: 63.899
epoch:  11 	 loss: 52.226 	 return: 68.284 	 ep_len: 68.284
epoch:  12 	 loss: 45.696 	 return: 67.716 	 ep_len: 67.716
epoch:  13 	 loss: 50.356 	 return: 72.768 	 ep_len: 72.768
epoch:  14 	 loss: 52.602 	 return: 72.657 	 ep_len: 72.657
epoch:  15 	 loss: 59.273 	 return: 80.871 	 ep_len: 80.871
epoch:  16 	 loss: 66.817 	 return: 93.7

In [None]:
PATH = f"{env_name}_{ep_len}"
torch.save(model, PATH)
