In [1]:
import gymnasium as gym
import ale_py
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
gym.register_envs(ale_py)

In [4]:
BATCH_SIZE = 32
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_END = 0.02
EPSILON_DECAY = 1000000
TARGET_UPDATE = 1000
MEMORY_SIZE = 4
LEARNING_RATE = 1e-4

In [124]:
def preprocess_observation(obs):
    obs = obs[35:195]  # Crop
    obs = cv2.resize(obs, (84, 84))  # Resize
    obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)  # Convert to grayscale
    _, obs = cv2.threshold(obs, 1, 255, cv2.THRESH_BINARY)  # Binary
    return torch.tensor(obs / 255.0, dtype=torch.float32).to(device=device) # Normalize

In [125]:
# Define the DQN model
class DQN(nn.Module):
    def __init__(self, action_space):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc = nn.Linear(7 * 7 * 64, 512)
        self.out = nn.Linear(512, action_space)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc(x))
        return self.out(x)

In [126]:
env = gym.make("PongNoFrameskip-v4", difficulty=1)
policy_net = DQN(env.action_space.n).to(device)
target_net = DQN(env.action_space.n).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)

In [127]:
memory = deque(maxlen=MEMORY_SIZE)

In [128]:
def select_action(state, epsilon, action_space):
    if random.random() < epsilon:
        return random.randrange(action_space)
    else:
        if isinstance(state, np.ndarray):
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            state = state.unsqueeze(0) ## Add batch dimension
            return torch.argmax(policy_net(state)).item()

In [129]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return

    transitions = random.sample(memory, BATCH_SIZE)
    batch_state, batch_action, batch_reward, batch_next_state, batch_done = zip(*transitions)

    # Conversion en tenseurs PyTorch
    batch_state = torch.cat(batch_state).unflatten(0,(BATCH_SIZE,-1)).to(device)
    batch_action = torch.tensor(batch_action).to(device)
    batch_reward = torch.tensor(batch_reward).to(device)
    batch_next_state = torch.cat(batch_next_state).unflatten(0,(BATCH_SIZE,-1)).to(device)
    batch_done = torch.tensor(batch_done, dtype=torch.bool).to(device)

    current_q_values = policy_net(batch_state).gather(1, batch_action.unsqueeze(1)).to(device)
    next_q_values = target_net(batch_next_state).max(1)[0].detach().to(device)
    expected_q_values = batch_reward + (GAMMA * next_q_values) * (~batch_done).to(device)

    loss = nn.SmoothL1Loss()(current_q_values, expected_q_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()



In [130]:
def run_train():
    num_episodes = 5000
    epsilon = EPSILON_START
    writer = SummaryWriter(f'runs/pong_dqn_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}')
    global_step = 0
    for episode in tqdm(range(num_episodes), desc="Training Episodes"):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        state = torch.stack([state] * 4, axis=0)

        total_reward = 0
        done = False
        steps = 0

        while not done:
            action = select_action(state, epsilon, env.action_space.n)
            next_obs, reward, done, truncated, info = env.step(action)
            done = done or truncated
            total_reward += reward
            next_state = preprocess_observation(next_obs)
            next_state = torch.cat((state[1:, :, :], next_state.unsqueeze(0)), dim=0)
            memory.append((state, action, reward, next_state, done))
            state = next_state
            steps += 1
            global_step += 1  # Incrémente global_step pour chaque étape

            # Optimisation du modèle et enregistrement de la perte
            loss = optimize_model()
            if loss is not None:
                writer.add_scalar("Loss", loss, global_step)

            if steps % TARGET_UPDATE == 0:
                target_net.load_state_dict(policy_net.state_dict())

            epsilon = max(EPSILON_END, EPSILON_START - steps / EPSILON_DECAY)
            # for name, param in policy_net.named_parameters():
            #     if param.requires_grad:
            #         writer.add_histogram(f"{name}_weights", param, episode)
            #         if param.grad is not None:
            #             writer.add_histogram(f"{name}_gradients", param.grad, episode)
        writer.add_scalar("Epsilon", epsilon, episode)
        writer.add_scalar("Total Reward", total_reward, episode)
    writer.close()  # Ferme le writer après l'entraînement



In [131]:
run_train()

Training Episodes:  19%|█▊        | 932/5000 [13:02<56:54,  1.19it/s]  


KeyboardInterrupt: 

In [118]:
# Load the line_profiler extension
%load_ext line_profiler

# Profile the run_episode function
%lprun -f run_train run_train()

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


Training Episodes:   0%|          | 2/5000 [00:05<3:59:34,  2.88s/it]

*** KeyboardInterrupt exception caught in code being profiled.




Timer unit: 1e-09 s

Total time: 5.70917 s
File: /tmp/ipykernel_259017/3695827517.py
Function: run_train at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def run_train():
     2         1        956.0    956.0      0.0      num_episodes = 5000
     3         1        270.0    270.0      0.0      epsilon = EPSILON_START
     4         1    2662630.0    3e+06      0.0      writer = SummaryWriter(f'runs/pong_dqn_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}')
     5         1        187.0    187.0      0.0      global_step = 0
     6         3    2686663.0 895554.3      0.0      for episode in tqdm(range(num_episodes), desc="Training Episodes"):
     7         3    8358930.0    3e+06      0.1          obs, _ = env.reset()
     8         3     664733.0 221577.7      0.0          state = preprocess_observation(obs)
     9         3     134105.0  44701.7      0.0          state = np.stack([state] * 4, axis=0)
    10  