In [None]:
import gymnasium as gym
import ale_py
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
from tqdm import tqdm


In [69]:
gym.register_envs(ale_py)

In [70]:
BATCH_SIZE = 32
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_END = 0.02
EPSILON_DECAY = 1000000
TARGET_UPDATE = 1000
MEMORY_SIZE = 10000
LEARNING_RATE = 1e-4

In [71]:
def preprocess_observation(obs):
    obs = obs[35:195]  # Crop
    obs = cv2.resize(obs, (84, 84))  # Resize
    obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)  # Convert to grayscale
    _, obs = cv2.threshold(obs, 1, 255, cv2.THRESH_BINARY)  # Binary
    return obs / 255.0  # Normalize


In [72]:
# Define the DQN model
class DQN(nn.Module):
    def __init__(self, action_space):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc = nn.Linear(7 * 7 * 64, 512)
        self.out = nn.Linear(512, action_space)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc(x))
        return self.out(x)

In [73]:
memory = deque(maxlen=MEMORY_SIZE)

In [74]:
def select_action(state, epsilon, action_space):
    if random.random() < epsilon:
        return random.randrange(action_space)
    else:
        if isinstance(state, np.ndarray):
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            return torch.argmax(policy_net(state)).item()

In [75]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return

    transitions = random.sample(memory, BATCH_SIZE)
    batch_state, batch_action, batch_reward, batch_next_state, batch_done = zip(*transitions)

    # Conversion en tenseurs PyTorch
    batch_state = torch.cat([torch.tensor(s, dtype=torch.float32).unsqueeze(0) for s in batch_state])
    batch_action = torch.tensor(batch_action)
    batch_reward = torch.tensor(batch_reward)
    batch_next_state = torch.cat([torch.tensor(s, dtype=torch.float32).unsqueeze(0) for s in batch_next_state])
    batch_done = torch.tensor(batch_done, dtype=torch.bool)

    current_q_values = policy_net(batch_state).gather(1, batch_action.unsqueeze(1))
    next_q_values = target_net(batch_next_state).max(1)[0].detach()
    expected_q_values = batch_reward + (GAMMA * next_q_values) * (~batch_done)

    loss = nn.MSELoss()(current_q_values, expected_q_values.unsqueeze(1))    
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()



In [76]:
env = gym.make("PongNoFrameskip-v4", difficulty=1)
policy_net = DQN(env.action_space.n)
target_net = DQN(env.action_space.n)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)

In [None]:
num_episodes = 5000
epsilon = EPSILON_START

for episode in tqdm(range(num_episodes), desc="Training Episodes"):
    obs, _ = env.reset()
    state = preprocess_observation(obs)
    state = np.stack([state] * 4, axis=0)
    state = torch.tensor(state, dtype=torch.float32)

    total_reward = 0
    done = False
    steps = 0

    while not done:
        action = select_action(state, epsilon, env.action_space.n)
        next_obs, reward, done, truncated, _ = env.step(action)
        total_reward += reward
        next_state = preprocess_observation(next_obs)
        next_state = np.concatenate((state[1:, :, :], np.expand_dims(next_state, 0)), axis=0)
        memory.append((state, action, reward, next_state, done))
        state = next_state

        optimize_model()

        if steps % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

        steps += 1
        epsilon = max(EPSILON_END, EPSILON_START - steps / EPSILON_DECAY)

    tqdm.write(f"Episode {episode}, Total Reward: {total_reward}")


  batch_state = torch.cat([torch.tensor(s, dtype=torch.float32).unsqueeze(0) for s in batch_state])
Training Episodes:   0%|          | 1/5000 [04:00<333:27:22, 240.14s/it]

Episode 0, Total Reward: -20.0


Training Episodes:   0%|          | 2/5000 [07:56<330:16:48, 237.90s/it]

Episode 1, Total Reward: -20.0


Training Episodes:   0%|          | 3/5000 [12:20<346:35:35, 249.70s/it]

Episode 2, Total Reward: -20.0
