# PANNs Revisados: Embeddings From Fine-Tuned Model

Carga `best.pt` y genera embeddings en disco, sin re-entrenar.

# 1. Configuración

Este notebook replica el flujo del baseline de PANNs, pero insertando un bloque de **fine-tuning** más parecido al de MERT: además de una cabeza de proyección para embeddings, se permite **adaptación parcial del backbone** (últimos bloques convolucionales) para especializar el modelo al dominio EGFxSet.

Objetivos:

- Medir tiempo total de entrenamiento.
- Mantener protocolo de evaluación idéntico (Top‑1/Top‑5 por similitud coseno).
- Comparar con: (i) PANNs preentrenado, (ii) fine-tuning con cabeza congelando backbone, (iii) este experimento con backbone parcialmente entrenable.

In [1]:
import sys
import os
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torchaudio
import soundfile as sf

import re

from tqdm import tqdm
import transformers
from transformers import AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DATA_ROOT = Path("../EGFxSet")
OUT_DIR   = Path("./outputs_01_panns_base_egfxset")

OUT_DIR.mkdir(parents=True, exist_ok=True)

print("DATA_ROOT exists:", DATA_ROOT.exists(), "->", DATA_ROOT)
print("OUT_DIR exists:", OUT_DIR.exists(), "->", OUT_DIR)

DATA_ROOT exists: True -> ../EGFxSet
OUT_DIR exists: True -> outputs_01_panns_base_egfxset


## 2. Semilla y carpeta de salida

Fija semilla (reproducibilidad) y define el directorio de salida de este experimento (02).

In [4]:
import os, random
from pathlib import Path
import numpy as np
import torch

OUT_DIR_03 = Path("./outputs_03_panns_finetune_egfxset_unfreeze")
OUT_DIR_03.mkdir(parents=True, exist_ok=True)

SEED = 42
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(SEED)
print("OUT_DIR_03:", OUT_DIR_03.resolve())

OUT_DIR_03: /Users/dtenreiro/Documents/TFM/panns_inference/outputs_03_panns_finetune_egfxset_unfreeze


## 3. Dispositivo

Selecciona `mps/cuda/cpu` igual que en el baseline.

In [5]:
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
device

if device == "cuda":
    print("CUDA device:", torch.cuda.get_device_name(0))
elif device == "mps":
    print("Apple Silicon MPS enabled")
else:
    print("CPU only")

Apple Silicon MPS enabled


# 2. Carga del índice del dataset

Reutiliza el `egfxset_index.csv` generado por el baseline de PANNs para mantener el **mismo orden** y metadatos.

In [7]:
import pandas as pd
from pathlib import Path

INDEX_CSV_PATH = Path("../egfxset_index.csv")
assert INDEX_CSV_PATH.exists(), f"No encuentro {INDEX_CSV_PATH}. Ejecuta antes el notebook 01_PANNs_base_egfx."

df = pd.read_csv(INDEX_CSV_PATH)
print("Index:", df.shape)
df.head()

Index: (8947, 7)


Unnamed: 0,path,tone,pickup,tone_pickup,string,fret,midi_pitch
0,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,0,64
1,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,1,65
2,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,10,74
3,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,11,75
4,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,12,76


## 1. Configuración del objetivo de fine-tuning

La tarea de fine-tuning será **clasificar `tone_pickup`** (65 clases), como en el fine-tuning de MERT.

In [8]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np

LABEL_COL = "tone_pickup"

df["row_id"] = np.arange(len(df), dtype=int)

le = LabelEncoder()
df["label_id"] = le.fit_transform(df[LABEL_COL].astype(str))
N_CLASSES = len(le.classes_)
print("N_CLASSES:", N_CLASSES)

train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    random_state=SEED,
    stratify=df["label_id"]
)

train_df = train_df.reset_index(drop=True)
val_df   = val_df.reset_index(drop=True)

train_idx = train_df["row_id"].to_numpy()
val_idx   = val_df["row_id"].to_numpy()

print("Train:", train_df.shape, "| Val:", val_df.shape)
print("train_idx:", train_idx.shape, "| val_idx:", val_idx.shape)

N_CLASSES: 65
Train: (7157, 9) | Val: (1790, 9)
train_idx: (7157,) | val_idx: (1790,)


# 3. PANNs: carga y fine-tuning

Cargamos PANNs CNN14 preentrenado (igual que en el baseline) y definimos una cabeza de clasificación para `tone_pickup`.

## 1. Parámetros de audio y utilidades de pooling

Se mantiene el mismo preprocesado (32 kHz, recorte/padding a 5 s).

In [9]:
MAX_SECONDS = 5.0
TARGET_SR = 32000

def load_audio_panns(path: str | Path):
    """Carga wav -> mono -> resample a 32kHz -> recorta a 5s"""
    wav, sr = sf.read(str(path))

    if wav.ndim == 2:
        wav = wav.mean(axis=1)

    wav = torch.from_numpy(wav).float()

    if sr != TARGET_SR:
        wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
        sr = TARGET_SR

    wav = wav[: int(sr * MAX_SECONDS)]
    return wav, sr

def temporal_pool(x: torch.Tensor, mode: str) -> torch.Tensor:
    """
    x: (T, D) ó (B, T, D). Devuelve (D) ó (B, D)
    """
    if x.dim() == 2:
        x0 = x
        if mode == "mean":
            return x0.mean(dim=0)
        elif mode == "max":
            return x0.max(dim=0).values
        elif mode == "stats":
            mu = x0.mean(dim=0)
            sd = x0.std(dim=0, unbiased=False)
            return torch.cat([mu, sd], dim=0)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    if x.dim() == 3:
        if mode == "mean":
            return x.mean(dim=1)
        elif mode == "max":
            return x.max(dim=1).values
        elif mode == "stats":
            mu = x.mean(dim=1)
            sd = x.std(dim=1, unbiased=False)
            return torch.cat([mu, sd], dim=2)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    raise ValueError(f"Unexpected tensor rank: {x.dim()}")

## 2. Dataset + DataLoaders

Construye un `Dataset` y un `collate_fn` que:
- Carga WAV
- Convierte a mono
- Re-muestrea a 32 kHz
- Recorta o **rellena con ceros** hasta 5 s

Esto permite hacer batches con tensores de longitud fija.

In [10]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

PATH_COL = "path"
LABEL_ID_COL = "label_id"

MAX_LEN = int(TARGET_SR * MAX_SECONDS)

def load_audio_fixed(path: str | Path):
    wav, sr = load_audio_panns(path)

    if wav.numel() < MAX_LEN:
        pad = MAX_LEN - wav.numel()
        wav = torch.nn.functional.pad(wav, (0, pad))
    else:
        wav = wav[:MAX_LEN]
    return wav

class EGFxSetToneDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        wav = load_audio_fixed(row[PATH_COL])
        y = int(row[LABEL_ID_COL])
        return wav, y

def collate_fn(batch):
    wavs, ys = zip(*batch)
    x = torch.stack([w.float() for w in wavs], dim=0)  # (B, T)
    y = torch.tensor(ys, dtype=torch.long)
    return x, y

BATCH_SIZE = 16
train_loader = DataLoader(EGFxSetToneDataset(train_df), batch_size=BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_loader   = DataLoader(EGFxSetToneDataset(val_df),   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)

## 3. Cargar PANNs CNN14 + checkpoint

Reutiliza el mismo bloque del baseline para localizar y cargar el `.pth`.

In [11]:
from pathlib import Path
import urllib.request

OUT = Path("pretrained_models")
OUT.mkdir(parents=True, exist_ok=True)

url = "https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1"
ckpt = OUT / "Cnn14_mAP=0.431.pth"

if not ckpt.exists():
    urllib.request.urlretrieve(url, ckpt)

PANNS_CNN14_CKPT = str(ckpt)
print("Checkpoint:", PANNS_CNN14_CKPT)

Checkpoint: pretrained_models/Cnn14_mAP=0.431.pth


In [12]:
from pathlib import Path
import os
import inspect
import torch
from panns_inference.models import Cnn14

def find_cnn14_ckpt() -> Path:

    env = os.environ.get("PANNS_CNN14_CKPT", "").strip()
    if env:
        p = Path(env).expanduser()
        if p.exists() and p.is_file():
            return p
        raise FileNotFoundError(f"PANNS_CNN14_CKPT apunta a algo que no es fichero: {p}")


    pkg_root = Path(inspect.getfile(Cnn14)).resolve().parent
    candidates = []
    for rel in [
        "../pretrained_models",
        "../pretrained",
        "../data",
        "../../pretrained_models",
        "../../pretrained",
    ]:
        d = (pkg_root / rel).resolve()
        if d.exists() and d.is_dir():
            candidates += list(d.rglob("*.pth"))

    candidates = sorted(candidates, key=lambda p: (("cnn14" not in p.name.lower()), len(p.name)))
    if candidates:
        return candidates[0]

    raise FileNotFoundError(
        "No se ha encontrado ningún checkpoint .pth de PANNs CNN14.\n"
    )

CKPT = find_cnn14_ckpt()
print("Using checkpoint:", CKPT)

panns_model = Cnn14(
    sample_rate=TARGET_SR,
    window_size=1024,
    hop_size=320,
    mel_bins=64,
    fmin=50,
    fmax=14000,
    classes_num=527
)

ckpt = torch.load(CKPT, map_location=device)
state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
panns_model.load_state_dict(state, strict=True)

panns_model.to(device).eval()
print("Loaded PANNs:", panns_model.__class__.__name__, "on", device)

Using checkpoint: /Users/dtenreiro/Documents/TFM/panns_inference/pretrained_models/Cnn14_mAP=0.431.pth
Loaded PANNs: Cnn14 on mps


## 4. Modelo para fine-tuning

Crea un wrapper que:
- Obtiene el embedding global de PANNs (`out['embedding']`).
- Aplica una capa `Linear` para predecir `tone_pickup`.

Así podemos entrenar una cabeza ligera y, opcionalmente, descongelar una parte final del encoder.

## Fine-tuning con *projection head* + adaptación parcial del backbone

**Objetivo**: mejorar la calidad del espacio de embeddings para *retrieval* (Top‑1/Top‑5) manteniendo el pooling interno de PANNs, pero permitiendo una **adaptación controlada** de las **últimas capas** del encoder.

**Qué cambia respecto al notebook anterior**:
- Además de la proyección `2048 → 256`, se **descongelan** los últimos bloques convolucionales para que el modelo pueda especializarse al dominio EGFxSet (estrategia más cercana al fine‑tuning parcial aplicado en MERT).
- El resto del protocolo (split, ventanas, evaluación por similitud coseno) se mantiene idéntico para comparar de forma directa.


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PANNsForTonePickup(nn.Module):
    """Wrapper de fine-tuning para PANNs.

    - backbone: PANNs (CNN14) preentrenado
    - proj: proyección entrenable para compactar el embedding (mejor retrieval)
    - classifier: capa final de clasificación

    Nota: el pooling temporal interno de PANNs NO se modifica; aquí solo se adapta el embedding final.
    """
    def __init__(self, panns_backbone: nn.Module, emb_dim: int, proj_dim: int, n_classes: int):
        super().__init__()
        self.panns = panns_backbone
        self.proj = nn.Linear(emb_dim, proj_dim)
        self.classifier = nn.Linear(proj_dim, n_classes)

    def forward(self, x):
        out = self.panns(x, None)
        emb_raw = out["embedding"]
        emb = self.proj(emb_raw)
        emb = F.normalize(emb, p=2, dim=-1)
        logits = self.classifier(emb)
        return logits, emb, emb_raw

panns_model.eval()
with torch.no_grad():
    x0, _ = next(iter(train_loader))
    x0 = x0.to(device)
    out0 = panns_model(x0, None)
    emb_dim = int(out0["embedding"].shape[-1])

PROJ_DIM = 256
ft_model = PANNsForTonePickup(panns_model, emb_dim=emb_dim, proj_dim=PROJ_DIM, n_classes=N_CLASSES).to(device)
print("emb_dim:", emb_dim, "| proj_dim:", PROJ_DIM)

emb_dim: 2048 | proj_dim: 256


## Cargar checkpoint fine-tuned guardado


In [14]:
BEST_PT = OUT_DIR_03 / "best.pt"
assert BEST_PT.exists(), f"No existe checkpoint fine-tuned: {BEST_PT}"

ckpt_ft = torch.load(BEST_PT, map_location=device)
state_dict = ckpt_ft.get("state_dict", ckpt_ft)
ft_model.load_state_dict(state_dict, strict=True)
ft_model.eval()

print("Loaded fine-tuned checkpoint:", BEST_PT)
print("best_epoch:", ckpt_ft.get("best_epoch", None), "| best_val_loss:", ckpt_ft.get("best_val_loss", None))

Loaded fine-tuned checkpoint: outputs_03_panns_finetune_egfxset_unfreeze/best.pt
best_epoch: 30 | best_val_loss: 0.31401769463909407


## Guardar snapshots de split (trazabilidad)


In [15]:
train_df.to_csv(OUT_DIR_03 / "egfxset_index_used_train.csv", index=False)
val_df.to_csv(OUT_DIR_03 / "egfxset_index_used_val.csv", index=False)
print("Saved:", OUT_DIR_03 / "egfxset_index_used_train.csv")
print("Saved:", OUT_DIR_03 / "egfxset_index_used_val.csv")

Saved: outputs_03_panns_finetune_egfxset_unfreeze/egfxset_index_used_train.csv
Saved: outputs_03_panns_finetune_egfxset_unfreeze/egfxset_index_used_val.csv


## 2. Función de extracción de embeddings (fine-tuned)

Mismo formato que el baseline, pero usando el backbone fine-tuned.

## Extracción de embeddings (para evaluación Top‑1/Top‑5)

Guardamos dos variantes:
- `global_raw`: embedding global original de PANNs (2048)
- `global_proj`: embedding proyectado y normalizado (256) -> **el candidato principal** para retrieval

Opcionalmente mantenemos las variantes framewise (si existen) usando el framewise output del backbone.


In [16]:
def extract_panns_variants_ft(path):
    """Devuelve dict {variant: embedding_1D_cpu}."""
    x, sr = load_audio_panns(path)
    x = x.to(device).unsqueeze(0)  # (1, n)

    variants = {}

    with torch.no_grad():
        _, emb_proj, emb_raw = ft_model(x)

        out_backbone = ft_model.panns(x, None)

    variants["global_raw"]  = emb_raw.squeeze(0).float().cpu()
    variants["global_proj"] = emb_proj.squeeze(0).float().cpu()

    frame = None
    if "framewise_embedding" in out_backbone:
        frame = out_backbone["framewise_embedding"]  # (B, T, D)
    elif "framewise_output" in out_backbone:
        frame = out_backbone["framewise_output"]     # (B, T, D) o (B, T, C)

    if frame is not None:
        frame = frame.squeeze(0).float().cpu()  # (T, D)
        variants["frame_mean"]  = temporal_pool(frame, "mean")
        variants["frame_max"]   = temporal_pool(frame, "max")
        variants["frame_stats"] = temporal_pool(frame, "stats")

    return variants

# 5. Embeddings tras fine-tuning

Extrae embeddings para **todo** el dataset, guardándolos en memmaps dentro de `outputs_02_...`.

## Generación masiva de embeddings en disco

Se generan ficheros `.npy` para todas las variantes retornadas por `extract_panns_variants_ft`.
En este experimento, la variante clave es `global_proj` (256), ya normalizada L2 en el `forward`.
Aun así, en la evaluación aplicamos una normalización L2 adicional por seguridad (no hace daño).


In [17]:
from tqdm import tqdm
import numpy as np

paths = df["path"].tolist()
N = len(paths)

test_vars = extract_panns_variants_ft(paths[0])
dims = {k: v.numel() for k, v in test_vars.items()}
print("Variants/dims:", dims)

out_memmaps = {}
for variant, dim in dims.items():
    out_path = OUT_DIR_03 / f"embeddings_panns_ft_proj_{variant}.npy"
    mm = np.memmap(out_path, dtype=np.float32, mode="w+", shape=(N, dim))
    out_memmaps[variant] = (mm, out_path)

for variant, vec in test_vars.items():
    out_memmaps[variant][0][0] = vec.numpy().astype(np.float32)

for i, path in enumerate(tqdm(paths[1:], desc="Extracting PANNs FT embeddings"), start=1):
    vars_i = extract_panns_variants_ft(path)
    for variant in dims.keys():
        vec = vars_i.get(variant, vars_i.get("global_proj", vars_i["global_raw"]))
        out_memmaps[variant][0][i] = vec.numpy().astype(np.float32)

for variant, (mm, out_path) in out_memmaps.items():
    mm.flush()
    print("Saved:", out_path)

Variants/dims: {'global_raw': 2048, 'global_proj': 256}


Extracting PANNs FT embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8946/8946 [01:44<00:00, 85.60it/s]

Saved: outputs_03_panns_finetune_egfxset_unfreeze/embeddings_panns_ft_proj_global_raw.npy
Saved: outputs_03_panns_finetune_egfxset_unfreeze/embeddings_panns_ft_proj_global_proj.npy



