In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [19]:
class CartPoleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        out = self.model(x)
        return out
    
    def get_action(self, x):
        out = self(x)
        return torch.argmax(out, dim=1)

model = CartPoleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [20]:
from tqdm import tqdm

def run_episode(env_name="CartPole-v1"):
    env = gym.make(env_name)  # Remove render_mode for training
    episode = torch.tensor([])
    observation, info = env.reset()
    terminated = False
    truncated = False
    total_reward = 0
    
    while not terminated and not truncated:
        obs_input = torch.tensor(observation, dtype=torch.float32)[None, :]
        action = model.get_action(obs_input)
        out = model(obs_input)
        log_prob = torch.log(out[:, action])
        observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
        total_reward += reward
        episode = torch.cat([episode, log_prob])
    
    rewards_to_go = torch.arange(len(episode), 0, -1)
    episode_loss = (rewards_to_go * episode).sum() # Negative since we want to maximize reward
    env.close()
    return (episode_loss, total_reward)

# Run episodes sequentially instead of in parallel
for episode_num in tqdm(range(1000)):
    episode_losses = torch.zeros(1000)
    total_rewards = torch.zeros(1000)
    for i in range(1000):  # Still run 10 episodes per batch
        loss, total_reward = run_episode()
        episode_losses[i] = loss
        total_rewards[i] = total_reward
    
    # Sum up losses from all episodes
    total_loss = episode_losses.mean()
    total_reward = total_rewards.mean()
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    # Keep track of last 5 episodes in output
    if episode_num >= 5:
        print("\033[F\033[K", end="")  # Clear the first line
        print("\033[F\033[K", end="")  # Clear the second line
    print(f"Episode {episode_num} loss: {total_loss.item()}")
    print(f"Episode {episode_num} reward: {total_rewards.mean().item()}")
    print(f"Episode {episode_num} avg reward over last 100: {total_rewards[-100:].mean().item()}")

Episode 0 loss: -250.66653442382812
Episode 0 reward: 9.380000114440918
Episode 1 loss: -263.9714660644531
Episode 1 reward: 9.388999938964844
Episode 2 loss: -272.09307861328125
Episode 2 reward: 9.32800006866455
Episode 3 loss: -552.5525512695312
Episode 3 reward: 9.99899959564209
Episode 4 loss: -16533.763671875
Episode 4 reward: 31.892000198364258
Episode 5 loss: -5763.046875
Episode 5 reward: 24.17099952697754
Episode 6 loss: -3034.561279296875
Episode 6 reward: 18.097999572753906
Episode 7 loss: -373.2899475097656
Episode 7 reward: 9.531999588012695
Episode 8 loss: -299.8049011230469
Episode 8 reward: 9.35099983215332
Episode 9 loss: -299.3340759277344
Episode 9 reward: 9.371000289916992
Episode 10 loss: -297.77880859375
Episode 10 reward: 9.362000465393066
Episode 11 loss: -301.0747375488281
Episode 11 reward: 9.369000434875488
Episode 12 loss: -303.4683837890625
Episode 12 reward: 9.359999656677246
Episode 13 loss: -307.1150817871094
Episode 13 reward: 9.348999977111816
Episode

KeyboardInterrupt: 

: 