In [3]:
from env import MazeEnv

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

#réseau policy
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)  # Pour des actions discrètes
        )
        
    def forward(self, state):
        return self.net(state)

#réseau critic (q-value)
class Critic(nn.Module):
    def __init__(self, total_state_dim, total_action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(total_state_dim + total_action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.net(x)


In [1]:
class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)
    
    def add(self, transition):
        self.buffer.append(transition)
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return map(np.array, zip(*batch))
    
    def size(self):
        return len(self.buffer)

In [4]:
env = MazeEnv()

In [7]:
num_agents = 4
state_dim = env.single_agent_state_size
action_dim = env.action_space.n


actors = [Actor(state_dim, action_dim) for _ in range(num_agents)]
critics = [Critic(state_dim * num_agents, action_dim * num_agents) for _ in range(num_agents)]

actor_optimizers = [optim.Adam(actor.parameters(), lr=1e-3) for actor in actors]
critic_optimizers = [optim.Adam(critic.parameters(), lr=1e-3) for critic in critics]

buffer = ReplayBuffer()


In [8]:
episodes = 10000
batch_size = 128
gamma = 0.95

for episode in range(episodes):
    states, _ = env.reset()
    
    done = False
    while not done:
        actions = []
        
        #sélection des actions pour chaque agent
        for i, actor in enumerate(actors):
            state_tensor = torch.FloatTensor(states[i])
            probs = actor(state_tensor) #distribution de probabilité des actions
            action = torch.multinomial(probs, 1).item() #tirage aléatoire d'une action
            actions.append(action)
        
        next_states, rewards, terminated, truncated, _ = env.step(actions)
        done = terminated or truncated
        
        buffer.add((states, actions, rewards, next_states))
        states = next_states
        
        #aprentissage si le buffer est assez rempli
        if buffer.size() >= batch_size:
            batch_states, batch_actions, batch_rewards, batch_next_states = buffer.sample(batch_size)
            
            #préparation des tenseurs batch
            batch_states = torch.FloatTensor(batch_states)
            batch_actions = torch.FloatTensor(batch_actions)
            batch_rewards = torch.FloatTensor(batch_rewards)
            batch_next_states = torch.FloatTensor(batch_next_states)
            
            for i in range(num_agents):
                #critic update
                with torch.no_grad():
                    next_actions = []
                    for j, actor in enumerate(actors):
                        next_action = actor(batch_next_states[:, j, :])
                        next_action = torch.multinomial(next_action, 1).squeeze(-1)
                        next_actions.append(next_action)
                    next_actions = torch.stack(next_actions, dim=1).float()
                    
                    target_q = critics[i](batch_next_states.view(batch_size, -1), next_actions.view(batch_size, -1))
                    y = batch_rewards[:, i].unsqueeze(1) + gamma * target_q

                current_q = critics[i](batch_states.view(batch_size, -1), batch_actions.view(batch_size, -1))
                critic_loss = nn.MSELoss()(current_q, y)

                critic_optimizers[i].zero_grad()
                critic_loss.backward()
                critic_optimizers[i].step()

                #actor update
                curr_actions = []
                for j, actor in enumerate(actors):
                    action_prob = actor(batch_states[:, j, :])
                    action = torch.multinomial(action_prob, 1).squeeze(-1)
                    curr_actions.append(action)
                curr_actions = torch.stack(curr_actions, dim=1).float()
                
                actor_loss = -critics[i](batch_states.view(batch_size, -1), curr_actions.view(batch_size, -1)).mean()
                
                actor_optimizers[i].zero_grad()
                actor_loss.backward()
                actor_optimizers[i].step()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x172 and 196x256)