In [37]:
import numpy as np
import pygame
import random
import torch
import torch.nn as nn

from cathedral_rl import cathedral_v0
from cathedral_rl.game.manual_policy import ManualPolicy

## Commandes 


Liste des touches possibles et leur effet :

- Espace (K_SPACE) : Parcourt la liste des pièces disponibles en passant de la plus grande à la plus petite.
- E (K_e) : Fait tourner la pièce dans le sens horaire (rotation à -90° par incrément, en tenant compte du plateau inversé).
- Q (K_q) : Fait tourner la pièce dans le sens anti-horaire (rotation à +90° par incrément).
- Flèche droite (K_RIGHT) : Déplace la pièce vers la droite, en vérifiant que le déplacement est légal.
- Flèche gauche (K_LEFT) : Déplace la pièce vers la gauche, en vérifiant que le déplacement est légal.
- Flèche haut (K_UP) : Déplace la pièce vers le haut (attention : en pygame, la coordonnée y augmente vers le bas), en vérifiant que le déplacement est légal.
- Flèche bas (K_DOWN) : Déplace la pièce vers le bas, en vérifiant que le déplacement est légal.


## Chose starting player

In [38]:
starting_player = "human" # human or AI

In [39]:
if starting_player == "AI":
    controlled_agent = "player_0"
else:
    controlled_agent = "player_1"

## Load DQN policy

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [41]:
board_size = 8

factor_illegal_action = 1

parameters_updates = 10    # plusieurs mises à jour par épisode
target_update_freq = 20    # fréquence (en épisodes) de mise à jour du réseau cible
opponent_update_freq = 100

epsilon_start = 0.3
epsilon_final = 0.1
epsilon_decay = 100    
epsilon_opponent = 0.1  # faible exploration pour l'adversaire 

method = "eps_greedy"

In [42]:
class DQN(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super(DQN, self).__init__()
        # observations : (10, 10, 5)
        self.conv = nn.Sequential(
            nn.Conv2d(obs_shape[2], 32, kernel_size=3, stride=1, padding=1),  # output: 32 x 10 x 10
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),              # output: 64 x 10 x 10
            nn.ReLU(),
            nn.Flatten()
        )
        
        dummy = torch.zeros(1, obs_shape[2], obs_shape[0], obs_shape[1])
        conv_out_size = self.conv(dummy).shape[1]
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, n_actions)
        )
        
    def forward(self, x):
        # x(batch, hauteur, largeur, channels)
        x = x.permute(0, 3, 1, 2)  
        x = self.conv(x)
        x = self.fc(x)
        return x

In [43]:
def epsilon_by_episode(episode):
    return epsilon_final + (epsilon_start - epsilon_final) * np.exp(-episode / epsilon_decay)

def temperature_by_episode(episode):
    return 1

def select_action_dqn(model, obs, action_mask, legal_moves, episode, device, method, verbose=False):
    model.eval()
    not_legal_action = 0
    with torch.no_grad():
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)  # (1, 10, 10, 5)
        q_values = model(obs_tensor).squeeze(0)  # (n_actions,)
        
        if method == 'eps_greedy':
            epsilon = epsilon_by_episode(episode) # epsilon-greedy
            if random.random() < epsilon:
                action = random.choice(legal_moves)
                not_legal_action = 0
            else:
                first_action = torch.argmax(q_values).item()
                mask = torch.tensor(action_mask, dtype=torch.bool, device=device)
                q_values[~mask] = -1e8
                action = torch.argmax(q_values).item()
                not_legal_action = int(first_action != action)
        
        elif method == 'boltzmann':
            temperature = temperature_by_episode(episode)
            first_action = torch.argmax(q_values).item()
            mask = torch.tensor(action_mask, dtype=torch.bool, device=device)
            q_values[~mask] = -1e8
            action = torch.argmax(q_values).item()
            not_legal_action = int(first_action != action)
            probabilities = F.softmax(q_values / temperature, dim=-1)
            action = torch.multinomial(probabilities, num_samples=1).item()
        
            
            
    model.train()
    return action, not_legal_action

In [44]:
env = cathedral_v0.env(
    board_size=10,
    render_mode="human",
    per_move_rewards=True,
    final_reward_score_difference=True,
).unwrapped

env.reset()

n_actions = env.action_space(controlled_agent).n
obs_shape = env.observe(controlled_agent)["observation"].shape  # (10, 10, 5)
list_reward_training, policy_net_checkpoints, num_checkpoints = [], [], 0

checkpoint = torch.load("model_weights_DQN/test5.pth", weights_only=False)

policy_net = DQN(obs_shape, n_actions).to(device)
policy_net.load_state_dict(checkpoint['model_state_dict'])

list_reward_training = checkpoint['list_reward']
policy_net_checkpoints = checkpoint['policy_net_checkpoints']
num_checkpoints = checkpoint['num_checkpoints']
print(f'Num checkpoints: {num_checkpoints}')

list_reward = []

Num checkpoints: 0


## Play against AI

In [45]:
env.reset()
env.render()


iter = 1

# Agent_id can be 0 or 1 : indicates starting player
if starting_player == "AI":
    human_agent_id = 1
else:
    human_agent_id = 0

manual_policy = ManualPolicy(env, agent_id=human_agent_id) # Policy controlled by player

while env.agents:
    observation, reward, termination, truncation, info = env.last()
    mask = observation["action_mask"]
    legal_moves = [i for i, valid in enumerate(observation["action_mask"]) if valid]
    agent = env.agent_selection

    print(
        f"\nTurn: {iter} | ({agent}) "
        f"Legal pieces : {list(env.legal_pieces[agent])}, "
        f"Legal moves total: {np.count_nonzero(mask)}, "
        f"Remaining pieces: {env.board.unplaced_pieces[agent]}"
    )

    if agent == manual_policy.agent:                # Human action
        action = manual_policy(observation, agent)
    else:                                           # AI action
        state = observation["observation"]
        action, _ = select_action_dqn(policy_net, state, mask, legal_moves, 0, device, method)

    env.step(action)

    print(
        f"Turn: {iter} | "
        f"Action: {action}, "
        f"Piece: {env.board.action_to_piece_map(action)[0]}, "
        f"Position: {env.board.action_to_pos_rotation_mapp(agent, action)[0]}, "
    )
    print(
        f"Turn: {iter} | Reward: {env.rewards[agent]}, "
        f"Cumulative reward: {env._cumulative_rewards[agent]}, "
    )
    if env.turns["player_0"] == env.turns["player_1"]:
        print()
        for agent in env.agents:
            print(
                f"SCORE ({agent}): {env.score[agent]['total']:0.2f}, "
                f"Squares/turn: {env.score[agent]['squares_per_turn']:0.2f}, "
                f"Remaining pieces difference: {env.score[agent]['remaining_pieces']}, "
                f"Territory difference: {env.score[agent]['territory']}"
            )

    iter += 1

print("Terminated") if termination else print("Truncated")
print("\nWINNER: ", env.winner)
for agent in env.possible_agents:
    print(f"\n{agent} Final reward: {env.rewards[agent]}")
    print(f"{agent} Cumulative reward: {env._cumulative_rewards[agent]}")
    print(
        f"{agent} Final remaining pieces: {[p.name for p in env.final_pieces[agent]]}"
    )
    print(
        f"{agent} Score: {env.score[agent]['total']:0.2f}, "
        f"Squares/turn: {env.score[agent]['squares_per_turn']:0.2f}, "
        f"Remaining pieces difference: {env.score[agent]['remaining_pieces']}, "
        f"Territory difference: {env.score[agent]['territory']}"
    )
pygame.quit()


Turn: 1 | (player_0) Legal pieces : [14], Legal moves total: 224, Remaining pieces: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
Turn: 1 | Action: 2772, Piece: 14, Position: (2, 2), 
Turn: 1 | Reward: 0, Cumulative reward: 0, 

SCORE (player_0): 0.00, Squares/turn: 0.00, Remaining pieces difference: 0, Territory difference: 0
SCORE (player_1): 0.00, Squares/turn: 0.00, Remaining pieces difference: 0, Territory difference: 0

Turn: 2 | (player_1) Legal pieces : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], Legal moves total: 2366, Remaining pieces: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Turn: 2 | Action: 1573, Piece: 8, Position: (4, 4), 
Turn: 2 | Reward: -1, Cumulative reward: -1, 

Turn: 3 | (player_0) Legal pieces : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], Legal moves total: 2034, Remaining pieces: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Turn: 3 | Action: 2505, Piece: 13, Position: (1, 5), 
Turn: 3 | Reward: 0, Cumulative reward: 0, 

SCORE (player_0):