# XAI-MVA Lab- 2h — Variance-Based Sensitivity / Sobol Attribution for an Audio Model

**Goal.** Build a small PyTorch model that classifies *synthetic* audio signals from a time–frequency representation, then **interpret** the decision using **variance-based (Sobol) attribution** over spectrogram patches.

> This lab is fully self-contained (no dataset download, except if you aim to try on an existing dataset afterwards).  
> Get the `sobol_attribution_method` implementation  at https://github.com/fel-thomas, we will use it for the XAI part.

---

## Outline (indicative)
1. Generate a synthetic audio mini-dataset (sines + noise + AM)   
2. Features: log-magnitude spectrogram  
3. Small CNN + training loop  
4. Evaluation & typical errors 
5. **Sobol attribution** on the spectrogram: interpretation & `grid_size` 
6. Questions / mini-exercises 

---

### Notation
- $x(t)$: time-domain signal  
- $X(f,\tau)$: STFT (time–frequency)  
- Network input: $\log(1+|X|)$ (“image” in time–frequency)  
- Attribution: importance per *patch* (grid) on that image

---



In [1]:
# (Optionnel) Installs si vous exécutez sur une machine vierge
# !pip -q install --upgrade numpy scipy matplotlib
# !pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

import math
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


SEED = 225
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

## 1) Synthetic audio dataset generation 

We generate short waveforms (e.g., 0.5 s at 8 kHz) with:
- either a **low sine** (class 0) *or* a **high sine** (class 1),
- a bit of **noise**,
- a mild **amplitude modulation (AM)** so the task is not “too easy”.

### Questions (answer in the notebook)
- **Q1.** Why is AM useful to avoid an overly trivial dataset?
- **Q2.** What is the risk if we always generate exactly the same phase / the same amplitude?


In [2]:
# Exercise 1 — Synthetic audio generator (warm-up)
# Fill the TODOs to generate a 1D waveform x (torch.Tensor of shape [T]) and a label in {0,1}.
# Hints:
# - Use torch.sin for pure tones
# - Apply a simple amplitude modulation: (1 + am_depth * sin(2π f_am t))
# - Add white Gaussian noise to match a target SNR in dB
# - Optionally add an interferer tone (a second sine) at a random frequency

import math
import torch
import matplotlib.pyplot as plt

def synth_sample(sr=8000, dur=0.5,
                 f_low=220.0, f_high=880.0,
                 label=None,
                 snr_db=10.0,
                 am_depth=0.3,
                 add_interferer=True):
    """Return (x, label) where x is a torch.Tensor [T] and label in {0,1}.
    Class 0: dominant component at f_low
    Class 1: dominant component at f_high
    """
    T = int(sr * dur)
    t = torch.linspace(0.0, dur, T, dtype=torch.float32)

    # TODO(1): sample / set the label in {0,1}
    # label = ...
    raise NotImplementedError("TODO(1)")

    # TODO(2): choose the main frequency f0 based on the label
    # f0 = ...
    raise NotImplementedError("TODO(2)")

    # TODO(3): generate the main sinusoid s0(t)
    # s0 = ...
    raise NotImplementedError("TODO(3)")

    # TODO(4): apply a simple amplitude modulation (AM)
    # f_am = ... (e.g., 2 to 6 Hz)
    # env = ...
    # x = env * s0
    raise NotImplementedError("TODO(4)")

    # Optional interferer (a second tone)
    if add_interferer:
        # TODO(5): add an interferer sine with random frequency in [300, 1500] Hz
        # fi = ...
        # x = x + 0.3 * torch.sin(2*pi*fi*t + phase)
        raise NotImplementedError("TODO(5)")

    # Add noise for target SNR
    # TODO(6): add white noise so that SNR(x_clean, noise) ≈ snr_db
    # - estimate signal power: P_signal = mean(x^2)
    # - noise power: P_noise = P_signal / (10^(snr_db/10))
    # - noise = sqrt(P_noise) * randn_like(x)
    # x_noisy = x + noise
    raise NotImplementedError("TODO(6)")

    # Optional: normalize to roughly [-1, 1]
    x = x_noisy / (x_noisy.abs().max() + 1e-8)
    return x, int(label)

# Demo: plot a few waveforms
sr = 8000
xs, ys = zip(*[synth_sample(sr=sr, dur=0.5) for _ in range(4)])

plt.figure()
for i, (x, y) in enumerate(zip(xs, ys)):
    plt.plot(x.numpy() + i * 2.2, label=f"y={y}")
plt.title("Synthetic waveforms (vertically shifted)")
plt.legend()
plt.show()


NotImplementedError: TODO(1)

## 2) Feature: log-magnitude spectrogram (STFT)

We turn $x(t)$ into a time–frequency “image” using the STFT.

### Questions
- **Q3.** Why do we often use `log(1 + |X|)` rather than raw `|X|`?
- **Q4.** What trade-off is introduced by the choice of `n_fft` / `hop_length`?


In [None]:
# Exercise 2 — STFT log-magnitude features
# Implement a simple log-magnitude spectrogram:
#   feat = log(1 + |STFT(x)|)
# Return shape: [F, TT]

import torch

def stft_logmag(x, n_fft=256, hop_length=64, win_length=256):
    """Returns a tensor [F, TT] (frequency x time) in log-magnitude."""
    # TODO(1): create a Hann window (torch.hann_window)
    # window = ...
    raise NotImplementedError("TODO(1)")

    # TODO(2): compute the complex STFT (torch.stft, return_complex=True)
    # X = ...
    raise NotImplementedError("TODO(2)")

    # TODO(3): magnitude -> log1p
    # mag = ...
    # feat = ...
    raise NotImplementedError("TODO(3)")

    return feat

# Visualization
x0, _ = synth_sample(sr=sr, dur=0.5, label=0)
x1, _ = synth_sample(sr=sr, dur=0.5, label=1)

feat0 = stft_logmag(x0)
feat1 = stft_logmag(x1)

plt.figure()
plt.imshow(feat0.numpy(), aspect="auto", origin="lower")
plt.title("Log-magnitude spectrogram (class 0)")
plt.colorbar()
plt.show()

plt.figure()
plt.imshow(feat1.numpy(), aspect="auto", origin="lower")
plt.title("Log-magnitude spectrogram (class 1)")
plt.colorbar()
plt.show()


## 3) PyTorch Dataset + DataLoader 

We wrap waveforms and spectrograms into a `Dataset`, and use a `DataLoader` for mini-batches.

Tip: in real audio tasks, you may precompute features or compute them on-the-fly depending on I/O constraints.


In [None]:
# Exercise 3 — Dataset + DataLoaders
# Implement __getitem__ to return:
#   x_feat: torch.Tensor of shape [1, F, TT]
#   y: int in {0,1}
# Optional: simple time-shift augmentation

import random
import numpy as np
from torch.utils.data import Dataset, DataLoader

class SyntheticAudioTFDataset(Dataset):
    def __init__(self, n=2000, sr=8000, dur=0.5, n_fft=256, hop_length=64,
                 snr_db_range=(0, 20), do_time_shift=False):
        self.n = n
        self.sr = sr
        self.dur = dur
        self.n_fft = n_fft
        self.hop = hop_length
        self.snr_db_range = snr_db_range
        self.do_time_shift = do_time_shift

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        # TODO(1): sample an SNR in the range, and a label
        # snr_db = ...
        # label = ...
        raise NotImplementedError("TODO(1)")

        # TODO(2): generate waveform
        # x, y = synth_sample(...)
        raise NotImplementedError("TODO(2)")

        # TODO(3): (optional) random circular time shift
        # if self.do_time_shift: ...
        raise NotImplementedError("TODO(3)")

        # TODO(4): compute log-magnitude STFT features
        # feat = stft_logmag(...)
        raise NotImplementedError("TODO(4)")

        # TODO(5): normalize features (e.g., per-sample mean/std)
        # feat = (feat - feat.mean()) / (feat.std() + 1e-8)
        raise NotImplementedError("TODO(5)")

        # TODO(6): add channel dimension -> [1, F, TT]
        # feat = feat.unsqueeze(0)
        raise NotImplementedError("TODO(6)")

        return feat, y

# Build loaders
train_ds = SyntheticAudioTFDataset(n=2000, do_time_shift=True)
val_ds   = SyntheticAudioTFDataset(n=400,  do_time_shift=False)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=0)

# Quick sanity check
xb, yb = next(iter(train_loader))
print("Batch x:", xb.shape, xb.dtype, "Batch y:", yb.shape, yb[:8].tolist())


## 4) A small CNN (on spectrograms)

We keep the architecture compact so training is fast on CPU.

### Questions
- **Q5.** Why does a CNN naturally “see” local patterns on a spectrogram?
- **Q6.** What is the difference between “time invariance” vs “frequency invariance” in this problem?


In [None]:
# Exercise 4 — A small CNN for spectrogram classification
# Implement a simple ConvNet that maps [B, 1, F, TT] -> logits [B, 2]

import torch.nn as nn
import torch

class SmallSpecCNN(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        # TODO(1): define a small convolutional feature extractor (Conv2d/BN/ReLU/Pool)
        # self.net = nn.Sequential(...)
        raise NotImplementedError("TODO(1)")

        # TODO(2): define a final classifier head (e.g., AdaptiveAvgPool2d + Linear)
        # self.head = ...
        raise NotImplementedError("TODO(2)")

    def forward(self, x):
        # TODO(3): forward pass
        # z = self.net(x)
        # z = ...
        # logits = self.head(z)
        raise NotImplementedError("TODO(3)")

# Device + model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SmallSpecCNN().to(device)
print(model)


## 5) Training

We train with cross-entropy and monitor validation loss/accuracy.

**Tip:** with synthetic data, it is easy to accidentally “leak” information (e.g., fixed amplitudes, fixed noise, etc.). Keep an eye on generalization.


In [None]:
# Exercise 5 — Training loop
# Fill the TODOs in train().

import numpy as np
import torch

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    losses = []
    ce = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = ce(logits, y)
            losses.append(loss.item())
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return float(np.mean(losses)), correct / total

def train(model, train_loader, val_loader, epochs=6, lr=1e-3):
    # TODO(1): create optimizer (Adam is fine)
    # opt = ...
    raise NotImplementedError("TODO(1)")

    ce = nn.CrossEntropyLoss()
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for ep in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        correct, total = 0, 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # TODO(2): forward pass -> loss
            # logits = ...
            # loss = ...
            raise NotImplementedError("TODO(2)")

            # TODO(3): backward + optimizer step
            # opt.zero_grad()
            # loss.backward()
            # opt.step()
            raise NotImplementedError("TODO(3)")

            running_loss += loss.item() * y.numel()
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()

        train_loss = running_loss / total
        train_acc = correct / total
        val_loss, val_acc = evaluate(model, val_loader)

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"Epoch {ep:02d} | train loss={train_loss:.3f}, acc={train_acc:.3f} | "
              f"val loss={val_loss:.3f}, acc={val_acc:.3f}")

    return history

history = train(model, train_loader, val_loader, epochs=6, lr=1e-3)


## 6) Inspecting errors

We inspect a few misclassified examples to understand what the model focuses on:
- is it confusing classes when SNR is low?
- is it over-relying on artifacts (noise, AM) rather than frequency content?


In [None]:
def collect_misclassified(model, loader, max_items=8):
    model.eval()
    items=[]
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            bad = (pred != y).nonzero(as_tuple=False).squeeze(1)
            for i in bad.tolist():
                items.append((x[i].detach().cpu(), int(y[i].cpu()), int(pred[i].cpu())))
                if len(items) >= max_items:
                    return items
    return items

bad = collect_misclassified(model, val_loader, max_items=6)
print("nb misclassified shown:", len(bad))

for i,(feat, y_true, y_pred) in enumerate(bad):
    plt.figure()
    plt.imshow(feat.squeeze(0).numpy(), origin="lower", aspect="auto")
    plt.title(f"Misclassified #{i} | true={y_true} pred={y_pred}")
    plt.xlabel("Temps")
    plt.ylabel("Fréquence")
    plt.colorbar()
    plt.show()


## 7) Variance-based attribution (Sobol) on the spectrogram

We want to measure the **importance of regions** of the spectrogram for the network’s decision.

Idea:
- split the input into a grid of patches (`grid_size`)
- mask / replace patches according to binary variables
- estimate Sobol indices (via Quasi Monte Carlo) to approximate each patch’s contribution to the output variance.

### Things to explore
- effect of `grid_size` (granularity)
- effect of `nb_samples` (estimator variance)

### Questions
- **Q7.** Why does estimation become harder when `grid_size` increases?
- **Q8.** On an audio spectrogram, which zones should matter to distinguish class 0 vs 1?


In [None]:
# Attempt to import de la library (comme dans votre notebook d'exemple)
has_sobol = False
try:
    from sobol_attribution_method.torch_explainer import SobolAttributionMethod
    has_sobol = True
except Exception as e:
    print("SobolAttributionMethod not found :", repr(e))
has_sobol


In [None]:
def show_attr(feat, attr, title=""):
    """Affiche entrée + attribution (mêmes dimensions)"""
    feat2 = feat.squeeze(0).numpy()
    attr2 = attr.squeeze(0).numpy()

    plt.figure()
    plt.imshow(feat2, origin="lower", aspect="auto")
    plt.title(title + " — spectrogram")
    plt.xlabel("Temps")
    plt.ylabel("Freq")
    plt.colorbar()
    plt.show()

    plt.figure()
    plt.imshow(attr2, origin="lower", aspect="auto")
    plt.title(title + " — attribution (Sobol)")
    plt.xlabel("Temps")
    plt.ylabel("Freq")
    plt.colorbar()
    plt.show()

if has_sobol:
    model.eval()
    # prendre un batch et choisir un exemple
    xb, yb = next(iter(val_loader))
    feat = xb[0].to(device)          # [1,F,T]
    y = int(yb[0].item())
    feat_b = feat.unsqueeze(0)       # [B=1,1,F,T]

    with torch.no_grad():
        logits = model(feat_b)
        pred = int(logits.argmax(dim=1).item())
    print("true:", y, "pred:", pred)

    # On construit un explainer. Plus nb_samples est grand, plus c'est long mais plus stable.
    explainer = SobolAttributionMethod(model, grid_size=8, nb_samples=256, batch_size=32)

    # L'API attend souvent une image (B,C,H,W) ; notre spectrogramme est déjà (B,1,F,T)
    # On explique la probabilité de la classe prédite (ou la vraie classe).
    target_class = pred
    explanation = explainer.explain(feat_b, target=target_class)  # tensor (B,1,F,T) ou proche

    show_attr(feat.cpu(), explanation[0].detach().cpu(), title=f"Sobol grid=8 (target={target_class})")


## 8) Exploring `grid_size` and stability

Try several `grid_size` values (e.g., 4, 8, 12) and compare:
- spatial resolution of the attribution map,
- stability across Monte Carlo runs.

### Mini-exercise
- **E5.** Fix `grid_size` and test `nb_samples ∈ {64, 256, 1024}`. What do you observe?


In [None]:
# Exercise 7 — Sobol attribution (global variance-based explanation)
# Fill the TODOs to compute Sobol attributions with different grid sizes.

if has_sobol:
    grid_sizes = [4, 8, 12]
    nb_samples = 256

    for gs in grid_sizes:
        # TODO(1): instantiate the Sobol explainer
        # explainer = SobolAttributionMethod(...)
        raise NotImplementedError("TODO(1)")

        # TODO(2): run explanation for feat_b and a target class
        # explanation = explainer.explain(...)
        raise NotImplementedError("TODO(2)")

        show_attr(feat.cpu(), explanation[0].detach().cpu(),
                  title=f"Sobol grid={gs}, nb_samples={nb_samples}")
else:
    print("Sobol library not installed — skip this section.")


## 9) (Bonus) Compare with a gradient-based attribution

We compute a simple saliency map (absolute gradient of the target logit w.r.t. the input).

### Question
- **Q9.** What does Sobol capture that saliency does not (and vice versa)?


In [None]:
# Exercise 8 (bonus) — Gradient saliency for comparison
# Implement a simple |d logit / d x| saliency map.

def gradient_attribution(model, x, target_class):
    x = x.clone().detach().requires_grad_(True)
    logits = model(x)
    loss = logits[0, target_class]
    loss.backward()
    return x.grad.abs().detach()

# Get a batch (as you want ! val_loader is perhapse not the best ! )
feat_b, y_b = next(iter(val_loader))
feat_b = feat_b.to(device)
y_b = y_b.to(device)

# Pick one example
x_demo = feat_b[0:1]
y_demo = y_b[0].item()

# Compute attribution
attr_grad = gradient_attribution(model, x_demo, y_demo)

# Visualize
plt.figure(figsize=(6, 4))
plt.imshow(attr_grad[0, 0].cpu(), origin="lower", aspect="auto")
plt.title("Gradient-based attribution")
plt.colorbar()
plt.show()



## 10) Wrap-up questions

Answer briefly 

2. **Q10.** How would you choose `grid_size` in a principled way? Give 2 criteria.
3. **Q11.** Give one potential failure case: an attribution map that looks “nice” but is misleading.
4. **Q12.** What if your model exploits noise instead of frequency cues? How would you detect it?


# 11) To go further and more interesting applications (part of the project in audio ! )
1. try with a multi-class frequency (not only low and high frequency)
2. Try other representations (ERB, Mel Spectrogram) and check the differences
3. try with more realistic dataset and other architecture here is an non exhaustive list:
- speech commands https://huggingface.co/datasets/google/speech_commands and AST https://huggingface.co/MIT/ast-finetuned-speech-commands-v2
- music genre https://huggingface.co/datasets/ccmusic-database/music_genre  (classic VGG, VIT, Convnext pretrained models)
4. provide a complete analysis using the same XAI system and previous code 
5. Other recommendations: you could also think preparing a small test dataset for pointing game or delation scores etc. 