In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
import random
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [8]:
X_train = np.load('../Data/X_train.npy')
state_size = X_train.shape[1]
action_size = 2

In [9]:
class DQN(nn.Module):
    """Réseau de neurones pour l'apprentissage Q (Deep Q-Network)"""
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, output_dim)
        
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.kaiming_normal_(self.fc3.weight)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
class DQNAgent:
    """Agent utilisant un DQN avec replay memory et exploration epsilon-greedy"""
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=10000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001

        self.model = DQN(state_size, action_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()
        
        self.losses = []
        self.rewards = []
        self.epsilons = []
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            act_values = self.model(state_tensor)
            return torch.argmax(act_values[0]).item()
    
    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return 0
        
        minibatch = random.sample(self.memory, batch_size)
        total_loss = 0
        
        for state, action, reward, next_state, done in minibatch:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
            
            target = reward
            if not done:
                with torch.no_grad():
                    target = reward + self.gamma * torch.max(self.model(next_state_tensor)[0]).item()
            
            current_q = self.model(state_tensor)[0]
            target_f = current_q.clone()
            target_f[action] = target
            
            self.optimizer.zero_grad()
            loss = self.criterion(current_q, target_f)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
        
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
        avg_loss = total_loss / batch_size
        self.losses.append(avg_loss)
        return avg_loss
    
    def plot_metrics(self):
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        ax1.plot(self.losses)
        ax1.set_title('Évolution de la perte pendant l\'entraînement')
        ax1.set_xlabel('Mini-batch')
        ax1.set_ylabel('Perte (MSE)')
        ax1.grid(True)
        
        ax2.plot(self.epsilons, label='Epsilon')
        if self.rewards:
            window_size = min(50, len(self.rewards))
            smoothed_rewards = np.convolve(self.rewards, np.ones(window_size)/window_size, mode='valid')
            ax2.plot(smoothed_rewards, label='Récompense moyenne (lissée)')
        
        ax2.set_title('Évolution de l\'exploration et des récompenses')
        ax2.set_xlabel('Épisode')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig('../Data/training_metrics.png')
        plt.show()
    
    def load(self, name):
        self.model.load_state_dict(torch.load(name))
    
    def save(self, name):
        torch.save(self.model.state_dict(), name)

Initialisation de l'agent

In [11]:
agent = DQNAgent(state_size, action_size)

Sauvegarde de l'agent initial

In [12]:
agent.save('../Data/initial_dqn_model.pth')