In [2]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.distributions.categorical import Categorical


In [3]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [16]:
# environment parameters
env_name="CartPole-v1"
env = gym.make(env_name)
obs_dim = env.observation_space.shape[0]
acts_dim = env.action_space.n
print(f"{obs_dim} obs | {acts_dim} acts")

# mlp parameters
hidden_sizes = [32]
sizes = [obs_dim]+hidden_sizes+[acts_dim]

# training parameters
epochs=200
batch_size=5000
lr = 1e-2

4 obs | 2 acts


In [6]:
class MLP(nn.Module):
    def __init__(self, sizes, activation=nn.Tanh, output_activation=nn.Identity):
        super().__init__()
        layers = []
        for j in range(len(sizes)-1):
            act = activation if j < len(sizes)-2 else output_activation
            layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
        self.layers = nn.Sequential(*layers)
        

    def forward(self, x):
        return self.layers(x)


In [7]:

def get_policy(obs):
    logits = model(obs)
    return Categorical(logits=logits)

def get_action(obs): # only one observation as input
    # you can remove .item() to pass batch of obs as input
    return get_policy(obs).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    """
    Even though we describe this as a loss function, it is not a loss function in the typical sense from supervised learning.
    1. The data distribution depends on the parameters
    2. It doesn’t measure performance
    """
    logp = get_policy(obs).log_prob(act)
    return -(logp * weights).mean()


# Train

In [17]:
model = MLP(sizes)

optimizer = Adam(model.parameters(), lr=lr)

# training loop
for i in range(epochs):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs, _ = env.reset()    # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        batch_obs.append(obs)
        
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _, _ = env.step(act)
        
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            
            # the weight for each logprob(a|s) is R(tau)
            batch_weights += [ep_ret] * ep_len
            
            obs, _ = env.reset()
            done = False
            ep_rews = []
            
            if len(batch_obs) > batch_size:
                break
                
    # take a single policy gradient update step using the experience gained
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                              )
    batch_loss.backward()
    optimizer.step()
    
    print("epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f"%
            (i, np.mean(batch_rets)))

# why is loss increasing???

epoch:   0 	 loss: 15.016 	 return: 17.847 	 ep_len: 17.847
epoch:   1 	 loss: 16.153 	 return: 19.023 	 ep_len: 19.023
epoch:   2 	 loss: 18.478 	 return: 21.778 	 ep_len: 21.778
epoch:   3 	 loss: 20.816 	 return: 23.787 	 ep_len: 23.787
epoch:   4 	 loss: 23.985 	 return: 27.321 	 ep_len: 27.321
epoch:   5 	 loss: 30.422 	 return: 30.951 	 ep_len: 30.951
epoch:   6 	 loss: 30.160 	 return: 35.582 	 ep_len: 35.582
epoch:   7 	 loss: 31.069 	 return: 37.193 	 ep_len: 37.193
epoch:   8 	 loss: 28.830 	 return: 36.577 	 ep_len: 36.577
epoch:   9 	 loss: 32.102 	 return: 40.724 	 ep_len: 40.724
epoch:  10 	 loss: 34.394 	 return: 43.138 	 ep_len: 43.138
epoch:  11 	 loss: 42.489 	 return: 52.896 	 ep_len: 52.896
epoch:  12 	 loss: 34.518 	 return: 46.229 	 ep_len: 46.229
epoch:  13 	 loss: 35.304 	 return: 48.212 	 ep_len: 48.212
epoch:  14 	 loss: 37.580 	 return: 48.462 	 ep_len: 48.462
epoch:  15 	 loss: 38.751 	 return: 52.135 	 ep_len: 52.135
epoch:  16 	 loss: 39.578 	 return: 54.2

In [19]:
PATH = f"{env_name}_{ep_len}"
torch.save(model, PATH)

# Test

In [20]:
# model_name = "CartPole-v1_1276"
# model = torch.load(f"models\\{model_name}")

env = gym.make(env_name, render_mode = "human")
num_episodes = 10

for e in range(num_episodes):
    state, _ = env.reset()
    done = False; score = 0
    
    while not done:
        action = get_action(torch.as_tensor(obs, dtype=torch.float32))
        new_state, reward, done, _, _ = env.step(action)
        state = new_state
        score += reward
        env.render()
        if score % 100 == 0:
            print(f"Episode {e}, score {score}")

env.close()