# Fine-tuned Checkpoint + Embeddings

Notebook de preparación de artefactos. Carga el checkpoint fine-tuned y genera embeddings de `train` y `val`.


1. Configuración

In [1]:
import os, random
import numpy as np
import torch
from pathlib import Path

SEED = 42
OUT_DIR_02 = Path("./outputs_02_mert_finetune_egfxset")

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)

seed_everything(SEED)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("device:", device, "| torch:", torch.__version__, "| seed:", SEED)


device: mps | torch: 2.9.1 | seed: 42


2. Índice

In [2]:
import pandas as pd
from pathlib import Path

INDEX_CSV_PATH = Path("../egfxset_index.csv")
assert INDEX_CSV_PATH.exists(), f"No existe el índice: {INDEX_CSV_PATH}"

df = pd.read_csv(INDEX_CSV_PATH)

print("Dataset cargado")
print("Número de muestras:", len(df))
print("Columnas:", list(df.columns))
display(df.head())


Dataset cargado
Número de muestras: 8947
Columnas: ['path', 'tone', 'pickup', 'tone_pickup', 'string', 'fret', 'midi_pitch']


Unnamed: 0,path,tone,pickup,tone_pickup,string,fret,midi_pitch
0,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,0,64
1,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,1,65
2,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,10,74
3,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,11,75
4,/Users/dtenreiro/Documents/TFM/EGFxSet/BluesDr...,BluesDriver,Bridge,BluesDriver__Bridge,1,12,76


3. Split de datos

In [3]:
from sklearn.model_selection import train_test_split

LABEL_COL = "tone_pickup"
PATH_COL = "path"

assert LABEL_COL in df.columns, f"No existe '{LABEL_COL}'. Columnas: {list(df.columns)}"
assert PATH_COL in df.columns, f"No existe '{PATH_COL}'. Columnas: {list(df.columns)}"

train_df, val_df = train_test_split(
    df,
    test_size=0.20,
    random_state=SEED,
    stratify=df[LABEL_COL]
)

print("Split OK")
print("Train:", len(train_df), "| Val:", len(val_df))
print("Clases train:", train_df[LABEL_COL].nunique(), "| Clases val:", val_df[LABEL_COL].nunique())

Split OK
Train: 7157 | Val: 1790
Clases train: 65 | Clases val: 65


4. Audio loader

In [4]:
import torchaudio
import soundfile as sf
from pathlib import Path

MAX_SECONDS = 5.0
TARGET_SR = 24000
TARGET_LEN = int(TARGET_SR * MAX_SECONDS)

def load_audio(path: str | Path):
    wav, sr = sf.read(str(path))

    # mono
    if isinstance(wav, np.ndarray) and wav.ndim == 2:
        wav = wav.mean(axis=1)

    # to torch float32
    wav = torch.tensor(wav, dtype=torch.float32)

    # resample
    if sr != TARGET_SR:
        wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
        sr = TARGET_SR

    # sanitize
    wav = torch.nan_to_num(wav, nan=0.0, posinf=0.0, neginf=0.0)

    # pad/trim to fixed length
    if wav.numel() < TARGET_LEN:
        pad = TARGET_LEN - wav.numel()
        wav = torch.cat([wav, torch.zeros(pad, dtype=wav.dtype)], dim=0)
    else:
        wav = wav[:TARGET_LEN]

    return wav, sr

5. MERT

In [5]:
from transformers import AutoModel, AutoProcessor

MERT_MODEL_NAME = "m-a-p/MERT-v1-330M"
print("Loading MERT model:", MERT_MODEL_NAME)

mert_model = AutoModel.from_pretrained(
    MERT_MODEL_NAME,
    trust_remote_code=True
).to(device)

processor = AutoProcessor.from_pretrained(
    MERT_MODEL_NAME,
    trust_remote_code=True
)

mert_model.config.output_hidden_states = True

print("MERT loaded on device:", device)

  from .autonotebook import tqdm as notebook_tqdm


Loading MERT model: m-a-p/MERT-v1-330M


Some weights of the model checkpoint at m-a-p/MERT-v1-330M were not used when initializing MERTModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing MERTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MERTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MERTModel were not initialized from the model checkpoint at m-a-p/MERT-v1-330M and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MERT loaded on device: mps


In [6]:
import torch.nn as nn
import torch.nn.functional as F

class MERTForTonePickup(nn.Module):
    def __init__(self, mert_model, num_classes: int, proj_dim: int = 256, dropout: float = 0.1):
        super().__init__()
        self.mert = mert_model

        hidden_size = getattr(self.mert.config, "hidden_size", None)
        if hidden_size is None:
            raise ValueError("No se pudo leer hidden_size de mert_model.config.hidden_size")

        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_size, proj_dim)
        self.classifier = nn.Linear(proj_dim, num_classes)

    def forward(self, inputs, labels=None):
        out = self.mert(**inputs)
        h = out.last_hidden_state  # (B, T, H)

        # MAX pooling temporal
        emb = h.max(dim=1).values  # (B, H)
        emb = self.dropout(emb)

        emb = F.relu(self.proj(emb))  # (B, proj_dim)
        emb = self.dropout(emb)

        logits = self.classifier(emb)
        loss = F.cross_entropy(logits, labels) if labels is not None else None
        return {"loss": loss, "logits": logits, "emb": emb}

6. Modelo fine-tuned

6.1. Carga de checkpoint y mapeo de clases


In [7]:
import torch
from pathlib import Path

CKPT_PATH = OUT_DIR_02 / "mert_ft_last4layers_headproj.pt"
assert CKPT_PATH.exists(), f"No existe el checkpoint: {CKPT_PATH}"

ckpt = torch.load(CKPT_PATH, map_location="cpu")

if "class_to_idx" in ckpt:
    class_to_idx = ckpt["class_to_idx"]
elif "label2id" in ckpt:
    class_to_idx = ckpt["label2id"]
else:
    LABEL_COL = "tone_pickup"
    classes = sorted(train_df[LABEL_COL].astype(str).unique().tolist())
    class_to_idx = {c: i for i, c in enumerate(classes)}

    ckpt["class_to_idx"] = class_to_idx
    patched_path = CKPT_PATH.with_name(CKPT_PATH.stem + "_patched_with_class_to_idx.pt")
    torch.save(ckpt, patched_path)
    print("Checkpoint parcheado guardado en:", patched_path)

    CKPT_PATH = patched_path

idx_to_class = {v: k for k, v in class_to_idx.items()}
print("Usando checkpoint:", CKPT_PATH)
print("Num clases:", len(class_to_idx))
print("Ejemplo:", list(class_to_idx.items())[:5])

Checkpoint parcheado guardado en: outputs_02_mert_finetune_egfxset/mert_ft_last4layers_headproj_patched_with_class_to_idx.pt
Usando checkpoint: outputs_02_mert_finetune_egfxset/mert_ft_last4layers_headproj_patched_with_class_to_idx.pt
Num clases: 65
Ejemplo: [('BluesDriver__Bridge', 0), ('BluesDriver__Bridge-Middle', 1), ('BluesDriver__Middle', 2), ('BluesDriver__Middle-Neck', 3), ('BluesDriver__Neck', 4)]


In [8]:
from pathlib import Path

OUT_DIR_02 = Path("./outputs_02_mert_finetune_egfxset")

assert CKPT_PATH.exists(), f"No existe el checkpoint: {CKPT_PATH}"

ckpt = torch.load(CKPT_PATH, map_location=device)

num_classes = int(ckpt["num_classes"])

ft_model = MERTForTonePickup(
    mert_model,
    num_classes=num_classes,
    proj_dim=256,
    dropout=0.1
).to(device)

ft_model.load_state_dict(ckpt["ft_model_state_dict"], strict=True)
ft_model.eval()

print("Loaded fine-tuned checkpoint:", CKPT_PATH)

Loaded fine-tuned checkpoint: outputs_02_mert_finetune_egfxset/mert_ft_last4layers_headproj_patched_with_class_to_idx.pt


In [9]:
LABEL_COL = "tone_pickup"
train_classes = set(train_df[LABEL_COL].astype(str).unique())
val_classes   = set(val_df[LABEL_COL].astype(str).unique())

missing_train = train_classes - set(class_to_idx.keys())
missing_val   = val_classes   - set(class_to_idx.keys())

print("Missing (train):", len(missing_train))
print("Missing (val):", len(missing_val))

Missing (train): 0
Missing (val): 0


7. Dataset

In [10]:
from torch.utils.data import Dataset, DataLoader

class FullDFDataset(Dataset):
    def __init__(self, df):
        self.paths = df[PATH_COL].astype(str).tolist()
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        wav, sr = load_audio(self.paths[idx])
        return wav, sr

def collate_audio_only(batch):
    wavs, srs = zip(*batch)

    wavs_np = [w.cpu().numpy().astype(np.float32) for w in wavs]

    inputs = processor(
        wavs_np,
        sampling_rate=srs[0],
        return_tensors="pt",
        padding=False,
        return_attention_mask=True
    )
    return inputs

8. Embeddings de Train y Val

In [11]:
from tqdm import tqdm

FT_TRAIN_EMBS_PATH  = OUT_DIR_02 / "mert_ft_embs_train.npy"
FT_VAL_EMBS_PATH    = OUT_DIR_02 / "mert_ft_embs_val.npy"
FT_TRAIN_INDEX_PATH = OUT_DIR_02 / "egfxset_index_used_train.csv"
FT_VAL_INDEX_PATH   = OUT_DIR_02 / "egfxset_index_used_val.csv"

train_df.to_csv(FT_TRAIN_INDEX_PATH, index=False)
val_df.to_csv(FT_VAL_INDEX_PATH, index=False)

def extract_embeddings_df(df_subset, out_embs_path, batch_size=8):
    loader = DataLoader(
        FullDFDataset(df_subset),
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_audio_only
    )

    all_embs = []
    ft_model.eval()

    with torch.no_grad():
        for inputs in tqdm(loader):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            out = ft_model(inputs, labels=None)
            emb = out["emb"].detach().cpu().numpy()  # (B, 256)
            all_embs.append(emb)

    all_embs = np.vstack(all_embs)
    np.save(out_embs_path, all_embs)
    print("Saved:", out_embs_path, "| shape:", all_embs.shape)
    return all_embs

BATCH_SIZE = 8
embs_train = extract_embeddings_df(train_df, FT_TRAIN_EMBS_PATH, batch_size=BATCH_SIZE)
embs_val   = extract_embeddings_df(val_df,   FT_VAL_EMBS_PATH,   batch_size=BATCH_SIZE)

print("Index snapshots saved:")
print(" -", FT_TRAIN_INDEX_PATH)
print(" -", FT_VAL_INDEX_PATH)

100%|█████████████████████████████████████████| 895/895 [06:34<00:00,  2.27it/s]


Saved: outputs_02_mert_finetune_egfxset/mert_ft_embs_train.npy | shape: (7157, 256)


100%|█████████████████████████████████████████| 224/224 [01:54<00:00,  1.96it/s]

Saved: outputs_02_mert_finetune_egfxset/mert_ft_embs_val.npy | shape: (1790, 256)
Index snapshots saved:
 - outputs_02_mert_finetune_egfxset/egfxset_index_used_train.csv
 - outputs_02_mert_finetune_egfxset/egfxset_index_used_val.csv





In [12]:
import numpy as np
import pandas as pd
from pathlib import Path

OUT_DIR_02 = Path("./outputs_02_mert_finetune_egfxset")

FT_TRAIN_EMBS_PATH  = OUT_DIR_02 / "mert_ft_embs_train.npy"
FT_VAL_EMBS_PATH    = OUT_DIR_02 / "mert_ft_embs_val.npy"
FT_TRAIN_INDEX_PATH = OUT_DIR_02 / "egfxset_index_used_train.csv"
FT_VAL_INDEX_PATH   = OUT_DIR_02 / "egfxset_index_used_val.csv"

assert FT_TRAIN_EMBS_PATH.exists()
assert FT_VAL_EMBS_PATH.exists()
assert FT_TRAIN_INDEX_PATH.exists()
assert FT_VAL_INDEX_PATH.exists()

print("OK paths")

OK paths


In [13]:
LABEL_COL = "tone_pickup" 

train_df = pd.read_csv(FT_TRAIN_INDEX_PATH)
val_df   = pd.read_csv(FT_VAL_INDEX_PATH)

X_train = np.load(FT_TRAIN_EMBS_PATH)
X_val   = np.load(FT_VAL_EMBS_PATH)

assert len(train_df) == X_train.shape[0], (len(train_df), X_train.shape)
assert len(val_df)   == X_val.shape[0],   (len(val_df), X_val.shape)

# Mapeo de clases consistente entre train y val
classes = sorted(pd.concat([train_df[LABEL_COL], val_df[LABEL_COL]]).unique().tolist())
class2id = {c:i for i,c in enumerate(classes)}

y_train = train_df[LABEL_COL].map(class2id).to_numpy()
y_val   = val_df[LABEL_COL].map(class2id).to_numpy()

num_classes = len(classes)
chance = 1.0 / num_classes

print("X_train:", X_train.shape, "| X_val:", X_val.shape)
print("num_classes:", num_classes, "| chance:", chance)

X_train: (7157, 256) | X_val: (1790, 256)
num_classes: 65 | chance: 0.015384615384615385
