In [1]:
!pip install gym torch numpy matplotlib



In [2]:
import gym
import math
import random
import numpy as np
from collections import deque, namedtuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
ENV_NAME = 'CartPole-v1'
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 500
TARGET_UPDATE = 10
MEMORY_SIZE = 10000
LR = 1e-3
NUM_EPISODES = 500

# Replay Memory
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# Dueling DQN Network
class DuelingDQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DuelingDQN, self).__init__()
        self.feature = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.ReLU(),
        )

        # Value Stream
        self.value_stream = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

        # Advantage Stream
        self.advantage_stream = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )

    def forward(self, x):
        x = self.feature(x)
        value = self.value_stream(x)
        advantage = self.advantage_stream(x)
        # Combine them to get Q-values
        qvals = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return qvals

# Epsilon Greedy Policy
def select_action(state, policy_net, steps_done, n_actions):
    epsilon = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    sample = random.random()
    if sample > epsilon:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

# Plotting function
def plot_durations(episode_durations, avg_window=100):
    plt.figure(1)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training Dueling DQN')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Compute and plot the average
    if len(durations_t) >= avg_window:
        means = durations_t.unfold(0, avg_window, 1).mean(1).flatten()
        means = torch.cat((torch.zeros(avg_window-1), means))
        plt.plot(means.numpy())
    plt.pause(0.001)  # pause a bit so that plots are updated

def main():
    env = gym.make(ENV_NAME)
    n_actions = env.action_space.n
    state_size = env.observation_space.shape[0]

    policy_net = DuelingDQN(state_size, n_actions).to(device)
    target_net = DuelingDQN(state_size, n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=LR)
    memory = ReplayMemory(MEMORY_SIZE)

    steps_done = 0
    episode_durations = []

    for episode in range(NUM_EPISODES):
        state = env.reset()
        state = torch.tensor([state], device=device, dtype=torch.float)
        total_reward = 0
        for t in range(1, 10000):  # Don't infinite loop
            action = select_action(state, policy_net, steps_done, n_actions)
            steps_done += 1
            next_state, reward, done, _ = env.step(action.item())
            total_reward += reward
            reward = torch.tensor([reward], device=device, dtype=torch.float)
            next_state = torch.tensor([next_state], device=device, dtype=torch.float)
            done_flag = torch.tensor([done], device=device, dtype=torch.float)

            memory.push(state, action, reward, next_state, done_flag)

            state = next_state

            # Perform optimization
            if len(memory) >= BATCH_SIZE:
                transitions = memory.sample(BATCH_SIZE)
                batch = Transition(*zip(*transitions))

                # Convert to tensors
                state_batch = torch.cat(batch.state)
                action_batch = torch.cat(batch.action)
                reward_batch = torch.cat(batch.reward)
                next_state_batch = torch.cat(batch.next_state)
                done_batch = torch.cat(batch.done)

                # Compute Q(s_t, a)
                state_action_values = policy_net(state_batch).gather(1, action_batch)

                # Compute V(s_{t+1}) for all next states.
                with torch.no_grad():
                    next_state_values = target_net(next_state_batch).max(1)[0].unsqueeze(1)
                    # Compute the expected Q values
                    expected_state_action_values = reward_batch.unsqueeze(1) + (GAMMA * next_state_values * (1 - done_batch.unsqueeze(1)))

                # Compute loss
                loss = F.mse_loss(state_action_values, expected_state_action_values)

                # Optimize the model
                optimizer.zero_grad()
                loss.backward()
                # Clip gradients to prevent explosion
                for param in policy_net.parameters():
                    param.grad.data.clamp_(-1, 1)
                optimizer.step()

            if done:
                episode_durations.append(total_reward)
                if episode % TARGET_UPDATE == 0:
                    target_net.load_state_dict(policy_net.state_dict())
                print(f"Episode {episode}: Total Reward: {total_reward}")
                break

        # Optionally plot
        if episode % 10 == 0:
            plot_durations(episode_durations)

    print('Training complete')
    env.close()
    # Show final plot
    plt.ioff()
    plot_durations(episode_durations)
    plt.show()

if __name__ == '__main__':
    main()

Output hidden; open in https://colab.research.google.com to view.