In [1]:
from cathedral_rl import cathedral_v0  
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import matplotlib.pyplot as plt
from tqdm import tqdm 



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device : {device}')
print('hello')

device : cuda


In [3]:
env = cathedral_v0.env(board_size=7, render_mode="text", per_move_rewards=True, final_reward_score_difference=True)

env.reset()
count = 0
print('=============================================')
while env.agents:
    agent = env.agent_selection
    
    observation = env.observe(agent)
    
    # obs = observation["observation"]
    # for i in range(obs.shape[2]):
    #     print(f"Canal {i+1} :")
    #     print(obs[:, :, i])
    #     print("\n")
    

    legal_moves = [i for i, valid in enumerate(observation["action_mask"]) if valid]
    # print(f'legale move {len(legal_moves)}')
    # print(f"{agent}. Actions légales : {legal_moves}")
    
    action = random.choice(legal_moves)
    # print(f"{agent} joue l'action {action}.")
    
    env.step(action)
    
    # print("==RENDER==")
    # env.render()
    
    print(f'rewards : {env.rewards}')
    print(f'cumulative reward : {env._cumulative_rewards}')
    print('=============================================')
    count += 1

print("La partie est terminée.")
print(count)
env.close()


rewards : {'player_0': np.int64(0), 'player_1': 0}
cumulative reward : {'player_0': np.int64(0), 'player_1': 0}
rewards : {'player_0': np.int64(0), 'player_1': np.int64(0)}
cumulative reward : {'player_0': np.int64(0), 'player_1': np.int64(0)}
rewards : {'player_0': np.int64(-2), 'player_1': np.int64(0)}
cumulative reward : {'player_0': np.int64(-2), 'player_1': np.int64(0)}
rewards : {'player_0': np.int64(-2), 'player_1': np.int64(-4)}
cumulative reward : {'player_0': np.int64(-4), 'player_1': np.int64(-4)}
rewards : {'player_0': np.int64(-2), 'player_1': np.int64(-4)}
cumulative reward : {'player_0': np.int64(-6), 'player_1': np.int64(-8)}
rewards : {'player_0': np.int64(-2), 'player_1': np.int64(-2)}
cumulative reward : {'player_0': np.int64(-8), 'player_1': np.int64(-10)}
rewards : {'player_0': np.int64(-4), 'player_1': np.int64(-2)}
cumulative reward : {'player_0': np.int64(-12), 'player_1': np.int64(-12)}
rewards : {'player_0': np.int64(-4), 'player_1': np.int64(-1)}
cumulative r

## DQN

In [4]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        
    def push(self, state, action, reward, next_state, done, action_mask, next_action_mask):
        self.buffer.append((state, action, reward, next_state, done, action_mask, next_action_mask))
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done, action_mask, next_action_mask = map(np.array, zip(*batch))
        return state, action, reward, next_state, done, action_mask, next_action_mask
        
    def __len__(self):
        return len(self.buffer)    


In [5]:
class DQN(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super(DQN, self).__init__()
        # observations : (10, 10, 5)
        self.conv = nn.Sequential(
            nn.Conv2d(obs_shape[2], 32, kernel_size=3, stride=1, padding=1),  # output: 32 x 10 x 10
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),              # output: 64 x 10 x 10
            nn.ReLU(),
            nn.Flatten()
        )
        
        dummy = torch.zeros(1, obs_shape[2], obs_shape[0], obs_shape[1])
        conv_out_size = self.conv(dummy).shape[1]
        print(f'conv_out_size : {conv_out_size}')
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, n_actions)
        )
        
    def forward(self, x):
        # x(batch, hauteur, largeur, channels)
        x = x.permute(0, 3, 1, 2)  
        x = self.conv(x)
        x = self.fc(x)
        return x


In [6]:
def preprocess_obs(obs):
    return torch.tensor(obs, dtype=torch.float32)

def select_action_dqn(model, obs, action_mask, device):
    model.eval()
    with torch.no_grad():
        obs_tensor = preprocess_obs(obs).unsqueeze(0).to(device)  # (1, 10, 10, 5)
        q_values = model(obs_tensor).squeeze(0)  # (n_actions,)
        mask = torch.tensor(action_mask, dtype=torch.bool, device=device)
        # Masquer les actions illégales en assignant -inf à leurs Q-valeurs
        q_values[~mask] = -1e8
        action = torch.argmax(q_values).item()
    model.train()
    return action

In [7]:
def get_next_state_for_controlled(env, controlled_agent, current_state, current_mask):
    """
    Simule les actions des autres agents jusqu'à ce que controlled_agent ait de nouveau la main.
    Retourne : next_state, next_action_mask et done_flag.
    """
    # On avance tant que controlled_agent n'est pas actif et que la partie n'est pas terminée.
    while env.agent_selection != controlled_agent and env.agents:
        current_agent = env.agent_selection
        obs = env.observe(current_agent)
        legal_moves = [i for i, valid in enumerate(obs["action_mask"]) if valid]
        if not legal_moves:
            env.step(0)
        else:
            env.step(random.choice(legal_moves))
            
    if controlled_agent in env.agents:
        next_obs = env.observe(controlled_agent)
        next_state = next_obs["observation"]
        next_action_mask = next_obs["action_mask"]
        done_flag = 0
    else:
        next_state = np.zeros_like(current_state)
        next_action_mask = np.zeros_like(current_mask)
        done_flag = 1
    return next_state, next_action_mask, done_flag


In [None]:
# --- Hyperparamètres ---
num_episodes = 100
buffer_capacity = 1000
batch_size = 64
gamma = 0.999
learning_rate = 1e-3
board_size = 8

updates = 10 # On effectue plusieurs mises à jour par épisode
target_update_freq = 10    # fréquence (en épisodes) de mise à jour du réseau cible

epsilon_start = 0.2
epsilon_final = 0.05
epsilon_decay = 100      

controlled_agent = "player_0" 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device : {device}')


def epsilon_by_episode(episode):
    return epsilon_final + (epsilon_start - epsilon_final) * np.exp(-episode / epsilon_decay)
    

device : cuda


In [None]:
def train_dqn(name):
    env = cathedral_v0.env(board_size=board_size, render_mode="text", per_move_rewards=True, final_reward_score_difference=False)
    env.reset()
    enter_train = False
    n_actions = env.action_space(controlled_agent).n
    print(f'n_actions : {n_actions}')
    obs_shape = env.observe(controlled_agent)["observation"].shape  # (10, 10, 5)

    policy_net = DQN(obs_shape, n_actions).to(device)
    target_net = DQN(obs_shape, n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
    replay_buffer = ReplayBuffer(buffer_capacity)
    list_reward = []
    list_epsilon = []
    win_count = 0
    for episode in range(num_episodes):
        env.reset()
        total_reward = 0
        losses = []
        while env.agents:
            current_agent = env.agent_selection
            observation = env.observe(current_agent)
            legal_moves = [i for i, valid in enumerate(observation["action_mask"]) if valid]
            
            if current_agent == controlled_agent:
                state = observation["observation"]
                action_mask = observation["action_mask"]
                # print(f'ACTION MASK : {np.any(action_mask==1)}\n{action_mask.shape}\n{action_mask}')
                    
                epsilon = epsilon_by_episode(episode) # epsilon-greedy
                list_epsilon.append(epsilon)
                if random.random() < epsilon:
                    action = random.choice(legal_moves)
                else:
                    action = select_action_dqn(policy_net, state, action_mask, device)
                
                # print(f'action : {action}')
                
                env.step(action)
                reward = env.rewards[current_agent]
                total_reward += reward

                next_state, next_action_mask, done_flag = get_next_state_for_controlled(env, controlled_agent, state, action_mask)
                
                
                replay_buffer.push(state, action, reward, next_state, done_flag, action_mask, next_action_mask)
                
            else:
                action = random.choice(legal_moves)
                env.step(action)
        
            if len(replay_buffer) >= batch_size:
                enter_train = True
                
                for _ in range(updates):  
                    states, actions, rewards, next_states, dones, action_masks, next_action_masks = replay_buffer.sample(batch_size)
                    
                    states_tensor = torch.tensor(states, dtype=torch.float32).to(device)
                    actions_tensor = torch.tensor(actions, dtype=torch.long).to(device)
                    rewards_tensor = torch.tensor(rewards, dtype=torch.float32).to(device)
                    next_states_tensor = torch.tensor(next_states, dtype=torch.float32).to(device)
                    dones_tensor = torch.tensor(dones, dtype=torch.float32).to(device)
                    next_action_masks_tensor = torch.tensor(next_action_masks, dtype=torch.bool).to(device)
                    
                    # Q-valeurs actuelles pour les actions sélectionnées
                    q_values = policy_net(states_tensor)
                    # print(f'q_values : {q_values.shape}\n{q_values}')
                    q_values = q_values.gather(1, actions_tensor.unsqueeze(1)).squeeze(1)
                    # print(f'q_values gather: {q_values.shape}\n{q_values}')
                    
                    # Calcul des Q-valeurs cibles via le réseau cible
                    with torch.no_grad():
                        next_q_values = target_net(next_states_tensor)  # shape: [batch_size, n_actions]
        
                        # On remplace les actions illégales par une grande valeur négative, mais finie.
                        next_q_values_masked = next_q_values.clone()
                        next_q_values_masked[~next_action_masks_tensor] = -1e8
                        
                        # Pour chaque échantillon, si aucune action n'est légale, on fixe le max à 0.
                        mask_sum = next_action_masks_tensor.sum(dim=1)
                        max_next_q_values = next_q_values_masked.max(1)[0]
                        max_next_q_values[mask_sum == 0] = 0.0

                        target_q_values = rewards_tensor + gamma * max_next_q_values * (1 - dones_tensor)
                        
                    loss = nn.MSELoss()(q_values, target_q_values)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    losses.append(loss.item())
                    
        list_reward.append(total_reward)
        if env.winner == 0:  # controlled_agent ("player_0") gagne
            win_count += 1
        elif env.winner == -1:  # match nul : on peut compter 0.5 victoire
            win_count += 0.5
        
        
        if enter_train:
            winner = env.winner
            print(f"Episode {episode+1}/{num_episodes} - Reward total: {total_reward:.2f} - Loss: {sum(losses)/len(losses):.4f} - Winner: {winner} - Epsilon: {epsilon_by_episode(episode):.2f}")
        if (episode+1) % target_update_freq == 0:
            target_net.load_state_dict(policy_net.state_dict())
            print("Update target_net")
    env.close()

    # torch.save(policy_net.state_dict(), f"model_weights_DQN/{name}.pth")
    torch.save({
        'model_state_dict': policy_net.state_dict(),
        'list_reward': list_reward,
        'list_epsilon': list_epsilon
    }, f"model_weights_DQN/{name}.pth")
    
    print(f'Winrate : {win_count/num_episodes}')
    
    return list_reward, list_epsilon

In [20]:
list_reward, list_epsilon = train_dqn('test2')

n_actions : 1753
conv_out_size : 4096
conv_out_size : 4096
Episode 7/100 - Reward total: -7.00 - Loss: 0.8255 - Winner: 0 - Epsilon: 0.19
Episode 8/100 - Reward total: -14.00 - Loss: 0.0678 - Winner: 1 - Epsilon: 0.19
Episode 9/100 - Reward total: -19.00 - Loss: 0.0593 - Winner: 1 - Epsilon: 0.19
Episode 10/100 - Reward total: -16.00 - Loss: 0.0309 - Winner: 1 - Epsilon: 0.19
Update target_net
Episode 11/100 - Reward total: -13.00 - Loss: 0.2623 - Winner: 1 - Epsilon: 0.19
Episode 12/100 - Reward total: -9.00 - Loss: 0.0299 - Winner: 0 - Epsilon: 0.18
Episode 13/100 - Reward total: -17.00 - Loss: 0.0248 - Winner: 1 - Epsilon: 0.18
Episode 14/100 - Reward total: -15.00 - Loss: 0.0276 - Winner: 0 - Epsilon: 0.18
Episode 15/100 - Reward total: -15.00 - Loss: 0.0329 - Winner: 1 - Epsilon: 0.18
Episode 16/100 - Reward total: -9.00 - Loss: 0.0152 - Winner: 0 - Epsilon: 0.18
Episode 17/100 - Reward total: -12.00 - Loss: 0.0265 - Winner: 1 - Epsilon: 0.18
Episode 18/100 - Reward total: -11.00 

In [21]:
def evaluate_DQN(name, num_episodes_eval=50):
    list_reward = []
    win_count = 0
    env = cathedral_v0.env(board_size=board_size, render_mode="text", per_move_rewards=True, final_reward_score_difference=False)
    env.reset()

    n_actions = env.action_space(controlled_agent).n
    obs_shape = env.observe(controlled_agent)["observation"].shape  # (10, 10, 5)

    if name != 'random':
        checkpoint = torch.load(f"model_weights_DQN/{name}.pth", weights_only=False)

        policy_net = DQN(obs_shape, n_actions).to(device)
        policy_net.load_state_dict(checkpoint['model_state_dict'])
        
        list_reward_training = checkpoint['list_reward']
    
    for episode in tqdm(range(num_episodes_eval)):
        env.reset()
        total_reward = 0

        while env.agents:
            current_agent = env.agent_selection
            observation = env.observe(current_agent)
            legal_moves = [i for i, valid in enumerate(observation["action_mask"]) if valid]
            
            if current_agent == controlled_agent:
                state = observation["observation"]
                action_mask = observation["action_mask"]
                
                if name == 'random':
                    action = random.choice(legal_moves)
                else:
                    action = select_action_dqn(policy_net, state, action_mask, device)
                
                env.step(action)
                reward = env.rewards[current_agent]
                total_reward += reward

            else:
                action = random.choice(legal_moves)
                env.step(action)
        
        list_reward.append(total_reward)
        if env.winner == 0:  # controlled_agent ("player_0") gagne
            win_count += 1
        elif env.winner == -1:  # match nul : on peut compter 0.5 victoire
            win_count += 0.5
    
    avg_reward = sum(list_reward)/len(list_reward)
    print(f"{num_episodes_eval} episodes => Avg Reward : {avg_reward} // Winrate : {win_count/num_episodes_eval}")
    env.close()
    return avg_reward

In [23]:
avg_reward = evaluate_DQN('test2')

conv_out_size : 4096


100%|██████████| 50/50 [00:21<00:00,  2.38it/s]

50 episodes => Avg Reward : -0.14 // Winrate : 0.94



