In [1]:
import random
import collections
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
GAMMA = 0.99
LR = 1e-3
BATCH_SIZE = 64
BUFFER_SIZE = 50_000
MIN_REPLAY_SIZE = 1_000

EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 500  

TARGET_UPDATE_FREQ = 1000 
TRAIN_FREQ = 4
MAX_STEPS = 200_000
EVAL_EPISODES = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(DEVICE)

cuda


In [3]:
Transition = collections.namedtuple(
    "Transition", ["state", "action", "reward", "next_state", "done"]
)


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        batch = Transition(*zip(*batch))  # transpose
        # Convert to tensors
        states = torch.as_tensor(np.array(batch.state), dtype=torch.float32, device=DEVICE)
        actions = torch.as_tensor(batch.action, dtype=torch.int64, device=DEVICE).unsqueeze(-1)
        rewards = torch.as_tensor(batch.reward, dtype=torch.float32, device=DEVICE).unsqueeze(-1)
        next_states = torch.as_tensor(np.array(batch.next_state), dtype=torch.float32, device=DEVICE)
        dones = torch.as_tensor(batch.done, dtype=torch.float32, device=DEVICE).unsqueeze(-1)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


In [4]:
@dataclass
class EpsilonScheduler:
    eps_start: float = EPS_START
    eps_end: float = EPS_END
    eps_decay: int = EPS_DECAY

    def value(self, step: int) -> float:
        # Exponential decay
        return self.eps_end + (self.eps_start - self.eps_end) * np.exp(-1.0 * step / self.eps_decay)


In [19]:
class DQN(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return self.net(x)


In [20]:
env = gym.make("CartPole-v1")
obs, _ = env.reset()


print(env.observation_space.shape)
obs

(4,)


array([-0.0086954 ,  0.01555407,  0.02404243, -0.01035482], dtype=float32)

In [8]:
env.action_space.n

2

In [13]:
actions = torch.tensor([[1], [0], [0], [1]])
qs = torch.rand(4, 2)

In [14]:
qs.gather(1, actions)

tensor([[0.0436],
        [0.3663],
        [0.6396],
        [0.1610]])

In [15]:
qs

tensor([[0.3254, 0.0436],
        [0.3663, 0.1533],
        [0.6396, 0.0343],
        [0.6579, 0.1610]])

In [18]:
next_qs = torch.rand(4, 2)
print(next_qs)
print(next_qs.max(dim=1))

tensor([[0.0375, 0.3798],
        [0.5424, 0.4772],
        [0.3766, 0.0095],
        [0.5031, 0.3305]])
torch.return_types.max(
values=tensor([0.3798, 0.5424, 0.3766, 0.5031]),
indices=tensor([1, 0, 0, 0]))


In [None]:
def train_dqn(env_id: str = "CartPole-v1"):
    env = gym.make(env_id)
    eval_env = gym.make(env_id)

    obs, _ = env.reset()
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    q_net = DQN(obs_dim, n_actions).to(DEVICE)
    target_net = DQN(obs_dim, n_actions).to(DEVICE)
    target_net.load_state_dict(q_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(q_net.parameters(), lr=LR)
    replay_buffer = ReplayBuffer(BUFFER_SIZE)
    eps_sched = EpsilonScheduler()

    print("Filling replay buffer with random policy...")
    while len(replay_buffer) < MIN_REPLAY_SIZE:
        obs, _ = env.reset()
        done = False
        while not done and len(replay_buffer) < MIN_REPLAY_SIZE:
            action = env.action_space.sample()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            replay_buffer.push(obs, action, reward, next_obs, done)
            obs = next_obs

    print("Starting training...")
    obs, _ = env.reset()
    episode_reward = 0.0
    episode = 0

    total_steps = 0
    while total_steps < MAX_STEPS:
        eps = eps_sched.value(total_steps)
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
                q_values = q_net(obs_tensor)
                action = int(torch.argmax(q_values, dim=1).item())

        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(obs, action, reward, next_obs, done)

        obs = next_obs
        episode_reward += reward
        total_steps += 1

        if total_steps % TRAIN_FREQ == 0:
            states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)

            # Current Q(s, a)
            q_values = q_net(states).gather(1, actions)

            # Target Q
            with torch.no_grad():
                next_q_values = target_net(next_states).max(dim=1, keepdim=True)[0]
                targets = rewards + GAMMA * (1 - dones) * next_q_values

            loss = nn.functional.mse_loss(q_values, targets)

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
            optimizer.step()

        # Target network update
        if total_steps % TARGET_UPDATE_FREQ == 0:
            target_net.load_state_dict(q_net.state_dict())

        # End of episode
        if done:
            episode += 1
            print(f"Episode {episode:4d} | Steps {total_steps:7d} | "
                  f"Return {episode_reward:6.1f} | eps={eps:.3f}")
            obs, _ = env.reset()
            episode_reward = 0.0

    # -----------------------
    # Evaluation
    # -----------------------
    print("\nEvaluating greedy policy...")
    returns = []
    for ep in range(EVAL_EPISODES):
        obs, _ = eval_env.reset()
        done = False
        ep_ret = 0.0
        while not done:
            with torch.no_grad():
                obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
                q_values = q_net(obs_tensor)
                action = int(torch.argmax(q_values, dim=1).item())
            obs, reward, terminated, truncated, _ = eval_env.step(action)
            done = terminated or truncated
            ep_ret += reward
        returns.append(ep_ret)
        print(f"Eval episode {ep+1}: return = {ep_ret}")

    print(f"\nAverage return over {EVAL_EPISODES} eval episodes: {np.mean(returns):.2f}")
    env.close()
    eval_env.close()
