In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import gym
from collections import deque
import os

# ===============================
#  Q-Network (standard network)
# ===============================
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


# ===============================
#  Replay Buffer
# ===============================
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.array, zip(*batch))
        return (
            torch.tensor(state, dtype=torch.float32),
            torch.tensor(action, dtype=torch.int64),
            torch.tensor(reward, dtype=torch.float32),
            torch.tensor(next_state, dtype=torch.float32),
            torch.tensor(done, dtype=torch.float32),
        )

    def __len__(self):
        return len(self.buffer)


# ===============================
#  Epsilon-greedy policy
# ===============================
def select_action(q_network, state, epsilon, action_dim):
    if random.random() < epsilon:
        return random.randint(0, action_dim - 1)
    with torch.no_grad():
        q_values = q_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
    return q_values.argmax().item()


# ===============================
#  Train Step
# ===============================
def train_step(q_network, target_network, optimizer, replay_buffer, batch_size, gamma=0.99):
    if len(replay_buffer) < batch_size:
        return 0.0

    states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

    # Compute Q targets
    with torch.no_grad():
        target_q = rewards + gamma * (1 - dones) * target_network(next_states).max(1)[0]

    current_q = q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    loss = F.mse_loss(current_q, target_q)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


# ===============================
#  Main Curriculum Training Loop
# ===============================
def train_curriculum_long_to_short(
    total_episodes=1000,
    num_bins=10,
    buffer_capacity=100000,
    batch_size=64,
    gamma=0.99,
    lr=1e-3,
    epsilon_start=1.0,
    epsilon_end=0.01,
    epsilon_decay=0.995,
    target_update_freq=1000,
    render=False,
    save_path="weights/dqn_curriculum_long_to_short.pth"
):

    env = gym.make("CartPole-v1", render_mode="human" if render else None)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    q_network = QNetwork(state_dim, action_dim)
    target_network = QNetwork(state_dim, action_dim)
    target_network.load_state_dict(q_network.state_dict())
    optimizer = optim.Adam(q_network.parameters(), lr=lr)
    replay_buffer = ReplayBuffer(capacity=buffer_capacity)

    # Curriculum setup
    lengths = np.linspace(1.8, 0.4, num_bins)
    episodes_per_bin = total_episodes // num_bins
    epsilon = epsilon_start

    all_rewards = []
    total_steps = 0

    for episode in range(total_episodes):
        # Determine current bin (progress linearly)
        bin_idx = min(episode // episodes_per_bin, num_bins - 1)
        current_length = lengths[bin_idx]

        env.unwrapped.length = current_length
        state, _ = env.reset()
        done = False
        episode_reward = 0

        while not done:
            action = select_action(q_network, state, epsilon, action_dim)
            next_state, reward, done, _, _ = env.step(action)
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward

            loss = train_step(q_network, target_network, optimizer, replay_buffer, batch_size, gamma)

            total_steps += 1
            if total_steps % target_update_freq == 0:
                target_network.load_state_dict(q_network.state_dict())

        epsilon = max(epsilon * epsilon_decay, epsilon_end)
        all_rewards.append(episode_reward)

        if (episode + 1) % 100 == 0:
            avg_reward = np.mean(all_rewards[-100:])
            print(f"[Episode {episode+1}/{total_episodes}] "
                  f"Length={current_length:.2f} | Eps={epsilon:.3f} | "
                  f"AvgReward(100)={avg_reward:.1f}")

    env.close()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(q_network.state_dict(), save_path)
    print(f"Training complete. Model saved to {save_path}")


# ===============================
#  Entry Point
# ===============================
if __name__ == "__main__":
    train_curriculum_long_to_short(
        total_episodes=1000,
        num_bins=10,
        buffer_capacity=100000,
        batch_size=64,
        gamma=0.99,
        lr=1e-3,
        epsilon_start=1.0,
        epsilon_end=0.01,
        epsilon_decay=0.995,
        target_update_freq=1000,
        render=False,
        save_path="weights/dqn_curriculum_short_to_long.pth"
    )

[Episode 100/1000] Length=1.80 | Eps=0.606 | AvgReward(100)=54.4
[Episode 200/1000] Length=1.64 | Eps=0.367 | AvgReward(100)=108.4
[Episode 300/1000] Length=1.49 | Eps=0.222 | AvgReward(100)=116.2
[Episode 400/1000] Length=1.33 | Eps=0.135 | AvgReward(100)=109.2
[Episode 500/1000] Length=1.18 | Eps=0.082 | AvgReward(100)=423.1
[Episode 600/1000] Length=1.02 | Eps=0.049 | AvgReward(100)=1142.1
[Episode 700/1000] Length=0.87 | Eps=0.030 | AvgReward(100)=817.7
[Episode 800/1000] Length=0.71 | Eps=0.018 | AvgReward(100)=732.7
[Episode 900/1000] Length=0.56 | Eps=0.011 | AvgReward(100)=224.0
[Episode 1000/1000] Length=0.40 | Eps=0.010 | AvgReward(100)=1046.4
Training complete. Model saved to weights/dqn_curriculum_short_to_long.pth
