In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distributions

import collections
import random
import matplotlib.pyplot as plt
import numpy as np
import gymnasium as gym
import tqdm

In [21]:
train_env = gym.make('CartPole-v1')
test_env = gym.make('CartPole-v1')

In [22]:
# SEED = 333

# train_env.seed(SEED);
# test_env.seed(SEED+1);
# np.random.seed(SEED);
# random.seed(SEED)
# torch.manual_seed(SEED);

In [23]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc_1(x)
        x = F.relu(x)
        x = self.fc_2(x)
        return x

In [24]:
input_dim = train_env.observation_space.shape[0]
hidden_dim = 32
output_dim = train_env.action_space.n

In [25]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)

In [35]:
def train(env, policy, optimizer, discount_factor, epsilon, device):
    policy.train()
    
    episode_reward = 0.0
    done = False

    # --- Handle both Gym and Gymnasium reset APIs ---
    reset_output = env.reset()
    if isinstance(reset_output, tuple):
        state, _ = reset_output  # Gymnasium returns (obs, info)
    else:
        state = reset_output     # Old Gym returns obs only

    # --- Ensure proper tensor shape ---
    state = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    while not done:
        # --- Epsilon-greedy action selection ---
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_pred = policy(state)
                action = torch.argmax(q_pred, dim=1).item()

        # --- Take step in environment ---
        step_output = env.step(action)
        if len(step_output) == 5:
            next_state, reward, terminated, truncated, _ = step_output
            done = terminated or truncated
        else:
            next_state, reward, done, _ = step_output

        next_state = torch.as_tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)

        # --- Update policy ---
        loss = update_policy(policy, state, action, reward, next_state, discount_factor, optimizer)

        # --- Transition ---
        state = next_state
        episode_reward += reward

    return loss, episode_reward, epsilon

In [36]:
def update_policy(policy, state, action, reward, next_state, discount_factor, optimizer):
    
    q_preds = policy(state)

    q_vals = q_preds[:, action]

    with torch.no_grad():
        q_next_preds = policy(next_state)
        q_next_vals = q_next_preds.max(1).values
        targets = reward + q_next_vals * discount_factor

    loss = F.smooth_l1_loss(targets.detach(), q_vals)
    
    optimizer.zero_grad()
    
    loss.backward()

    nn.utils.clip_grad_norm_(policy.parameters(), 0.5)

    optimizer.step()
    
    return loss.item()

In [41]:
def evaluate(env, policy, device):
    
    policy.eval()
    
    done = False
    episode_reward = 0

    state, _ = env.reset()

    while not done:
        
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        
        with torch.no_grad():
        
            q_pred = policy(state)
            
            action = torch.argmax(q_pred).item()

        # state, reward, done, _ = env.step(action)
        state, reward, done, truncated, info = env.step(action)
        
        episode_reward += reward

    return episode_reward

In [42]:
n_runs = 10
n_episodes = 500
discount_factor = 0.8
start_epsilon = 1.0
end_epsilon = 0.01
epsilon_decay = 0.995

train_rewards = torch.zeros(n_runs, n_episodes)
test_rewards = torch.zeros(n_runs, n_episodes)
device = torch.device('cpu')

for run in range(n_runs):
    
    policy = MLP(input_dim, hidden_dim, output_dim)
    policy = policy.to(device)
    policy.apply(init_weights)
    epsilon = start_epsilon

    optimizer = optim.RMSprop(policy.parameters(), lr=1e-6)

    for episode in tqdm.tqdm(range(n_episodes), desc=f'Run: {run}'):

        loss, train_reward, epsilon = train(train_env, policy, optimizer, discount_factor, epsilon, device)

        epsilon *= epsilon_decay
        epsilon = min(epsilon, end_epsilon)

        test_reward = evaluate(test_env, policy, device)
        
        train_rewards[run][episode] = train_reward
        test_rewards[run][episode] = test_reward

Run: 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:18<00:00, 27.39it/s]
Run: 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:07<00:00, 68.90it/s]
Run: 2:   1%|▊                                                                                                                                      | 3/500 [00:11<32:02,  3.87s/it]


KeyboardInterrupt: 

In [None]:
idxs = range(n_episodes)
fig, ax = plt.subplots(n_runs, figsize=(10,6))
for i, _ax in enumerate(ax):
    _ax.plot(idxs, train_rewards[i], c='red')
    _ax.plot(idxs, test_rewards[i], c='blue')
    _ax.set_ylim(0, 550)
    _ax.set_ylabel('Rewards');
    if i == n_runs - 1:
        _ax.set_xlabel('Episodes')

In [None]:
idxs = range(n_episodes)
fig, ax = plt.subplots(n_runs, figsize=(10,6))
for i, _ax in enumerate(ax):
    _ax.plot(idxs, test_rewards[i])
    _ax.set_ylim(0, 550)
    _ax.set_ylabel('Rewards');
    if i == n_runs - 1:
        _ax.set_xlabel('Episodes')

In [None]:
idxs = range(n_episodes)
fig, ax = plt.subplots(n_runs, figsize=(10,6))
for i, _ax in enumerate(ax):
    _ax.plot(idxs, train_rewards[i])
    _ax.set_ylim(0, 550)
    _ax.set_ylabel('Rewards');
    if i == n_runs - 1:
        _ax.set_xlabel('Episodes')

In [None]:
idxs = range(n_episodes)
fig, ax = plt.subplots(1, figsize=(10,6))
ax.plot(idxs, test_rewards.mean(0))
ax.set_ylim(0, 550)
ax.fill_between(idxs, test_rewards.min(0).values, test_rewards.max(0).values, alpha=0.1)
ax.set_xlabel('Episodes')
ax.set_ylabel('Rewards');

In [None]:
idxs = range(n_episodes)
fig, ax = plt.subplots(1, figsize=(10,6))
ax.plot(idxs, train_rewards.mean(0))
ax.set_ylim(0, 550)
ax.fill_between(idxs, train_rewards.min(0).values, train_rewards.max(0).values, alpha=0.1)
ax.set_xlabel('Episodes')
ax.set_ylabel('Rewards');

In [None]:
n = 500
start_x = 1.0
decay = 0.995
print(decay)
min_x = 0.01
x = start_x
ys = []
for i in range(n):
    x *= decay
    x = max(x, min_x)
    ys.append(x)
plt.plot(ys)
plt.ylim(0,1.1)

In [None]:
np.e

In [None]:
q = collections.deque(maxlen=5)

In [None]:
q

In [None]:
len(q)

In [None]:
q.append(1)

In [None]:
20_000/500

In [None]:
q.maxlen

In [None]:
q.append(1)
q.append(1)
q.append(1)
q.append(1)

In [None]:
len(q)

In [None]:
q.append(1)
q.append(1)
q.append(1)
q.append(1)

In [None]:
q

In [None]:
len(q)