Imports


In [33]:
import math
import random
from collections import namedtuple, deque
from itertools import count

import gymnasium as gym
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display, clear_output

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

Plotting and GPU


In [34]:
# Enable interactive plotting
plt.ion()

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else 
                      "mps" if torch.backends.mps.is_available() else 
                      "cpu")

# %% [code]
print(f"Using device: {device}")

# Test if PyTorch is using the GPU for tensors
test_tensor = torch.rand(1).to(device)
if test_tensor.device.type == "cuda":
    print("✅ GPU is being used for PyTorch computations!")
elif test_tensor.device.type == "mps":
    print("✅ MPS (Apple Silicon GPU) is being used!")
else:
    print("❌ GPU is NOT being used. PyTorch is running on CPU.")


Hyperparameters


In [35]:
BATCH_SIZE = 128          # Batch size for replay memory
GAMMA = 0.99              # Discount factor
EPS_START = 0.9           # Starting epsilon for exploration
EPS_END = 0.05            # Minimum epsilon
EPS_DECAY = 1000          # Epsilon decay rate
TAU = 0.005               # Soft update rate for the target net
LR = 1e-4                 # Learning rate for AdamW optimizer
NUM_EPISODES = 500        # More episodes for robust training
MAX_STEPS_PER_EPISODE = 1000  # Longer episodes


Environment creation

In [36]:
env = gym.make("CartPole-v1", max_episode_steps=MAX_STEPS_PER_EPISODE)
n_actions = env.action_space.n
initial_state, info = env.reset()
n_observations = len(initial_state)

print(f"State dimension: {n_observations}, Action dimension: {n_actions}")
print(f"Using device: {device}")
print("Ready to train with DQN on CartPole.\n")


State dimension: 4, Action dimension: 2
Using device: cuda
Ready to train with DQN on CartPole.



Replay memory

In [37]:
# %% [code]
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        """Save a transition."""
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

memory = ReplayMemory(50_000)  # Larger capacity


DQN and network initialisation

In [38]:
# %% [code]
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Instantiate networks
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
steps_done = 0


Epsilon greedy action selection and optimise function

In [39]:
# %% [code]
def select_action(state):
    """
    Epsilon-greedy action selection.
    Decreases epsilon from EPS_START to EPS_END over EPS_DECAY steps.
    """
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

def optimize_model():
    """
    Sample a batch from replay memory, compute the Huber loss, 
    backpropagate, and soft-update the target network.
    """
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Non-final mask & cat
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), 
                                  device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Q(s, a)
    q_values = policy_net(state_batch).gather(1, action_batch)

    # Q target
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    expected_q_values = reward_batch + (GAMMA * next_state_values)

    # Loss
    loss = nn.SmoothL1Loss()(q_values, expected_q_values.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    # Soft update of target network
    target_state_dict = target_net.state_dict()
    policy_state_dict = policy_net.state_dict()
    for key in policy_state_dict:
        target_state_dict[key] = TAU * policy_state_dict[key] + (1.0 - TAU) * target_state_dict[key]
    target_net.load_state_dict(target_state_dict)


Training loop

In [40]:
# %% [code]
episode_durations = []

def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.clf()
    plt.title("Training..." if not show_result else "Result")
    plt.xlabel("Episode")
    plt.ylabel("Duration (Steps)")
    plt.plot(durations_t.numpy())
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy(), label="100-episode average")
        plt.legend()
    plt.pause(0.001)
    display(plt.gcf())
    clear_output(wait=True)

# -------- More episodes to see a longer training simulation -------- #
for i_episode in range(1, NUM_EPISODES + 1):
    obs, info = env.reset()
    state = torch.tensor([obs], dtype=torch.float32, device=device)
    total_reward = 0.0

    for t in range(MAX_STEPS_PER_EPISODE):
        # Epsilon-greedy action
        action_tensor = select_action(state)
        action_int = action_tensor.item()  # 0 or 1
        next_obs, reward, terminated, truncated, info = env.step(action_int)
        
        # Debug prints: action, immediate reward, position/angle
        if i_episode % 50 == 0:  # Print debug info every 50 episodes
            print(f"Episode {i_episode}, Step {t}, Action: {action_int}, Reward: {reward:.1f}, Obs: {next_obs}")

        total_reward += reward
        reward_tensor = torch.tensor([reward], device=device)
        
        done = terminated or truncated
        if done:
            next_state = None
        else:
            next_state = torch.tensor([next_obs], dtype=torch.float32, device=device)

        # Store transition
        memory.push(state, action_tensor, next_state, reward_tensor)
        state = next_state

        # Optimize
        optimize_model()

        if done:
            episode_durations.append(t + 1)
            print(f"Episode {i_episode} finished after {t+1} steps. Total reward: {total_reward:.1f}")
            break

    else:
        # If we never broke in the loop, it means we reached max steps
        episode_durations.append(MAX_STEPS_PER_EPISODE)
        print(f"Episode {i_episode} reached max steps ({MAX_STEPS_PER_EPISODE}). Total reward: {total_reward:.1f}")
    
    # Update plot
    plot_durations()

env.close()

print("Training complete.")
plot_durations(show_result=True)
plt.ioff()
plt.show()


<Figure size 640x480 with 0 Axes>

Test and Visualisation

In [41]:
# %% [code]
# Create a test environment with a large step limit for visualization
test_env = gym.make("CartPole-v1", max_episode_steps=2000, render_mode="rgb_array")

def display_animation(frames, interval=50):
    fig = plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    plt.axis('off')
    patch = plt.imshow(frames[0])
    
    def animate(i):
        patch.set_data(frames[i])
        return [patch]

    anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=interval)
    plt.close(fig)
    return HTML(anim.to_jshtml())

# Test the trained agent in a single long episode
frames = []
state, info = test_env.reset()
state_tensor = torch.tensor([state], dtype=torch.float32, device=device)
total_reward = 0.0

for step_count in range(2000):  # Up to 2000 steps for a longer test
    with torch.no_grad():
        # Greedy action (no random exploration for testing)
        action = policy_net(state_tensor).max(1)[1].item()
    obs, reward, terminated, truncated, info = test_env.step(action)
    frames.append(test_env.render())  # Save each frame
    total_reward += reward
    done = terminated or truncated
    if done:
        print(f"Test Episode ended after {step_count+1} steps, total reward: {total_reward:.1f}")
        break
    state_tensor = torch.tensor([obs], dtype=torch.float32, device=device)

test_env.close()

print("Displaying test animation...")
display(display_animation(frames, interval=25))


Test Episode ended after 2000 steps, total reward: 2000.0
Displaying test animation...
