# Pipeline v3 — Conv1D + MDN + Transformer / GRU / Mamba

**Nouveautés :**
1. **Conv1D** sur les waveforms (capture la morphologie du spike)
2. **MDN** (Mixture of Gaussians) pour des prédictions multimodales
3. **3 architectures** comparées : Transformer, GRU, Mamba
4. Utilise le nouveau `dataset.py` (split temporel, collate propre)


In [1]:
%matplotlib inline

import os, time, math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace

# Import du nouveau dataset.py
from dataset import create_data_loaders, load_params


## 0. Configuration & chargement des données

In [2]:
# Configuration des données — SimpleNamespace pour create_data_loaders
data_cfg = SimpleNamespace(
    dataset_dir=os.path.abspath("dataset"),
    mouse="M1199_PAG",
    stride=4,
    window_size=108,
    val_fraction=0.1,        # 10% val (split temporel)
    use_speed_mask=True,     # filtrer les moments où la souris bouge
    batch_size=64,
    num_workers=0,
)

# Charge le parquet, split temporel, crée les DataLoaders
train_loader, val_loader, params = create_data_loaders(data_cfg)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nDevice: {device}")


Mouse: M1199_PAG | nGroups=4 | channels=[6, 4, 6, 4]
Loading /Users/ippo/Downloads/hackathon/theta-gang-master/dataset/M1199_PAG_stride4_win108_test.parquet...
Loaded in 34.3s -- 62257 rows
SpeedMask: 22974/62257 moving samples (36.9%)
Temporal split: train=20676, val=2298
Train: 20676 samples, 324 batches
Val:   2298 samples, 36 batches

Device: cpu


In [3]:
# Vérifier la forme d'un batch
batch = next(iter(train_loader))
print("Shapes du premier batch :")
for k, v in batch.items():
    print(f"  {k:15s} → {v.shape}  dtype={v.dtype}")


Shapes du premier batch :
  groups          → torch.Size([64, 104])  dtype=torch.int64
  mask            → torch.Size([64, 104])  dtype=torch.float32
  spike_times     → torch.Size([64, 104])  dtype=torch.float32
  indices0        → torch.Size([64, 104])  dtype=torch.int64
  indices1        → torch.Size([64, 104])  dtype=torch.int64
  indices2        → torch.Size([64, 104])  dtype=torch.int64
  indices3        → torch.Size([64, 104])  dtype=torch.int64
  group0          → torch.Size([64, 29, 6, 32])  dtype=torch.float32
  group1          → torch.Size([64, 26, 4, 32])  dtype=torch.float32
  group2          → torch.Size([64, 34, 6, 32])  dtype=torch.float32
  group3          → torch.Size([64, 26, 4, 32])  dtype=torch.float32
  pos             → torch.Size([64, 2])  dtype=torch.float32


---
## 1. Modules partagés

Les 3 architectures partagent :
- **Conv1DWaveformEmbedder** : Conv1D sur les 32 bins temporels (au lieu de Linear)
- **ContinuousPositionalEncoding** : encodage temporel basé sur indexInDat
- **MDNHead** : Mixture Density Network (K gaussiennes)
- **SpikeEncoderBase** : pipeline d'encodage commun (embed → gather → concat → proj)


In [4]:
class Conv1DWaveformEmbedder(nn.Module):
    """
    Conv1D sur les waveforms au lieu de Linear.
    
    Traite chaque canal comme un signal 1D et applique des convolutions
    pour capturer la morphologie du spike :
    - pente du pic (dépolarisation)
    - largeur (type de neurone)
    - asymétrie (retour à la baseline)
    """

    def __init__(self, n_channels, n_features):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(n_channels, n_features, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(n_features, n_features, kernel_size=5, padding=2),
            nn.ReLU(),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, wf):
        """wf: (B, n_spikes, n_ch, 32) → (B, n_spikes, n_features)"""
        B, S, nCh, T = wf.shape
        x = wf.reshape(B * S, nCh, T)
        x = self.conv(x)
        x = self.pool(x).squeeze(-1)
        return x.reshape(B, S, -1)


class ContinuousPositionalEncoding(nn.Module):
    """Encode le vrai timestamp du spike (continu entre 0 et 1)."""

    def __init__(self, d_model, max_freq_scale=1000.0):
        super().__init__()
        freqs = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(max_freq_scale) / d_model)
        )
        self.register_buffer("freqs", freqs)

    def forward(self, spike_times):
        t = spike_times.clamp(min=0).unsqueeze(-1)
        angles = t * self.freqs.unsqueeze(0).unsqueeze(0)
        return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)


class MDNHead(nn.Module):
    """
    Mixture Density Network : prédit K gaussiennes.
    
    Sorties par composante : pi_k (poids), mu_k (centre 2D), sigma_k (écart-type 2D)
    Le modèle peut dire : "70% en (3,5) et 30% en (7,2)"
    """

    def __init__(self, input_dim, n_components=3, dropout=0.1):
        super().__init__()
        self.n_components = n_components
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(input_dim, n_components * 5),
        )

    def forward(self, context):
        B = context.shape[0]
        K = self.n_components
        raw = self.net(context).reshape(B, K, 5)
        pi = F.softmax(raw[:, :, 0], dim=-1)
        mu = raw[:, :, 1:3]
        sigma = torch.exp(raw[:, :, 3:5].clamp(-5, 5))
        return pi, mu, sigma


def mdn_loss(pi, mu, sigma, target):
    """NLL d'un mélange de gaussiennes (généralise gaussian_nll_loss)."""
    target_exp = target.unsqueeze(1).expand_as(mu)
    log_normal = -0.5 * (
        np.log(2 * np.pi) + 2 * torch.log(sigma) + ((target_exp - mu) / sigma) ** 2
    )
    log_normal = log_normal.sum(dim=-1)
    log_prob = torch.logsumexp(torch.log(pi + 1e-8) + log_normal, dim=-1)
    return -log_prob.mean()

print("✅ Conv1DWaveformEmbedder, ContinuousPositionalEncoding, MDNHead")


✅ Conv1DWaveformEmbedder, ContinuousPositionalEncoding, MDNHead


In [5]:
class SpikeEncoderBase(nn.Module):
    """
    Base commune : encode les spikes en séquence (B, L, n_features).
    Les sous-classes ajoutent leur séquenceur (Transformer / GRU / Mamba).
    """

    def __init__(self, params, n_features=128, dropout=0.1, n_components=3):
        super().__init__()
        self.params = params
        self.n_features = n_features
        self.n_groups = params.nGroups

        # Conv1D embedders par shank
        self.embedders = nn.ModuleList()
        for g in range(self.n_groups):
            self.embedders.append(Conv1DWaveformEmbedder(params.nChannelsPerGroup[g], n_features))

        self.group_embedding = nn.Embedding(self.n_groups + 1, n_features, padding_idx=0)
        self.temporal_pe = ContinuousPositionalEncoding(n_features)
        self.input_proj = nn.Sequential(
            nn.Linear(n_features * 3, n_features), nn.ReLU(), nn.Dropout(dropout),
        )
        self.head = MDNHead(n_features, n_components=n_components, dropout=dropout)

    def encode_sequence(self, batch):
        groups_seq = batch["groups"]
        mask = batch["mask"]
        spike_times = batch["spike_times"]
        B, L = groups_seq.shape

        gathered_list = []
        for g in range(self.n_groups):
            wf = batch[f"group{g}"]
            n_spikes = wf.shape[1]
            emb = self.embedders[g](wf)

            null = torch.zeros(B, 1, self.n_features, device=emb.device)
            full_emb = torch.cat([null, emb], dim=1)

            raw_idx = batch[f"indices{g}"]
            safe_idx = raw_idx.clamp(min=0, max=n_spikes)
            idx = safe_idx.unsqueeze(-1).expand(-1, -1, self.n_features)
            gathered = torch.gather(full_emb, dim=1, index=idx)
            gathered_list.append(gathered)

        stacked = torch.stack(gathered_list, dim=2)
        g_clamped = groups_seq.clamp(min=0)
        one_hot = F.one_hot(g_clamped, num_classes=self.n_groups).float()
        wf_feat = (stacked * one_hot.unsqueeze(-1)).sum(dim=2)

        group_emb = self.group_embedding(g_clamped + 1)
        time_emb = self.temporal_pe(spike_times)

        combined = torch.cat([wf_feat, group_emb, time_emb], dim=-1)
        return self.input_proj(combined), mask

    def pool_and_predict(self, seq_out, mask):
        mask_exp = mask.unsqueeze(-1)
        context = (seq_out * mask_exp).sum(dim=1) / mask_exp.sum(dim=1).clamp(min=1e-7)
        return self.head(context)

print("✅ SpikeEncoderBase")


✅ SpikeEncoderBase


---
## 2. Les 3 architectures

| | Transformer | GRU | Mamba |
|---|---|---|---|
| **Mécanisme** | Self-attention (chaque spike voit tous les autres) | Récurrence bidirectionnelle | Conv causale + gate |
| **Complexité** | O(L²) | O(L) | O(L) |
| **Force** | Relations longue distance | Léger, capture l'ordre | Rapide, scalable |


In [6]:
# ============================================================
# TRANSFORMER
# ============================================================
class SpikeTransformerV3(SpikeEncoderBase):
    def __init__(self, params, n_features=128, n_heads=4, n_layers=2,
                 dropout=0.1, n_components=3):
        super().__init__(params, n_features, dropout, n_components)
        layer = nn.TransformerEncoderLayer(
            d_model=n_features, nhead=n_heads,
            dim_feedforward=n_features * 4, dropout=dropout, batch_first=True,
        )
        self.sequencer = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, batch):
        seq, mask = self.encode_sequence(batch)
        seq_out = self.sequencer(seq, src_key_padding_mask=(mask == 0))
        return self.pool_and_predict(seq_out, mask)


# ============================================================
# GRU BIDIRECTIONNEL
# ============================================================
class SpikeGRUV3(SpikeEncoderBase):
    def __init__(self, params, n_features=128, n_gru_layers=2,
                 dropout=0.1, n_components=3):
        super().__init__(params, n_features, dropout, n_components)
        self.gru = nn.GRU(
            input_size=n_features, hidden_size=n_features,
            num_layers=n_gru_layers, batch_first=True,
            bidirectional=True, dropout=dropout if n_gru_layers > 1 else 0,
        )
        self.gru_proj = nn.Linear(n_features * 2, n_features)

    def forward(self, batch):
        seq, mask = self.encode_sequence(batch)
        gru_out, _ = self.gru(seq)
        seq_out = self.gru_proj(gru_out)
        return self.pool_and_predict(seq_out, mask)


# ============================================================
# MAMBA-LIKE
# ============================================================
class SimpleMambaBlock(nn.Module):
    """Conv causale + gate, inspiré de Mamba/S4. Complexité O(L)."""
    def __init__(self, d_model, d_conv=4, expand=2, dropout=0.1):
        super().__init__()
        d_inner = d_model * expand
        self.in_proj = nn.Linear(d_model, d_inner)
        self.conv1d = nn.Conv1d(d_inner, d_inner, kernel_size=d_conv,
                                padding=d_conv - 1, groups=d_inner)
        self.act = nn.SiLU()
        self.gate_proj = nn.Linear(d_model, d_inner)
        self.out_proj = nn.Linear(d_inner, d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        residual = x
        x_n = self.norm(x)
        h = self.act(self.conv1d(self.in_proj(x_n).transpose(1,2))[:,:,:x.shape[1]].transpose(1,2))
        h = h * torch.sigmoid(self.gate_proj(x_n))
        out = self.dropout(self.out_proj(h))
        if mask is not None:
            out = out * mask.unsqueeze(-1)
        return residual + out


class SpikeMambaV3(SpikeEncoderBase):
    def __init__(self, params, n_features=128, n_layers=4,
                 dropout=0.1, n_components=3):
        super().__init__(params, n_features, dropout, n_components)
        self.mamba_layers = nn.ModuleList([
            SimpleMambaBlock(n_features, dropout=dropout) for _ in range(n_layers)
        ])

    def forward(self, batch):
        seq, mask = self.encode_sequence(batch)
        x = seq
        for layer in self.mamba_layers:
            x = layer(x, mask=mask)
        return self.pool_and_predict(x, mask)

print("✅ SpikeTransformerV3, SpikeGRUV3, SpikeMambaV3")


✅ SpikeTransformerV3, SpikeGRUV3, SpikeMambaV3


---
## 3. Entraînement


In [7]:
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss, n = 0.0, 0
    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        pi, mu, sigma = model(batch)
        loss = mdn_loss(pi, mu, sigma, batch["pos"])
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        total_loss += loss.item() * len(batch["pos"])
        n += len(batch["pos"])
    return total_loss / n


@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_mse, total_nll, n = 0.0, 0.0, 0
    all_preds, all_sigmas, all_targets = [], [], []

    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        target = batch["pos"]
        pi, mu, sigma = model(batch)

        total_nll += mdn_loss(pi, mu, sigma, target).item() * len(target)

        pred = (pi.unsqueeze(-1) * mu).sum(dim=1)
        total_mse += F.mse_loss(pred, target).item() * len(target)
        n += len(target)

        all_preds.append(pred.cpu())
        all_sigmas.append((pi.unsqueeze(-1) * sigma).sum(dim=1).cpu())
        all_targets.append(target.cpu())

    return {
        "mse": total_mse / n, "nll": total_nll / n,
        "preds": torch.cat(all_preds),
        "targets": torch.cat(all_targets),
        "sigmas": torch.cat(all_sigmas),
    }


def run_experiment(name, model, epochs=50, lr=3e-4):
    print(f"\n{'='*60}")
    n_params = sum(p.numel() for p in model.parameters())
    print(f"  {name.upper()} — {n_params:,} params")
    print(f"{'='*60}\n")

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

    hist = {"train": [], "val_mse": [], "val_nll": []}
    best_val = float("inf")
    save_path = f"best_{name}.pt"

    for epoch in range(1, epochs + 1):
        t0 = time.time()
        tl = train_one_epoch(model, train_loader, optimizer)
        vr = evaluate(model, val_loader)
        dt = time.time() - t0

        hist["train"].append(tl)
        hist["val_mse"].append(vr["mse"])
        hist["val_nll"].append(vr["nll"])

        old_lr = optimizer.param_groups[0]["lr"]
        scheduler.step(vr["nll"])
        new_lr = optimizer.param_groups[0]["lr"]

        star = ""
        if vr["nll"] < best_val:
            best_val = vr["nll"]
            torch.save(model.state_dict(), save_path)
            star = " ★"

        lr_s = f" (lr {old_lr:.0e}→{new_lr:.0e})" if new_lr != old_lr else ""
        print(f"  {epoch:3d}/{epochs} | train={tl:.4f} | mse={vr['mse']:.6f} | nll={vr['nll']:.4f}{star}{lr_s} ({dt:.1f}s)")

    model.load_state_dict(torch.load(save_path, weights_only=True))
    final = evaluate(model, val_loader)
    dist = torch.sqrt(((final["preds"] - final["targets"]) ** 2).sum(dim=1))

    w1 = (torch.abs(final["preds"] - final["targets"]) < final["sigmas"]).all(1).float().mean()
    w2 = (torch.abs(final["preds"] - final["targets"]) < 2*final["sigmas"]).all(1).float().mean()

    print(f"\n  → Eucl: mean={dist.mean():.4f} median={dist.median():.4f}")
    print(f"  → MSE={final['mse']:.6f}  NLL={final['nll']:.4f}")
    print(f"  → Calibration: ±1σ={100*w1:.1f}%  ±2σ={100*w2:.1f}%")

    return {"name": name, "n_params": n_params, "hist": hist, "final": final, "dist": dist}


---
## 4. Lancement des 3 architectures


In [8]:
N_FEATURES = 128
N_COMPONENTS = 3
DROPOUT = 0.1
EPOCHS = 50
LR = 3e-4

results = {}


### 4.1 Transformer

In [None]:
results["transformer"] = run_experiment("transformer",
    SpikeTransformerV3(params, N_FEATURES, n_heads=4, n_layers=2,
                       dropout=DROPOUT, n_components=N_COMPONENTS),
    epochs=EPOCHS, lr=LR)



  TRANSFORMER — 806,415 params



  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


    1/50 | train=0.3903 | mse=0.084818 | nll=-0.1637 ★ (869.7s)


### 4.2 GRU

In [None]:
results["gru"] = run_experiment("gru",
    SpikeGRUV3(params, N_FEATURES, n_gru_layers=2,
               dropout=DROPOUT, n_components=N_COMPONENTS),
    epochs=EPOCHS, lr=LR)


### 4.3 Mamba

In [None]:
results["mamba"] = run_experiment("mamba",
    SpikeMambaV3(params, N_FEATURES, n_layers=4,
                 dropout=DROPOUT, n_components=N_COMPONENTS),
    epochs=EPOCHS, lr=LR)


---
## 5. Comparaison


In [None]:
print(f"{'='*80}")
print(f"  COMPARAISON FINALE")
print(f"{'='*80}")
print(f"  {'Arch':<15} {'Params':>10} {'MSE':>10} {'NLL':>10} {'Eucl.':>10} {'±1σ':>8} {'±2σ':>8}")
print(f"  {'-'*73}")

for name, r in results.items():
    f = r["final"]
    d = r["dist"]
    s = f["sigmas"]
    w1 = (torch.abs(f["preds"] - f["targets"]) < s).all(1).float().mean()
    w2 = (torch.abs(f["preds"] - f["targets"]) < 2*s).all(1).float().mean()
    print(f"  {name:<15} {r['n_params']:>10,} {f['mse']:>10.6f} {f['nll']:>10.4f} "
          f"{d.mean():>10.4f} {100*w1:>7.1f}% {100*w2:>7.1f}%")


In [None]:
colors = {"transformer": "tab:blue", "gru": "tab:orange", "mamba": "tab:green"}

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for name, r in results.items():
    c = colors[name]
    axes[0].plot(r["hist"]["val_mse"], label=name, color=c)
    axes[1].plot(r["hist"]["val_nll"], label=name, color=c)

axes[0].set_title("Val MSE (↓)"); axes[0].set_xlabel("Epoch")
axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].set_title("Val NLL MDN (↓)"); axes[1].set_xlabel("Epoch")
axes[1].legend(); axes[1].grid(True, alpha=0.3)

names = list(results.keys())
means = [results[n]["dist"].mean().item() for n in names]
bars = axes[2].bar(names, means, color=[colors[n] for n in names])
for bar, val in zip(bars, means):
    axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.003,
                 f"{val:.4f}", ha="center", fontsize=10)
axes[2].set_title("Erreur euclidienne moyenne")

plt.tight_layout()
plt.show()


---
## 6. Visualisation du meilleur modèle


In [None]:
best_name = min(results, key=lambda n: results[n]["final"]["mse"])
print(f"Meilleur modèle : {best_name.upper()}")

r = results[best_name]
preds_np = r["final"]["preds"].numpy()
targets_np = r["final"]["targets"].numpy()
dist_np = r["dist"].numpy()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Cible vs prédit
ax = axes[0]
ax.scatter(targets_np[:, 0], targets_np[:, 1], s=3, alpha=0.3, label="cible", c="steelblue")
ax.scatter(preds_np[:, 0], preds_np[:, 1], s=3, alpha=0.3, label="prédit", c="coral")
ax.set_xlabel("x"); ax.set_ylabel("y")
ax.set_title(f"Cible vs Prédit ({best_name})")
ax.legend(); ax.set_aspect("equal")

# 2. Erreur par position (percentile clipping + meilleur colormap)
ax = axes[1]
vmin, vmax = np.percentile(dist_np, 2), np.percentile(dist_np, 98)
sc = ax.scatter(targets_np[:, 0], targets_np[:, 1], s=5, c=dist_np,
                cmap="RdYlGn_r", alpha=0.6, vmin=vmin, vmax=vmax)
plt.colorbar(sc, ax=ax, label="erreur euclidienne")
ax.set_xlabel("x"); ax.set_ylabel("y")
ax.set_title("Erreur par position"); ax.set_aspect("equal")

# 3. Distribution des erreurs
ax = axes[2]
ax.hist(dist_np, bins=50, color="teal", alpha=0.7)
ax.axvline(dist_np.mean(), color="red", ls="--", label=f"mean={dist_np.mean():.4f}")
ax.axvline(np.median(dist_np), color="orange", ls="--", label=f"median={np.median(dist_np):.4f}")
ax.set_xlabel("Erreur euclidienne")
ax.set_title("Distribution des erreurs"); ax.legend()

plt.suptitle(f"Résultats — {best_name.upper()} (Conv1D + MDN, K={N_COMPONENTS})",
             fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()
