# Notebook 02h : Transformer Hiérarchique (Classification bras + Régression conditionnelle)

## Modifications par rapport au notebook 02g

1. **Classification explicite gauche/droite** : un head de classification binaire prédit P(bras droit) à partir de la distance curviligne `d` (seuil à d=0.5). Le couloir haut est attribué au bras le plus proche.

2. **Régression conditionnelle** : deux heads de régression spécialisés (un par bras) prédisent (x, y) + incertitude. La prédiction finale est un **mélange pondéré** par la probabilité de classification.

3. **Loss hiérarchique** : chaque head de régression n'est entraîné que sur les exemples de son bras (routage par les labels ground truth).

4. **Spike dropout (15%) + Gaussian noise (std=0.5)** : data augmentation identique au notebook 02d.

**Géométrie du U** (coordonnées normalisées [0, 1]) :
- Bras gauche : x ∈ [0, 0.3], y ∈ [0, 1]
- Bras droit : x ∈ [0.7, 1.0], y ∈ [0, 1]
- Couloir haut : x ∈ [0, 1], y ∈ [0.7, 1.0]
- Largeur du couloir : 0.3

**Squelette central** : 3 segments formant un U :
1. (0.15, 0) → (0.15, 0.85) — bras gauche
2. (0.15, 0.85) → (0.85, 0.85) — couloir haut
3. (0.85, 0.85) → (0.85, 0) — bras droit

**Classification bras** : d < 0.5 → bras gauche, d ≥ 0.5 → bras droit

**Loss combinée** : `L = L_cls (BCE) + L_pos_left (Gaussian NLL, masqué) + L_pos_right (Gaussian NLL, masqué) + λ_d * L_curvilinear (MSE) + λ_feas * L_feasibility`

## 1. Imports et configuration

In [None]:
import pandas as pd
import numpy as np
import json
import os
import math
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings('ignore')

# Reproductibilité
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f'Device: {DEVICE}')

In [None]:
# --- Chargement des données ---
LOCAL_DIR = os.path.join(os.path.abspath('..'), 'data')

PARQUET_NAME = "M1199_PAG_stride4_win108_test.parquet"
JSON_NAME = "M1199_PAG.json"

PARQUET_FILE = os.path.join(LOCAL_DIR, PARQUET_NAME)
JSON_FILE = os.path.join(LOCAL_DIR, JSON_NAME)

if not os.path.exists(PARQUET_FILE):
    raise FileNotFoundError(
        f"Données introuvables dans {LOCAL_DIR}/\n"
        f"Lancez d'abord: python download_data.py"
    )

print(f"Chargement depuis {LOCAL_DIR}/")
df = pd.read_parquet(PARQUET_FILE)
with open(JSON_FILE, "r") as f:
    params = json.load(f)

print(f"Shape: {df.shape}")

nGroups = params['nGroups']
nChannelsPerGroup = [params[f'group{g}']['nChannels'] for g in range(nGroups)]
print(f"nGroups={nGroups}, nChannelsPerGroup={nChannelsPerGroup}")

## 2. Chargement et filtrage

In [None]:
# Filtrage speedMask (on ne garde que les exemples en mouvement)
speed_masks = np.array([x[0] for x in df['speedMask']])
df_moving = df[speed_masks].reset_index(drop=True)
print(f'Exemples en mouvement : {len(df_moving)}')

## 3. Géométrie du U-maze, distance curviligne et labels de bras

Le squelette du U est défini par 3 segments :
1. Bras gauche : (0.15, 0) → (0.15, 0.85)
2. Couloir haut : (0.15, 0.85) → (0.85, 0.85)
3. Bras droit : (0.85, 0.85) → (0.85, 0)

Pour chaque position (x, y), on projette sur le segment le plus proche et on calcule la distance curviligne cumulative `d` ∈ [0, 1].

**Classification bras** : d < 0.5 → gauche (label 0), d ≥ 0.5 → droite (label 1).
Le seuil d=0.5 correspond au milieu du couloir haut.

In [None]:
# --- Définition du squelette du U ---
SKELETON_SEGMENTS = np.array([
    [0.15, 0.0, 0.15, 0.85],   # Segment 1 : bras gauche (bas → haut)
    [0.15, 0.85, 0.85, 0.85],  # Segment 2 : couloir haut (gauche → droite)
    [0.85, 0.85, 0.85, 0.0],   # Segment 3 : bras droit (haut → bas)
])

CORRIDOR_HALF_WIDTH = 0.15

# Longueurs cumulatives des segments
SEGMENT_LENGTHS = np.array([
    np.sqrt((s[2]-s[0])**2 + (s[3]-s[1])**2) for s in SKELETON_SEGMENTS
])
TOTAL_LENGTH = SEGMENT_LENGTHS.sum()  # 0.85 + 0.70 + 0.85 = 2.40
CUMULATIVE_LENGTHS = np.concatenate([[0], np.cumsum(SEGMENT_LENGTHS)])

print(f'Longueurs des segments : {SEGMENT_LENGTHS}')
print(f'Longueur totale du U : {TOTAL_LENGTH:.2f}')
print(f'Longueurs cumulatives : {CUMULATIVE_LENGTHS}')


def project_point_on_segment(px, py, x1, y1, x2, y2):
    """Projette un point (px, py) sur le segment [(x1,y1), (x2,y2)].
    
    Retourne:
        t: paramètre de projection clampé à [0, 1]
        dist: distance du point à la projection
        proj_x, proj_y: coordonnées de la projection
    """
    dx, dy = x2 - x1, y2 - y1
    seg_len_sq = dx**2 + dy**2
    if seg_len_sq < 1e-12:
        return 0.0, np.sqrt((px - x1)**2 + (py - y1)**2), x1, y1
    t = ((px - x1) * dx + (py - y1) * dy) / seg_len_sq
    t = np.clip(t, 0.0, 1.0)
    proj_x = x1 + t * dx
    proj_y = y1 + t * dy
    dist = np.sqrt((px - proj_x)**2 + (py - proj_y)**2)
    return t, dist, proj_x, proj_y


def compute_curvilinear_distance(x, y):
    """Calcule la distance curviligne normalisée d ∈ [0, 1] le long du U."""
    best_dist = np.inf
    best_d = 0.0
    
    for i, (x1, y1, x2, y2) in enumerate(SKELETON_SEGMENTS):
        t, dist, _, _ = project_point_on_segment(x, y, x1, y1, x2, y2)
        if dist < best_dist:
            best_dist = dist
            best_d = (CUMULATIVE_LENGTHS[i] + t * SEGMENT_LENGTHS[i]) / TOTAL_LENGTH
    
    return best_d


def compute_distance_to_skeleton(x, y):
    """Distance minimale du point (x, y) au squelette du U."""
    best_dist = np.inf
    for x1, y1, x2, y2 in SKELETON_SEGMENTS:
        _, dist, _, _ = project_point_on_segment(x, y, x1, y1, x2, y2)
        best_dist = min(best_dist, dist)
    return best_dist


# --- Calcul de d pour tous les exemples ---
positions = np.array([[x[0], x[1]] for x in df_moving['pos']], dtype=np.float32)
curvilinear_d = np.array([
    compute_curvilinear_distance(x, y) for x, y in positions
], dtype=np.float32)

# --- Labels de bras : 0 = gauche (d < 0.5), 1 = droite (d >= 0.5) ---
ARM_THRESHOLD = 0.5
arm_labels = (curvilinear_d >= ARM_THRESHOLD).astype(np.float32)

print(f'd curviligne : min={curvilinear_d.min():.4f}, max={curvilinear_d.max():.4f}, mean={curvilinear_d.mean():.4f}')
print(f'\nClassification bras (seuil d={ARM_THRESHOLD}) :')
print(f'  Bras gauche (d < {ARM_THRESHOLD}) : {(arm_labels == 0).sum()} ({(arm_labels == 0).mean():.1%})')
print(f'  Bras droit  (d >= {ARM_THRESHOLD}) : {(arm_labels == 1).sum()} ({(arm_labels == 1).mean():.1%})')

# Vérification : distance au squelette
dist_to_skel = np.array([compute_distance_to_skeleton(x, y) for x, y in positions])
print(f'\nDistance au squelette : mean={dist_to_skel.mean():.4f}, max={dist_to_skel.max():.4f}')
print(f'  % dans le couloir (dist < {CORRIDOR_HALF_WIDTH}) : {(dist_to_skel < CORRIDOR_HALF_WIDTH).mean():.1%}')

In [None]:
# --- Visualisation : d + classification bras ---
fig, axes = plt.subplots(1, 3, figsize=(21, 6))

# 1. Squelette du U sur les positions
axes[0].scatter(positions[:, 0], positions[:, 1], c='lightgray', s=1, alpha=0.3)
for x1, y1, x2, y2 in SKELETON_SEGMENTS:
    axes[0].plot([x1, x2], [y1, y2], 'r-', linewidth=3)
axes[0].set_xlabel('X'); axes[0].set_ylabel('Y')
axes[0].set_title('Squelette du U sur les positions')
axes[0].set_aspect('equal')

# 2. Positions colorées par d
sc = axes[1].scatter(positions[:, 0], positions[:, 1], c=curvilinear_d, s=1, alpha=0.5, cmap='viridis')
plt.colorbar(sc, ax=axes[1], label='d (distance curviligne normalisée)')
axes[1].set_xlabel('X'); axes[1].set_ylabel('Y')
axes[1].set_title('Distance curviligne d le long du U')
axes[1].set_aspect('equal')

# 3. Classification bras gauche (bleu) vs droit (rouge)
left_mask = arm_labels == 0
right_mask = arm_labels == 1
axes[2].scatter(positions[left_mask, 0], positions[left_mask, 1], c='blue', s=1, alpha=0.3, label='Gauche (d<0.5)')
axes[2].scatter(positions[right_mask, 0], positions[right_mask, 1], c='red', s=1, alpha=0.3, label='Droite (d≥0.5)')
axes[2].axvline(x=0.5, color='gray', linestyle='--', alpha=0.5)
axes[2].set_xlabel('X'); axes[2].set_ylabel('Y')
axes[2].set_title(f'Classification bras (seuil d={ARM_THRESHOLD})')
axes[2].legend(markerscale=10)
axes[2].set_aspect('equal')

plt.tight_layout()
plt.show()

# Histogramme de d
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(curvilinear_d, bins=100, alpha=0.7, edgecolor='black', linewidth=0.3)
ax.axvline(x=ARM_THRESHOLD, color='red', linestyle='--', linewidth=2, label=f'Seuil d={ARM_THRESHOLD}')
ax.set_xlabel('d (distance curviligne)'); ax.set_ylabel('Nombre d\'exemples')
ax.set_title('Distribution de la distance curviligne')
ax.legend()
plt.tight_layout()
plt.show()

## 4. Preprocessing : reconstruction de la séquence chronologique

In [None]:
def reconstruct_sequence(row, nGroups, nChannelsPerGroup, max_seq_len=128):
    """
    Reconstruit la séquence chronologique de spikes.
    """
    groups = row['groups']
    length = min(len(groups), max_seq_len)
    
    waveforms = {}
    for g in range(nGroups):
        nCh = nChannelsPerGroup[g]
        raw = row[f'group{g}']
        waveforms[g] = raw.reshape(-1, nCh, 32)
    
    seq_waveforms = []
    seq_shank_ids = []
    
    for t in range(length):
        g = int(groups[t])
        idx = int(row[f'indices{g}'][t])
        if idx > 0 and idx <= waveforms[g].shape[0]:
            seq_waveforms.append((waveforms[g][idx - 1], g))
            seq_shank_ids.append(g)
    
    return seq_waveforms, seq_shank_ids

# Test rapide
wf, sids = reconstruct_sequence(df_moving.iloc[0], nGroups, nChannelsPerGroup)
print(f'Premier exemple : {len(wf)} spikes réels dans la séquence')
print(f'Shanks utilisés : {set(sids)}')
print(f'Premier spike : shank={wf[0][1]}, shape={wf[0][0].shape}')

## 5. Dataset PyTorch (avec arm_label)

In [None]:
MAX_SEQ_LEN = 128
MAX_CHANNELS = max(nChannelsPerGroup)  # 6

class SpikeSequenceDataset(Dataset):
    def __init__(self, dataframe, nGroups, nChannelsPerGroup, curvilinear_d, arm_labels, max_seq_len=MAX_SEQ_LEN):
        self.df = dataframe
        self.nGroups = nGroups
        self.nChannelsPerGroup = nChannelsPerGroup
        self.max_seq_len = max_seq_len
        
        # Pré-extraire les targets
        self.targets = np.array([[x[0], x[1]] for x in dataframe['pos']], dtype=np.float32)
        self.curvilinear_d = curvilinear_d.astype(np.float32)
        self.arm_labels = arm_labels.astype(np.float32)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        seq, shank_ids = reconstruct_sequence(row, self.nGroups, self.nChannelsPerGroup, self.max_seq_len)
        
        seq_len = len(seq)
        if seq_len == 0:
            seq_len = 1
            waveforms = np.zeros((1, MAX_CHANNELS, 32), dtype=np.float32)
            shank_ids_arr = np.array([0], dtype=np.int64)
        else:
            waveforms = np.zeros((seq_len, MAX_CHANNELS, 32), dtype=np.float32)
            shank_ids_arr = np.array(shank_ids, dtype=np.int64)
            for t, (wf, g) in enumerate(seq):
                nCh = wf.shape[0]
                waveforms[t, :nCh, :] = wf
        
        target = self.targets[idx]
        d = self.curvilinear_d[idx]
        arm = self.arm_labels[idx]
        return {
            'waveforms': torch.from_numpy(waveforms),
            'shank_ids': torch.from_numpy(shank_ids_arr),
            'seq_len': seq_len,
            'target': torch.from_numpy(target),
            'd': torch.tensor(d, dtype=torch.float32),
            'arm': torch.tensor(arm, dtype=torch.float32)
        }


def collate_fn(batch):
    """Collate avec padding dynamique."""
    max_len = max(item['seq_len'] for item in batch)
    batch_size = len(batch)
    
    waveforms = torch.zeros(batch_size, max_len, MAX_CHANNELS, 32)
    shank_ids = torch.zeros(batch_size, max_len, dtype=torch.long)
    mask = torch.ones(batch_size, max_len, dtype=torch.bool)
    targets = torch.stack([item['target'] for item in batch])
    d_targets = torch.stack([item['d'] for item in batch])
    arm_targets = torch.stack([item['arm'] for item in batch])
    
    for i, item in enumerate(batch):
        sl = item['seq_len']
        waveforms[i, :sl] = item['waveforms']
        shank_ids[i, :sl] = item['shank_ids']
        mask[i, :sl] = False
    
    return {
        'waveforms': waveforms,
        'shank_ids': shank_ids,
        'mask': mask,
        'targets': targets,
        'd_targets': d_targets,
        'arm_targets': arm_targets
    }

print('Dataset et collate_fn définis (avec arm_label + d curviligne).')

## 6. Loss de faisabilité (pénalisation hors labyrinthe)

Identique au notebook 02g.

In [None]:
def train_epoch(model, loader, optimizer, scheduler, criterion_bce, criterion_nll, criterion_d, feas_loss, device):
    model.train()
    total_loss = 0
    total_cls_loss = 0
    total_pos_loss = 0
    total_d_loss = 0
    total_feas_loss = 0
    total_cls_correct = 0
    total_samples = 0
    n_batches = 0
    
    for batch in loader:
        wf = batch['waveforms'].to(device)
        sid = batch['shank_ids'].to(device)
        mask = batch['mask'].to(device)
        targets = batch['targets'].to(device)
        d_targets = batch['d_targets'].to(device)
        arm_targets = batch['arm_targets'].to(device)
        
        optimizer.zero_grad()
        cls_logit, mu_left, sigma_left, mu_right, sigma_right, d_pred = model(wf, sid, mask)
        
        # --- Loss classification (BCE) ---
        loss_cls = criterion_bce(cls_logit.squeeze(-1), arm_targets)
        
        # --- Loss position gauche et droite (Gaussian NLL, masquées) ---
        left_mask = (arm_targets == 0)
        right_mask = (arm_targets == 1)
        
        loss_pos = torch.tensor(0.0, device=device)
        if left_mask.any():
            loss_left = criterion_nll(
                mu_left[left_mask], targets[left_mask], sigma_left[left_mask] ** 2
            )
            loss_pos = loss_pos + loss_left
        
        if right_mask.any():
            loss_right = criterion_nll(
                mu_right[right_mask], targets[right_mask], sigma_right[right_mask] ** 2
            )
            loss_pos = loss_pos + loss_right
        
        # --- Loss curviligne (MSE) ---
        loss_d = criterion_d(d_pred.squeeze(-1), d_targets)
        
        # --- Loss faisabilité (sur la prédiction combinée) ---
        p_right = torch.sigmoid(cls_logit)  # (batch, 1) avec gradient
        p_left = 1.0 - p_right
        mu_combined = p_right * mu_right + p_left * mu_left
        loss_feas = feas_loss(mu_combined)
        
        # --- Loss totale ---
        loss = loss_cls + loss_pos + LAMBDA_D * loss_d + LAMBDA_FEAS * loss_feas
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        total_cls_loss += loss_cls.item()
        total_pos_loss += loss_pos.item()
        total_d_loss += loss_d.item()
        total_feas_loss += loss_feas.item()
        
        # Accuracy classification
        with torch.no_grad():
            preds_cls = (torch.sigmoid(cls_logit.squeeze(-1)) >= 0.5).float()
            total_cls_correct += (preds_cls == arm_targets).sum().item()
            total_samples += len(arm_targets)
        
        n_batches += 1
    
    cls_acc = total_cls_correct / total_samples
    return (total_loss / n_batches, total_cls_loss / n_batches, total_pos_loss / n_batches,
            total_d_loss / n_batches, total_feas_loss / n_batches, cls_acc)


@torch.no_grad()
def eval_epoch(model, loader, criterion_bce, criterion_nll, criterion_d, feas_loss, device):
    model.eval()
    total_loss = 0
    total_cls_loss = 0
    total_pos_loss = 0
    total_d_loss = 0
    total_feas_loss = 0
    total_cls_correct = 0
    total_samples = 0
    n_batches = 0
    all_mu = []
    all_sigma = []
    all_p_right = []
    all_d_pred = []
    all_targets = []
    all_d_targets = []
    all_arm_targets = []
    
    for batch in loader:
        wf = batch['waveforms'].to(device)
        sid = batch['shank_ids'].to(device)
        mask = batch['mask'].to(device)
        targets = batch['targets'].to(device)
        d_targets = batch['d_targets'].to(device)
        arm_targets = batch['arm_targets'].to(device)
        
        # Prédiction combinée
        mu, sigma, p_right, d_pred = model.predict(wf, sid, mask)
        
        # Recalculer les sorties brutes pour la loss
        cls_logit, mu_left, sigma_left, mu_right, sigma_right, _ = model(wf, sid, mask)
        
        # Loss classification
        loss_cls = criterion_bce(cls_logit.squeeze(-1), arm_targets)
        
        # Loss position (masquée par bras)
        left_mask = (arm_targets == 0)
        right_mask = (arm_targets == 1)
        
        loss_pos = torch.tensor(0.0, device=device)
        if left_mask.any():
            loss_pos = loss_pos + criterion_nll(
                mu_left[left_mask], targets[left_mask], sigma_left[left_mask] ** 2
            )
        if right_mask.any():
            loss_pos = loss_pos + criterion_nll(
                mu_right[right_mask], targets[right_mask], sigma_right[right_mask] ** 2
            )
        
        loss_d = criterion_d(d_pred.squeeze(-1), d_targets)
        loss_feas = feas_loss(mu)
        loss = loss_cls + loss_pos + LAMBDA_D * loss_d + LAMBDA_FEAS * loss_feas
        
        total_loss += loss.item()
        total_cls_loss += loss_cls.item()
        total_pos_loss += loss_pos.item()
        total_d_loss += loss_d.item()
        total_feas_loss += loss_feas.item()
        
        preds_cls = (p_right.squeeze(-1) >= 0.5).float()
        total_cls_correct += (preds_cls == arm_targets).sum().item()
        total_samples += len(arm_targets)
        
        n_batches += 1
        all_mu.append(mu.cpu().numpy())
        all_sigma.append(sigma.cpu().numpy())
        all_p_right.append(p_right.cpu().numpy())
        all_d_pred.append(d_pred.cpu().numpy())
        all_targets.append(targets.cpu().numpy())
        all_d_targets.append(d_targets.cpu().numpy())
        all_arm_targets.append(arm_targets.cpu().numpy())
    
    cls_acc = total_cls_correct / total_samples
    return (total_loss / n_batches, total_cls_loss / n_batches, total_pos_loss / n_batches,
            total_d_loss / n_batches, total_feas_loss / n_batches, cls_acc,
            np.concatenate(all_mu), np.concatenate(all_sigma),
            np.concatenate(all_p_right), np.concatenate(all_d_pred),
            np.concatenate(all_targets), np.concatenate(all_d_targets),
            np.concatenate(all_arm_targets))

## 7. Architecture du modèle hiérarchique

**Backbone partagé** : identique à 02g (encodeurs CNN par shank → Transformer → masked avg pooling).

**Heads spécialisés** :
- `cls_head` : classification binaire (gauche/droite) → P(droite)
- `mu_head_left` + `log_sigma_head_left` : régression position + incertitude pour le bras gauche
- `mu_head_right` + `log_sigma_head_right` : idem pour le bras droit
- `d_head` : prédiction de d (tâche auxiliaire conservée)

**Combinaison à l'inférence** :
```
p = P(droite)
mu = p * mu_right + (1-p) * mu_left
sigma² = p * (sigma_right² + mu_right²) + (1-p) * (sigma_left² + mu_left²) - mu²
```

In [None]:
class SpikeEncoder(nn.Module):
    """Encode un waveform (MAX_CH, 32) en un vecteur de dimension embed_dim."""
    
    def __init__(self, n_channels, embed_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(n_channels, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, embed_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
    
    def forward(self, x):
        return self.conv(x).squeeze(-1)


class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""
    
    def __init__(self, embed_dim, max_len=256):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class SpikeTransformerHierarchical(nn.Module):
    """Transformer hiérarchique : classification bras + régression conditionnelle.
    
    Le backbone Transformer est partagé. Deux heads de régression spécialisés
    (gauche et droite) sont combinés via la probabilité de classification.
    """
    
    def __init__(self, nGroups, nChannelsPerGroup, embed_dim=64, nhead=4, 
                 num_layers=2, dropout=0.2, spike_dropout=0.15, noise_std=0.5,
                 max_channels=MAX_CHANNELS):
        super().__init__()
        self.nGroups = nGroups
        self.embed_dim = embed_dim
        self.max_channels = max_channels
        self.spike_dropout = spike_dropout
        self.noise_std = noise_std
        
        # Un encodeur par shank
        self.spike_encoders = nn.ModuleList([
            SpikeEncoder(max_channels, embed_dim) for _ in range(nGroups)
        ])
        
        # Embedding de shank
        self.shank_embedding = nn.Embedding(nGroups, embed_dim)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(embed_dim)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=nhead, dim_feedforward=embed_dim * 4,
            dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers, enable_nested_tensor=False
        )
        
        # --- Tête de classification : P(bras droit) ---
        self.cls_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 1)
        )
        
        # --- Tête bras gauche : position + incertitude ---
        self.mu_head_left = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 2)
        )
        self.log_sigma_head_left = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 2)
        )
        
        # --- Tête bras droit : position + incertitude ---
        self.mu_head_right = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 2)
        )
        self.log_sigma_head_right = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 2)
        )
        
        # --- Tête distance curviligne ---
        self.d_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 1),
            nn.Sigmoid()
        )
    
    def _apply_spike_dropout(self, mask):
        if not self.training or self.spike_dropout <= 0:
            return mask
        drop_mask = torch.rand_like(mask.float()) < self.spike_dropout
        active = ~mask
        new_drops = drop_mask & active
        remaining = active & ~new_drops
        n_remaining = remaining.sum(dim=1)
        all_dropped = n_remaining == 0
        if all_dropped.any():
            new_drops[all_dropped] = False
        return mask | new_drops
    
    def _apply_waveform_noise(self, waveforms):
        if not self.training or self.noise_std <= 0:
            return waveforms
        noise = torch.randn_like(waveforms) * self.noise_std
        return waveforms + noise
    
    def forward(self, waveforms, shank_ids, mask):
        """
        Returns:
            cls_logit: (batch, 1) - logit de classification P(droite)
            mu_left: (batch, 2) - position prédite bras gauche
            sigma_left: (batch, 2) - incertitude bras gauche
            mu_right: (batch, 2) - position prédite bras droit
            sigma_right: (batch, 2) - incertitude bras droit
            d_pred: (batch, 1) - distance curviligne prédite
        """
        batch_size, seq_len = waveforms.shape[:2]
        
        # Data augmentation
        mask = self._apply_spike_dropout(mask)
        waveforms = self._apply_waveform_noise(waveforms)
        
        # Encode chaque spike
        embeddings = torch.zeros(batch_size, seq_len, self.embed_dim, device=waveforms.device)
        for g in range(self.nGroups):
            group_mask = (shank_ids == g) & (~mask)
            if group_mask.any():
                group_wf = waveforms[group_mask]
                group_emb = self.spike_encoders[g](group_wf)
                embeddings[group_mask] = group_emb
        
        # Shank embedding + Positional encoding
        shank_emb = self.shank_embedding(shank_ids)
        embeddings = embeddings + shank_emb
        embeddings = self.pos_encoding(embeddings)
        
        # Transformer
        encoded = self.transformer(embeddings, src_key_padding_mask=mask)
        
        # Masked average pooling
        active_mask = (~mask).unsqueeze(-1).float()
        pooled = (encoded * active_mask).sum(dim=1) / (active_mask.sum(dim=1) + 1e-8)
        
        # --- Sorties ---
        # Classification
        cls_logit = self.cls_head(pooled)  # (batch, 1) - logit brut
        
        # Bras gauche
        mu_left = self.mu_head_left(pooled)
        sigma_left = torch.exp(self.log_sigma_head_left(pooled))
        
        # Bras droit
        mu_right = self.mu_head_right(pooled)
        sigma_right = torch.exp(self.log_sigma_head_right(pooled))
        
        # Distance curviligne
        d_pred = self.d_head(pooled)
        
        return cls_logit, mu_left, sigma_left, mu_right, sigma_right, d_pred
    
    def predict(self, waveforms, shank_ids, mask):
        """Prédiction combinée via mélange pondéré par P(droite).
        
        Returns:
            mu: (batch, 2) - position combinée
            sigma: (batch, 2) - incertitude combinée
            p_right: (batch, 1) - probabilité bras droit
            d_pred: (batch, 1) - distance curviligne
        """
        cls_logit, mu_left, sigma_left, mu_right, sigma_right, d_pred = self.forward(
            waveforms, shank_ids, mask
        )
        
        p_right = torch.sigmoid(cls_logit)  # (batch, 1)
        p_left = 1.0 - p_right
        
        # Mélange pondéré des moyennes
        mu = p_right * mu_right + p_left * mu_left
        
        # Variance totale (loi de la variance totale pour mélange)
        # Var = E[Var] + Var[E] = p*sigma² + p*(mu - mu_combined)²
        var_combined = (
            p_right * (sigma_right ** 2 + mu_right ** 2)
            + p_left * (sigma_left ** 2 + mu_left ** 2)
            - mu ** 2
        )
        sigma = torch.sqrt(var_combined.clamp(min=1e-8))
        
        return mu, sigma, p_right, d_pred


# Test rapide
SPIKE_DROPOUT = 0.15
NOISE_STD = 0.5
model = SpikeTransformerHierarchical(nGroups, nChannelsPerGroup, embed_dim=64, nhead=4, num_layers=2,
                                      spike_dropout=SPIKE_DROPOUT, noise_std=NOISE_STD)
n_params = sum(p.numel() for p in model.parameters())
print(f'Modèle créé : {n_params:,} paramètres')
print(f'Têtes : classification (gauche/droite) + 2 régressions conditionnelles + d curviligne')
print(f'\nComparaison avec 02g :')
n_params_02g = n_params - sum(
    p.numel() for name, p in model.named_parameters() 
    if 'left' in name or 'right' in name or 'cls_head' in name
) + sum(
    p.numel() for name, p in model.named_parameters() 
    if 'mu_head_left' in name  # proxy : une seule tête mu + sigma dans 02g
) + sum(
    p.numel() for name, p in model.named_parameters() 
    if 'log_sigma_head_left' in name
)
print(f'  02g : ~{n_params_02g:,} paramètres (1 tête position)')
print(f'  02h : {n_params:,} paramètres (+{n_params - n_params_02g:,} pour les heads supplémentaires)')

## 8. Split et DataLoaders

In [None]:
from sklearn.model_selection import KFold

# Split temporel 90/10
split_idx = int(len(df_moving) * 0.9)
df_train_full = df_moving.iloc[:split_idx].reset_index(drop=True)
df_test = df_moving.iloc[split_idx:].reset_index(drop=True)
d_train_full = curvilinear_d[:split_idx]
d_test = curvilinear_d[split_idx:]
arm_train_full = arm_labels[:split_idx]
arm_test = arm_labels[split_idx:]

print(f'Train (full) : {len(df_train_full)} exemples')
print(f'Test         : {len(df_test)} exemples')
print(f'\nDistribution bras (train) : gauche={int((arm_train_full == 0).sum())}, droite={int((arm_train_full == 1).sum())}')
print(f'Distribution bras (test)  : gauche={int((arm_test == 0).sum())}, droite={int((arm_test == 1).sum())}')

# KFold
N_FOLDS = 5
kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=41)

BATCH_SIZE = 64

# Test loader
test_dataset = SpikeSequenceDataset(df_test, nGroups, nChannelsPerGroup, d_test, arm_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         collate_fn=collate_fn, num_workers=0)

for fold, (train_idx, val_idx) in enumerate(kf.split(df_train_full)):
    print(f'  Fold {fold+1}: train={len(train_idx)}, val={len(val_idx)}')

print(f'\nTest: {len(test_dataset)} exemples, {len(test_loader)} batches')

## 9. Entraînement (5-Fold Cross-Validation)

**Loss combinée** :
```
L = L_cls (BCE sur classification bras)
  + L_pos_left (Gaussian NLL, masqué sur bras gauche)
  + L_pos_right (Gaussian NLL, masqué sur bras droit)
  + λ_d * L_curvilinear (MSE sur d)
  + λ_feas * L_feasibility (pénalité hors labyrinthe)
```

In [None]:
# Hyperparamètres
EMBED_DIM = 64
NHEAD = 4
NUM_LAYERS = 2
DROPOUT = 0.2
SPIKE_DROPOUT = 0.15
NOISE_STD = 0.5
LR = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 30
PATIENCE = 7

# Poids des loss auxiliaires
LAMBDA_D = 1.0       # poids de la loss curviligne
LAMBDA_FEAS = 10.0   # poids de la pénalité hors labyrinthe

print(f'Hyperparamètres : embed_dim={EMBED_DIM}, nhead={NHEAD}, layers={NUM_LAYERS}, dropout={DROPOUT}')
print(f'Data augmentation : spike dropout={SPIKE_DROPOUT:.0%}, gaussian noise std={NOISE_STD}')
print(f'Loss : BCE(cls) + GaussianNLL(left, masqué) + GaussianNLL(right, masqué) + {LAMBDA_D}*MSE(d) + {LAMBDA_FEAS}*Feasibility')
print(f'Entraînement : {EPOCHS} epochs max, patience={PATIENCE}, LR={LR}')
print(f'Device : {DEVICE}')

In [None]:
def train_epoch(model, loader, optimizer, scheduler, criterion_bce, criterion_nll, criterion_d, feas_loss, device):
    model.train()
    total_loss = 0
    total_cls_loss = 0
    total_pos_loss = 0
    total_d_loss = 0
    total_feas_loss = 0
    total_cls_correct = 0
    total_samples = 0
    n_batches = 0
    
    for batch in loader:
        wf = batch['waveforms'].to(device)
        sid = batch['shank_ids'].to(device)
        mask = batch['mask'].to(device)
        targets = batch['targets'].to(device)
        d_targets = batch['d_targets'].to(device)
        arm_targets = batch['arm_targets'].to(device)
        
        optimizer.zero_grad()
        cls_logit, mu_left, sigma_left, mu_right, sigma_right, d_pred = model(wf, sid, mask)
        
        # --- Loss classification (BCE) ---
        loss_cls = criterion_bce(cls_logit.squeeze(-1), arm_targets)
        
        # --- Loss position gauche (Gaussian NLL, masqué) ---
        left_mask = (arm_targets == 0)
        right_mask = (arm_targets == 1)
        
        loss_pos = torch.tensor(0.0, device=device)
        if left_mask.any():
            loss_left = criterion_nll(
                mu_left[left_mask], targets[left_mask], sigma_left[left_mask] ** 2
            )
            loss_pos = loss_pos + loss_left
        
        if right_mask.any():
            loss_right = criterion_nll(
                mu_right[right_mask], targets[right_mask], sigma_right[right_mask] ** 2
            )
            loss_pos = loss_pos + loss_right
        
        # --- Loss curviligne (MSE) ---
        loss_d = criterion_d(d_pred.squeeze(-1), d_targets)
        
        # --- Loss faisabilité (sur la prédiction combinée) ---
        with torch.no_grad():
            p_right = torch.sigmoid(cls_logit)
        p_right_detached = torch.sigmoid(cls_logit)  # garder le gradient pour mu
        p_left = 1.0 - p_right_detached
        mu_combined = p_right_detached * mu_right + p_left * mu_left
        loss_feas = feas_loss(mu_combined)
        
        # --- Loss totale ---
        loss = loss_cls + loss_pos + LAMBDA_D * loss_d + LAMBDA_FEAS * loss_feas
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        total_cls_loss += loss_cls.item()
        total_pos_loss += loss_pos.item()
        total_d_loss += loss_d.item()
        total_feas_loss += loss_feas.item()
        
        # Accuracy classification
        with torch.no_grad():
            preds_cls = (torch.sigmoid(cls_logit.squeeze(-1)) >= 0.5).float()
            total_cls_correct += (preds_cls == arm_targets).sum().item()
            total_samples += len(arm_targets)
        
        n_batches += 1
    
    cls_acc = total_cls_correct / total_samples
    return (total_loss / n_batches, total_cls_loss / n_batches, total_pos_loss / n_batches,
            total_d_loss / n_batches, total_feas_loss / n_batches, cls_acc)


@torch.no_grad()
def eval_epoch(model, loader, criterion_bce, criterion_nll, criterion_d, feas_loss, device):
    model.eval()
    total_loss = 0
    total_cls_loss = 0
    total_pos_loss = 0
    total_d_loss = 0
    total_feas_loss = 0
    total_cls_correct = 0
    total_samples = 0
    n_batches = 0
    all_mu = []
    all_sigma = []
    all_p_right = []
    all_d_pred = []
    all_targets = []
    all_d_targets = []
    all_arm_targets = []
    
    for batch in loader:
        wf = batch['waveforms'].to(device)
        sid = batch['shank_ids'].to(device)
        mask = batch['mask'].to(device)
        targets = batch['targets'].to(device)
        d_targets = batch['d_targets'].to(device)
        arm_targets = batch['arm_targets'].to(device)
        
        # Prédiction combinée
        mu, sigma, p_right, d_pred = model.predict(wf, sid, mask)
        
        # Recalculer les sorties brutes pour la loss
        cls_logit, mu_left, sigma_left, mu_right, sigma_right, _ = model(wf, sid, mask)
        
        # Loss classification
        loss_cls = criterion_bce(cls_logit.squeeze(-1), arm_targets)
        
        # Loss position (masquée par bras)
        left_mask = (arm_targets == 0)
        right_mask = (arm_targets == 1)
        
        loss_pos = torch.tensor(0.0, device=device)
        if left_mask.any():
            loss_pos = loss_pos + criterion_nll(
                mu_left[left_mask], targets[left_mask], sigma_left[left_mask] ** 2
            )
        if right_mask.any():
            loss_pos = loss_pos + criterion_nll(
                mu_right[right_mask], targets[right_mask], sigma_right[right_mask] ** 2
            )
        
        loss_d = criterion_d(d_pred.squeeze(-1), d_targets)
        loss_feas = feas_loss(mu)
        loss = loss_cls + loss_pos + LAMBDA_D * loss_d + LAMBDA_FEAS * loss_feas
        
        total_loss += loss.item()
        total_cls_loss += loss_cls.item()
        total_pos_loss += loss_pos.item()
        total_d_loss += loss_d.item()
        total_feas_loss += loss_feas.item()
        
        preds_cls = (p_right.squeeze(-1) >= 0.5).float()
        total_cls_correct += (preds_cls == arm_targets).sum().item()
        total_samples += len(arm_targets)
        
        n_batches += 1
        all_mu.append(mu.cpu().numpy())
        all_sigma.append(sigma.cpu().numpy())
        all_p_right.append(p_right.cpu().numpy())
        all_d_pred.append(d_pred.cpu().numpy())
        all_targets.append(targets.cpu().numpy())
        all_d_targets.append(d_targets.cpu().numpy())
        all_arm_targets.append(arm_targets.cpu().numpy())
    
    cls_acc = total_cls_correct / total_samples
    return (total_loss / n_batches, total_cls_loss / n_batches, total_pos_loss / n_batches,
            total_d_loss / n_batches, total_feas_loss / n_batches, cls_acc,
            np.concatenate(all_mu), np.concatenate(all_sigma),
            np.concatenate(all_p_right), np.concatenate(all_d_pred),
            np.concatenate(all_targets), np.concatenate(all_d_targets),
            np.concatenate(all_arm_targets))

In [None]:
# Boucle d'entraînement avec KFold
fold_results = []
all_train_losses = {}
all_val_losses = {}

for fold, (train_idx, val_idx) in enumerate(kf.split(df_train_full)):
    print(f'\n{"="*60}')
    print(f'FOLD {fold+1}/{N_FOLDS}')
    print(f'{"="*60}')
    
    # Datasets pour ce fold
    df_fold_train = df_train_full.iloc[train_idx].reset_index(drop=True)
    df_fold_val = df_train_full.iloc[val_idx].reset_index(drop=True)
    d_fold_train = d_train_full[train_idx]
    d_fold_val = d_train_full[val_idx]
    arm_fold_train = arm_train_full[train_idx]
    arm_fold_val = arm_train_full[val_idx]
    
    fold_train_dataset = SpikeSequenceDataset(df_fold_train, nGroups, nChannelsPerGroup, d_fold_train, arm_fold_train)
    fold_val_dataset = SpikeSequenceDataset(df_fold_val, nGroups, nChannelsPerGroup, d_fold_val, arm_fold_val)
    
    fold_train_loader = DataLoader(fold_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                   collate_fn=collate_fn, num_workers=0)
    fold_val_loader = DataLoader(fold_val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                 collate_fn=collate_fn, num_workers=0)
    
    print(f'  Train: {len(fold_train_dataset)}, Val: {len(fold_val_dataset)}')
    
    # Nouveau modèle
    model = SpikeTransformerHierarchical(
        nGroups, nChannelsPerGroup,
        embed_dim=EMBED_DIM, nhead=NHEAD, num_layers=NUM_LAYERS,
        dropout=DROPOUT, spike_dropout=SPIKE_DROPOUT, noise_std=NOISE_STD
    ).to(DEVICE)
    
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LR, epochs=EPOCHS, steps_per_epoch=len(fold_train_loader)
    )
    criterion_bce = nn.BCEWithLogitsLoss()
    criterion_nll = nn.GaussianNLLLoss()
    criterion_d = nn.MSELoss()
    feas_loss_fn = FeasibilityLoss(SKELETON_SEGMENTS, CORRIDOR_HALF_WIDTH).to(DEVICE)
    
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    model_path = f'../outputs/best_transformer_02h_fold{fold+1}.pt'
    
    for epoch in range(EPOCHS):
        t_loss, t_cls, t_pos, t_d, t_feas, t_acc = train_epoch(
            model, fold_train_loader, optimizer, scheduler,
            criterion_bce, criterion_nll, criterion_d, feas_loss_fn, DEVICE
        )
        v_loss, v_cls, v_pos, v_d, v_feas, v_acc, _, _, _, _, _, _, _ = eval_epoch(
            model, fold_val_loader, criterion_bce, criterion_nll, criterion_d, feas_loss_fn, DEVICE
        )
        
        train_losses.append(t_loss)
        val_losses.append(v_loss)
        
        if epoch % 5 == 0 or epoch == EPOCHS - 1:
            print(f'  Epoch {epoch+1:02d}/{EPOCHS} | Train: {t_loss:.4f} (cls={t_cls:.4f}, pos={t_pos:.4f}, d={t_d:.5f}, feas={t_feas:.6f}, acc={t_acc:.1%}) | Val: {v_loss:.4f} (acc={v_acc:.1%})')
        
        if v_loss < best_val_loss:
            best_val_loss = v_loss
            patience_counter = 0
            torch.save(model.state_dict(), model_path)
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f'  Early stopping a epoch {epoch+1}')
                break
    
    all_train_losses[fold] = train_losses
    all_val_losses[fold] = val_losses
    
    # Évaluer sur la validation de ce fold
    model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True))
    (_, _, _, _, _, val_acc, val_mu, val_sigma, val_p_right, val_d_pred, 
     val_targets, val_d_targets, val_arm_targets) = eval_epoch(
        model, fold_val_loader, criterion_bce, criterion_nll, criterion_d, feas_loss_fn, DEVICE
    )
    val_eucl = np.sqrt((val_targets[:, 0] - val_mu[:, 0])**2 + (val_targets[:, 1] - val_mu[:, 1])**2)
    val_d_mae = np.abs(val_d_targets - val_d_pred.squeeze()).mean()
    
    # % de prédictions hors labyrinthe
    val_dist_to_skel = np.array([
        compute_distance_to_skeleton(val_mu[i, 0], val_mu[i, 1]) for i in range(len(val_mu))
    ])
    pct_outside = (val_dist_to_skel > CORRIDOR_HALF_WIDTH).mean()
    
    fold_results.append({
        'fold': fold + 1,
        'best_val_loss': best_val_loss,
        'val_eucl_mean': val_eucl.mean(),
        'val_r2_x': r2_score(val_targets[:, 0], val_mu[:, 0]),
        'val_r2_y': r2_score(val_targets[:, 1], val_mu[:, 1]),
        'val_d_mae': val_d_mae,
        'val_cls_acc': val_acc,
        'val_pct_outside': pct_outside,
        'epochs': len(train_losses),
    })
    print(f'  Best val loss: {best_val_loss:.5f} | Eucl: {val_eucl.mean():.4f} | R2: X={fold_results[-1]["val_r2_x"]:.4f}, Y={fold_results[-1]["val_r2_y"]:.4f}')
    print(f'  d MAE: {val_d_mae:.4f} | Cls acc: {val_acc:.1%} | Hors labyrinthe: {pct_outside:.1%}')

# Résumé
print(f'\n{"="*60}')
print(f'RESUME CROSS-VALIDATION ({N_FOLDS} folds)')
print(f'{"="*60}')
for r in fold_results:
    print(f'  Fold {r["fold"]}: Loss={r["best_val_loss"]:.5f} | Eucl={r["val_eucl_mean"]:.4f} | R2_X={r["val_r2_x"]:.4f} | R2_Y={r["val_r2_y"]:.4f} | d_MAE={r["val_d_mae"]:.4f} | cls_acc={r["val_cls_acc"]:.1%} | hors={r["val_pct_outside"]:.1%}')

mean_eucl = np.mean([r['val_eucl_mean'] for r in fold_results])
std_eucl = np.std([r['val_eucl_mean'] for r in fold_results])
mean_r2_x = np.mean([r['val_r2_x'] for r in fold_results])
mean_r2_y = np.mean([r['val_r2_y'] for r in fold_results])
mean_d_mae = np.mean([r['val_d_mae'] for r in fold_results])
mean_cls_acc = np.mean([r['val_cls_acc'] for r in fold_results])
mean_outside = np.mean([r['val_pct_outside'] for r in fold_results])
print(f'\n  Moyenne : Eucl={mean_eucl:.4f} (+/- {std_eucl:.4f}) | R2_X={mean_r2_x:.4f} | R2_Y={mean_r2_y:.4f} | d_MAE={mean_d_mae:.4f} | cls_acc={mean_cls_acc:.1%} | hors={mean_outside:.1%}')

In [None]:
# Courbes d'entraînement par fold
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

colors = plt.cm.tab10(np.linspace(0, 1, N_FOLDS))

for fold in range(N_FOLDS):
    axes[0].plot(all_train_losses[fold], color=colors[fold], linewidth=1.5, label=f'Fold {fold+1}')
    axes[1].plot(all_val_losses[fold], color=colors[fold], linewidth=1.5, label=f'Fold {fold+1}')

axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss totale')
axes[0].set_title('Train Loss par fold'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Loss totale')
axes[1].set_title('Validation Loss par fold'); axes[1].legend(); axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 10. Évaluation finale sur le test set

In [None]:
# Évaluation ensemble des 5 folds sur le test set
criterion_bce = nn.BCEWithLogitsLoss()
criterion_nll = nn.GaussianNLLLoss()
criterion_d = nn.MSELoss()
feas_loss_fn = FeasibilityLoss(SKELETON_SEGMENTS, CORRIDOR_HALF_WIDTH).to(DEVICE)

all_fold_mu = []
all_fold_sigma = []
all_fold_p_right = []
all_fold_d = []

for fold in range(N_FOLDS):
    model_path = f'../outputs/best_transformer_02h_fold{fold+1}.pt'
    model = SpikeTransformerHierarchical(
        nGroups, nChannelsPerGroup,
        embed_dim=EMBED_DIM, nhead=NHEAD, num_layers=NUM_LAYERS,
        dropout=DROPOUT, spike_dropout=SPIKE_DROPOUT, noise_std=NOISE_STD
    ).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True))
    
    (_, _, _, _, _, fold_acc, fold_mu, fold_sigma, fold_p_right, fold_d,
     y_test, d_test_targets, arm_test_targets) = eval_epoch(
        model, test_loader, criterion_bce, criterion_nll, criterion_d, feas_loss_fn, DEVICE
    )
    all_fold_mu.append(fold_mu)
    all_fold_sigma.append(fold_sigma)
    all_fold_p_right.append(fold_p_right)
    all_fold_d.append(fold_d)
    
    fold_eucl = np.sqrt((y_test[:, 0] - fold_mu[:, 0])**2 + (y_test[:, 1] - fold_mu[:, 1])**2)
    print(f'Fold {fold+1} sur test: Eucl={fold_eucl.mean():.4f}, cls_acc={fold_acc:.1%}')

# Ensemble
all_fold_mu = np.stack(all_fold_mu)
all_fold_sigma = np.stack(all_fold_sigma)
all_fold_p_right = np.stack(all_fold_p_right)
all_fold_d = np.stack(all_fold_d)

y_pred = all_fold_mu.mean(axis=0)
d_pred_ensemble = all_fold_d.mean(axis=0).squeeze()
p_right_ensemble = all_fold_p_right.mean(axis=0).squeeze()

# Sigma ensemble (loi de la variance totale)
mean_var = (all_fold_sigma ** 2).mean(axis=0)
var_mu = all_fold_mu.var(axis=0)
y_sigma = np.sqrt(mean_var + var_mu)

# --- Métriques position ---
mse_x = mean_squared_error(y_test[:, 0], y_pred[:, 0])
mse_y = mean_squared_error(y_test[:, 1], y_pred[:, 1])
mae_x = mean_absolute_error(y_test[:, 0], y_pred[:, 0])
mae_y = mean_absolute_error(y_test[:, 1], y_pred[:, 1])
r2_x = r2_score(y_test[:, 0], y_pred[:, 0])
r2_y = r2_score(y_test[:, 1], y_pred[:, 1])
eucl_errors = np.sqrt((y_test[:, 0] - y_pred[:, 0])**2 + (y_test[:, 1] - y_pred[:, 1])**2)

# --- Métriques d ---
d_mae = np.abs(d_test_targets - d_pred_ensemble).mean()
d_r2 = r2_score(d_test_targets, d_pred_ensemble)

# --- Métriques classification ---
arm_pred = (p_right_ensemble >= 0.5).astype(float)
cls_accuracy = (arm_pred == arm_test_targets).mean()

# --- % hors labyrinthe ---
test_dist_to_skel = np.array([
    compute_distance_to_skeleton(y_pred[i, 0], y_pred[i, 1]) for i in range(len(y_pred))
])
pct_outside = (test_dist_to_skel > CORRIDOR_HALF_WIDTH).mean()

# --- Taux de confusion de bras ---
# Confusion = prédiction dans le mauvais bras (arm_pred != arm_true)
arm_confusion_mask = arm_pred != arm_test_targets
arm_confusion_rate = arm_confusion_mask.mean()
eucl_confusion = eucl_errors[arm_confusion_mask].mean() if arm_confusion_mask.any() else 0.0
eucl_correct = eucl_errors[~arm_confusion_mask].mean()

# --- Erreur par zone ---
left_zone = d_test_targets < 0.354
corridor_zone = (d_test_targets >= 0.354) & (d_test_targets <= 0.646)
right_zone = d_test_targets > 0.646

print(f'\n{"="*60}')
print(f'Transformer 02h : Hiérarchique — Ensemble ({N_FOLDS} folds)')
print(f'{"="*60}')
print(f'  MSE  : X={mse_x:.5f}, Y={mse_y:.5f}')
print(f'  MAE  : X={mae_x:.4f}, Y={mae_y:.4f}')
print(f'  R²   : X={r2_x:.4f}, Y={r2_y:.4f}')
print(f'  Eucl : mean={eucl_errors.mean():.4f}, median={np.median(eucl_errors):.4f}, p90={np.percentile(eucl_errors, 90):.4f}')
print(f'\n  d curviligne : MAE={d_mae:.4f}, R²={d_r2:.4f}')
print(f'  Classification bras : accuracy={cls_accuracy:.1%}')
print(f'  Hors labyrinthe : {pct_outside:.1%}')
print(f'\n  Confusion de bras : {arm_confusion_rate:.1%}')
print(f'    Eucl (bras correct)   : {eucl_correct:.4f}')
print(f'    Eucl (bras confondu)  : {eucl_confusion:.4f}')
print(f'\n  Erreur par zone :')
print(f'    Bras gauche (d<0.354) : Eucl={eucl_errors[left_zone].mean():.4f} ({left_zone.sum()} points)')
print(f'    Couloir haut          : Eucl={eucl_errors[corridor_zone].mean():.4f} ({corridor_zone.sum()} points)')
print(f'    Bras droit (d>0.646)  : Eucl={eucl_errors[right_zone].mean():.4f} ({right_zone.sum()} points)')
print(f'\n  Sigma moyen : X={y_sigma[:, 0].mean():.4f}, Y={y_sigma[:, 1].mean():.4f}')

## 11. Visualisations

In [None]:
# --- Scatter pred vs true ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].scatter(y_test[:, 0], y_pred[:, 0], s=1, alpha=0.3)
axes[0].plot([0, 1], [0, 1], 'r--', linewidth=2)
axes[0].set_xlabel('True X'); axes[0].set_ylabel('Predicted X')
axes[0].set_title(f'02h - Position X (R²={r2_x:.3f})')
axes[0].set_aspect('equal')

axes[1].scatter(y_test[:, 1], y_pred[:, 1], s=1, alpha=0.3)
axes[1].plot([0, 1], [0, 1], 'r--', linewidth=2)
axes[1].set_xlabel('True Y'); axes[1].set_ylabel('Predicted Y')
axes[1].set_title(f'02h - Position Y (R²={r2_y:.3f})')
axes[1].set_aspect('equal')

axes[2].scatter(d_test_targets, d_pred_ensemble, s=1, alpha=0.3)
axes[2].plot([0, 1], [0, 1], 'r--', linewidth=2)
axes[2].set_xlabel('True d'); axes[2].set_ylabel('Predicted d')
axes[2].set_title(f'02h - Distance curviligne (R²={d_r2:.3f})')
axes[2].set_aspect('equal')

plt.tight_layout()
plt.show()

In [None]:
# --- Trajectoire + classification + confusion ---
segment = slice(0, 500)
seg_idx = np.arange(500)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Trajectoire 2D avec squelette
axes[0, 0].plot(y_test[segment, 0], y_test[segment, 1], 'b-', alpha=0.5, label='Vraie trajectoire', linewidth=1)
axes[0, 0].plot(y_pred[segment, 0], y_pred[segment, 1], 'r-', alpha=0.5, label='Prediction (mu)', linewidth=1)
for x1, y1, x2, y2 in SKELETON_SEGMENTS:
    axes[0, 0].plot([x1, x2], [y1, y2], 'k--', linewidth=1, alpha=0.3)
axes[0, 0].set_xlabel('X'); axes[0, 0].set_ylabel('Y')
axes[0, 0].set_title('Trajectoire (500 premiers points test)')
axes[0, 0].legend()
axes[0, 0].set_aspect('equal')

# 2. P(droite) sur le temps
axes[0, 1].plot(seg_idx, arm_test_targets[segment], 'b-', label='Vrai bras (0=G, 1=D)', linewidth=1.5, alpha=0.5)
axes[0, 1].plot(seg_idx, p_right_ensemble[segment], 'r-', alpha=0.7, label='P(droite)', linewidth=1)
axes[0, 1].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
axes[0, 1].set_xlabel('Index'); axes[0, 1].set_ylabel('P(droite)')
axes[0, 1].set_title('Classification bras : P(droite) vs vérité')
axes[0, 1].legend()

# 3. Distance curviligne
axes[1, 0].plot(seg_idx, d_test_targets[segment], 'b-', label='Vrai d', linewidth=1.5)
axes[1, 0].plot(seg_idx, d_pred_ensemble[segment], 'r-', alpha=0.7, label='Prediction d', linewidth=1)
axes[1, 0].axhline(y=ARM_THRESHOLD, color='gray', linestyle='--', alpha=0.5, label=f'Seuil d={ARM_THRESHOLD}')
axes[1, 0].set_xlabel('Index'); axes[1, 0].set_ylabel('d (curviligne)')
axes[1, 0].set_title('Distance curviligne le long du U')
axes[1, 0].legend()

# 4. Carte de confusion de bras
correct_mask = ~arm_confusion_mask
axes[1, 1].scatter(y_test[correct_mask, 0], y_test[correct_mask, 1], 
                    c='green', s=1, alpha=0.2, label=f'Bras correct ({correct_mask.mean():.1%})')
if arm_confusion_mask.any():
    axes[1, 1].scatter(y_test[arm_confusion_mask, 0], y_test[arm_confusion_mask, 1], 
                        c='red', s=5, alpha=0.8, label=f'Bras confondu ({arm_confusion_rate:.1%})')
for x1, y1, x2, y2 in SKELETON_SEGMENTS:
    axes[1, 1].plot([x1, x2], [y1, y2], 'k--', linewidth=1, alpha=0.3)
axes[1, 1].set_xlabel('X'); axes[1, 1].set_ylabel('Y')
axes[1, 1].set_title('Confusion de bras (positions réelles)')
axes[1, 1].legend(markerscale=5)
axes[1, 1].set_aspect('equal')

plt.tight_layout()
plt.show()

In [None]:
# --- Heatmaps : erreur + hors labyrinthe + P(droite) ---
fig, axes = plt.subplots(1, 3, figsize=(21, 7))

nbins = 20
x_edges = np.linspace(0, 1, nbins + 1)
y_edges = np.linspace(0, 1, nbins + 1)

for ax_idx, (title, values, cmap, pos_for_binning) in enumerate([
    ('Erreur euclidienne moyenne', eucl_errors, 'RdYlGn_r', y_test),
    ('Sigma moyen predit', (y_sigma[:, 0] + y_sigma[:, 1]) / 2, 'RdYlGn_r', y_test),
    ('P(droite) moyenne', p_right_ensemble, 'RdBu_r', y_test)
]):
    val_map = np.full((nbins, nbins), np.nan)
    count_map = np.zeros((nbins, nbins))
    
    for i in range(len(pos_for_binning)):
        xi = np.clip(np.searchsorted(x_edges, pos_for_binning[i, 0]) - 1, 0, nbins - 1)
        yi = np.clip(np.searchsorted(y_edges, pos_for_binning[i, 1]) - 1, 0, nbins - 1)
        if np.isnan(val_map[yi, xi]):
            val_map[yi, xi] = 0
        val_map[yi, xi] += values[i]
        count_map[yi, xi] += 1
    
    mean_map = np.where(count_map > 0, val_map / count_map, np.nan)
    
    im = axes[ax_idx].imshow(mean_map, origin='lower', aspect='equal', 
                              cmap=cmap, extent=[0, 1, 0, 1])
    axes[ax_idx].set_xlabel('X'); axes[ax_idx].set_ylabel('Y')
    axes[ax_idx].set_title(title)
    plt.colorbar(im, ax=axes[ax_idx])

plt.tight_layout()
plt.show()

In [None]:
# --- Calibration ---
sigma_mean = (y_sigma[:, 0] + y_sigma[:, 1]) / 2

fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(sigma_mean, eucl_errors, s=1, alpha=0.3)
ax.set_xlabel('Sigma moyen predit'); ax.set_ylabel('Erreur euclidienne reelle')
ax.set_title('Calibration : incertitude vs erreur')
sigma_range = np.linspace(0, sigma_mean.max(), 100)
ax.plot(sigma_range, 2 * sigma_range, 'r--', label='y = 2*sigma', linewidth=1.5)
ax.legend()
plt.tight_layout()
plt.show()

# Calibration
in_1sigma = np.mean(eucl_errors < sigma_mean)
in_2sigma = np.mean(eucl_errors < 2 * sigma_mean)
in_3sigma = np.mean(eucl_errors < 3 * sigma_mean)
print(f'Calibration de l\'incertitude :')
print(f'  Erreur < 1*sigma : {in_1sigma:.1%} (attendu ~39%)')
print(f'  Erreur < 2*sigma : {in_2sigma:.1%} (attendu ~86%)')
print(f'  Erreur < 3*sigma : {in_3sigma:.1%} (attendu ~99%)')

## 12. Sauvegarde des prédictions

In [None]:
np.save('../outputs/preds_transformer_02h.npy', y_pred)
np.save('../outputs/sigma_transformer_02h.npy', y_sigma)
np.save('../outputs/d_pred_transformer_02h.npy', d_pred_ensemble)
np.save('../outputs/y_test_transformer_02h.npy', y_test)
np.save('../outputs/d_test_transformer_02h.npy', d_test_targets)
print(f'Predictions ensemble ({N_FOLDS} folds) sauvegardees.')
print(f'  preds_transformer_02h.npy : mu ensemble ({y_pred.shape})')
print(f'  sigma_transformer_02h.npy : sigma ensemble ({y_sigma.shape})')
print(f'  d_pred_transformer_02h.npy : d ensemble ({d_pred_ensemble.shape})')
print(f'  y_test_transformer_02h.npy : targets ({y_test.shape})')
print(f'  d_test_transformer_02h.npy : d targets ({d_test_targets.shape})')

## 13. Interprétation

### Approche hiérarchique vs approche directe (02g)

Le notebook 02g prédit (x, y) et d simultanément avec une seule tête de régression. La distance curviligne d force indirectement le backbone à distinguer les bras, mais la tête de position reste unique et doit gérer les deux bras.

Le notebook 02h sépare explicitement le problème :

1. **Classification** : le head `cls_head` apprend directement à distinguer gauche vs droite. C'est une tâche binaire simple que le modèle devrait résoudre avec une accuracy très élevée.

2. **Régression conditionnelle** : chaque head de position ne voit que les exemples de son bras pendant l'entraînement. Il peut donc se spécialiser :
   - Le head gauche apprend que x ≈ 0.15 et y varie de 0 à 0.85
   - Le head droit apprend que x ≈ 0.85 et y varie de 0 à 0.85

3. **Mélange à l'inférence** : la prédiction finale est `mu = p * mu_right + (1-p) * mu_left`. Quand la classification est confiante (p ≈ 0 ou p ≈ 1), on obtient quasiment la prédiction du bon head. Dans la zone de transition (couloir haut), le mélange interpolate entre les deux heads.

### Avantages attendus

- **Réduction des erreurs de confusion de bras** : la classification explicite rend les "sauts" entre bras moins probables
- **Spécialisation des heads** : chaque head peut apprendre une géométrie plus simple (essentiellement 1D le long de son bras)
- **Incertitude mieux calibrée** : dans la zone de transition, l'incertitude augmente naturellement via la variance du mélange

### Points d'attention

- **Plus de paramètres** : 2 heads de régression au lieu de 1, plus le head de classification. Le risque d'overfitting augmente légèrement.
- **Zone de transition** : le couloir haut est la zone la plus difficile car les deux heads doivent y contribuer via le mélange
- **Propagation d'erreur** : si la classification se trompe, l'erreur de position est amplifiée car le mauvais head domine