In [None]:
# Clean out TF/Keras (they can grab CUDA memory even if unused)
!pip -q uninstall -y tensorflow keras || true
# Minimal deps
!pip -q install "transformers>=4.40,<5" torchaudio soundfile scikit-learn --upgrade

In [15]:

import os, gc, torch
os.environ["USE_TF"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.cuda.empty_cache(); gc.collect()

print("PyTorch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())


PyTorch: 2.8.0+cu126 | CUDA available: True


In [16]:
from pathlib import Path
import json, random, math
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch, torchaudio
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoFeatureExtractor, AutoConfig, Wav2Vec2Model
from sklearn.metrics import mean_absolute_error, mean_squared_error

# ==== USER SETTINGS ====
DATA_DIR     = "/content/Data"   # folder with .wav or .mp3 + matching .json
TARGET_KEYS  = ["Valence_best","Arousal_best","Submissive_vs._Dominant_best"]

MODEL_NAME   = "facebook/wav2vec2-base-960h"  # small & stable; upgrade later if needed
TARGET_SR    = 16_000
MAX_SECONDS  = 8.0              # keep modest; you can try 10 later
SEED         = 42

# Training
EPOCHS       = 3                # start small
LR           = 1e-3             # higher LR since we train only a tiny head
WEIGHT_DECAY = 0.0
BATCH_SIZE   = 1                # keep at 1 for stability
NUM_WORKERS  = 0                # 0 = no multiprocessing (stable on Colab)
VAL_SPLIT    = 0.1              # if N>1, take ~10% for val
MAX_FILES    = None             # set an int (e.g., 200) for a smoke test; None = all

random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [17]:
def collect_pairs(data_dir: str, target_keys: List[str], limit: Optional[int]=None):
    root = Path(data_dir)
    files = sorted(root.rglob("*.wav")) or sorted(root.rglob("*.mp3"))
    items = []
    for a in files:
        j = a.with_suffix(".json")
        if not j.exists():
            continue
        try:
            emo = json.loads(j.read_text(encoding="utf-8")).get("emotion_annotation", {})
            labels = [float(emo[k]) for k in target_keys]
            if not all(np.isfinite(labels)):
                continue
            items.append({"audio_path": str(a), "labels": labels})
            if limit is not None and len(items) >= limit:
                break
        except Exception:
            continue
    if not items:
        raise RuntimeError("No usable (audio,json) pairs found.")
    return items

items = collect_pairs(DATA_DIR, TARGET_KEYS, limit=MAX_FILES)
random.shuffle(items)

# Split
n = len(items)
if n == 1:
    train_items, val_items = items, []
else:
    n_val = min(max(1, int(n * VAL_SPLIT)), n-1)
    val_items, train_items = items[:n_val], items[n_val:]

print(f"pairs total={n}  train={len(train_items)}  val={len(val_items)}")


pairs total=80  train=72  val=8


In [18]:
MAX_LEN = int(TARGET_SR * MAX_SECONDS)
_resamplers: Dict[Tuple[int,int], torchaudio.transforms.Resample] = {}

def load_first_n_seconds(path: str, target_sr: int, max_seconds: float) -> torch.Tensor:
    # infer original SR without decoding full file
    try:
        info = torchaudio.info(path)
        orig_sr = info.sample_rate
    except Exception:
        _, orig_sr = torchaudio.load(path, frame_offset=0, num_frames=1024)
    frames = int(orig_sr * max_seconds)

    # read only that window
    wav, sr = torchaudio.load(path, frame_offset=0, num_frames=frames)  # (C, T<=frames)

    # mono
    if wav.shape[0] > 1:
        wav = wav.mean(0, keepdim=True)
    # resample minimal window
    if sr != target_sr:
        key = (sr, target_sr)
        if key not in _resamplers:
            _resamplers[key] = torchaudio.transforms.Resample(sr, target_sr)
        wav = _resamplers[key](wav)
    wav = wav.squeeze(0)

    # truncate/pad to EXACT MAX_LEN (so we can use padding="do_not_pad")
    if wav.numel() > MAX_LEN:
        wav = wav[:MAX_LEN]
    if wav.numel() < MAX_LEN:
        wav = torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))

    # peak normalize
    wav = wav / (wav.abs().max() + 1e-9)
    return wav


In [19]:
class PathDataset(Dataset):
    def __init__(self, items: List[Dict[str,Any]]):
        self.items = items
    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        ex = self.items[idx]
        return {"audio_path": ex["audio_path"], "labels": torch.tensor(ex["labels"], dtype=torch.float32)}

train_ds = PathDataset(train_items)
val_ds   = PathDataset(val_items) if len(val_items) else None

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=False, drop_last=False) if val_ds else None

print("Loaders ready.")


Loaders ready.


In [20]:
# Feature extractor & encoder (frozen)
fe = AutoFeatureExtractor.from_pretrained(MODEL_NAME, sampling_rate=TARGET_SR)
enc_cfg = AutoConfig.from_pretrained(MODEL_NAME, output_hidden_states=False)
encoder = Wav2Vec2Model.from_pretrained(MODEL_NAME, config=enc_cfg).to(device)
encoder.eval()
for p in encoder.parameters():
    p.requires_grad = False

# Tiny temporal head: GRU + attention pooling -> 3 outputs
class TemporalHead(nn.Module):
    def __init__(self, d_model=768, hidden=128, out_dim=3):
        super().__init__()
        self.gru = nn.GRU(d_model, hidden, num_layers=1, batch_first=True, bidirectional=True)
        self.att = nn.Sequential(
            nn.Linear(2*hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, 1)
        )
        self.out = nn.Sequential(
            nn.LayerNorm(2*hidden),
            nn.Linear(2*hidden, out_dim)
        )
    def forward(self, hs):                 # hs: (B, T', d_model)
        z, _ = self.gru(hs)                # (B, T', 2H)
        a = self.att(z).squeeze(-1)        # (B, T')
        w = torch.softmax(a, dim=1).unsqueeze(-1)
        pooled = (w * z).sum(dim=1)        # (B, 2H)
        return self.out(pooled)            # (B, out_dim)

num_labels = len(TARGET_KEYS)
head = TemporalHead(d_model=encoder.config.hidden_size, hidden=128, out_dim=num_labels).to(device)

# Only head is trainable
opt = torch.optim.AdamW(head.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
mse = nn.MSELoss()
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))

print("Encoder frozen. Trainable head params:",
      sum(p.numel() for p in head.parameters() if p.requires_grad))


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Encoder frozen. Trainable head params: 723972


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))


In [22]:
EPOCHS = 20
def batch_to_inputs(batch):
    """
    Accepts either:
      - list of dicts (when using a no-op collate_fn), or
      - dict with batched fields (PyTorch default collate)
    and returns (inputs_dict, labels_tensor) with batch size 1.
    """
    # Extract path + labels for a single-sample batch
    if isinstance(batch, list):
        ex = batch[0]
        path = ex["audio_path"]
        labels = ex["labels"]
        if isinstance(labels, torch.Tensor):
            labels = labels.unsqueeze(0)          # (1, D)
        else:
            labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(0)
    elif isinstance(batch, dict):
        # default collate: "audio_path" -> list[str], "labels" -> tensor/list
        path = batch["audio_path"][0]
        labels = batch["labels"]
        if isinstance(labels, torch.Tensor):
            labels = labels[0].unsqueeze(0)       # (1, D)
        else:
            labels = torch.tensor(labels[0], dtype=torch.float32).unsqueeze(0)
    else:
        raise TypeError(f"Unexpected batch type: {type(batch)}")

    # Load only the first MAX_SECONDS from disk, fixed-length & normalized
    wav = load_first_n_seconds(path, TARGET_SR, MAX_SECONDS)

    # Fixed length already â†’ no dynamic padding
    feat = fe(wav, sampling_rate=TARGET_SR, return_tensors="pt", padding="do_not_pad")
    inputs = {k: v.to(device) for k, v in feat.items()}
    labels = labels.to(device)
    return inputs, labels


def encode(inputs):
    with torch.no_grad():
        out = encoder(input_values=inputs["input_values"])
        hs  = out.last_hidden_state  # (B, T', d_model)
    return hs

def run_epoch(loader, train=True):
    if train:
        head.train()
    else:
        head.eval()
    total_loss = 0.0
    n_items = 0
    for batch in loader:
        inputs, labels = batch_to_inputs(batch)
        hs = encode(inputs)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            preds = head(hs)
            loss = mse(preds, labels)
        if train:
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        total_loss += loss.item() * labels.size(0)
        n_items += labels.size(0)
        # free ASAP
        del inputs, labels, hs, preds, loss
        torch.cuda.empty_cache()
    return total_loss / max(1, n_items)

for epoch in range(1, EPOCHS+1):
    tr_loss = run_epoch(train_loader, train=True)
    if val_loader:
        with torch.no_grad():
            va_loss = run_epoch(val_loader, train=False)
        print(f"Epoch {epoch}/{EPOCHS} | train MSE={tr_loss:.4f} | val MSE={va_loss:.4f}")
    else:
        print(f"Epoch {epoch}/{EPOCHS} | train MSE={tr_loss:.4f}")


  info = torchaudio.info(path)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  return AudioMetaData(
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):


Epoch 1/20 | train MSE=0.1809 | val MSE=0.3153
Epoch 2/20 | train MSE=0.1600 | val MSE=0.2728
Epoch 3/20 | train MSE=0.1587 | val MSE=0.3201
Epoch 4/20 | train MSE=0.1617 | val MSE=0.2984
Epoch 5/20 | train MSE=0.1124 | val MSE=0.3070
Epoch 6/20 | train MSE=0.0964 | val MSE=0.3073
Epoch 7/20 | train MSE=0.0813 | val MSE=0.2352
Epoch 8/20 | train MSE=0.0663 | val MSE=0.3383
Epoch 9/20 | train MSE=0.0438 | val MSE=0.3369
Epoch 10/20 | train MSE=0.0482 | val MSE=0.3097
Epoch 11/20 | train MSE=0.0303 | val MSE=0.3211
Epoch 12/20 | train MSE=0.0258 | val MSE=0.2734
Epoch 13/20 | train MSE=0.0199 | val MSE=0.3329
Epoch 14/20 | train MSE=0.0224 | val MSE=0.2699
Epoch 15/20 | train MSE=0.0168 | val MSE=0.2761
Epoch 16/20 | train MSE=0.0213 | val MSE=0.3073
Epoch 17/20 | train MSE=0.0162 | val MSE=0.2499
Epoch 18/20 | train MSE=0.0107 | val MSE=0.3072
Epoch 19/20 | train MSE=0.0112 | val MSE=0.2754
Epoch 20/20 | train MSE=0.0089 | val MSE=0.2888


In [23]:
def ccc(y_true, y_pred):
    y = np.asarray(y_true, np.float64)
    x = np.asarray(y_pred, np.float64)
    vx, vy = x.var(), y.var()
    mx, my = x.mean(), y.mean()
    cov = ((x - mx) * (y - my)).mean()
    denom = vx + vy + (mx - my)**2
    return float(2 * cov / denom) if denom > 0 else 0.0

def evaluate_full(loader):
    y_true, y_pred = [], []
    head.eval()
    with torch.no_grad():
        for batch in loader:
            inputs, labels = batch_to_inputs(batch)
            hs = encode(inputs)
            preds = head(hs)
            y_true.append(labels.detach().cpu().numpy())
            y_pred.append(preds.detach().cpu().numpy())
            del inputs, labels, hs, preds
            torch.cuda.empty_cache()
    Y = np.concatenate(y_true, axis=0)
    P = np.concatenate(y_pred, axis=0)
    mae = mean_absolute_error(Y, P, multioutput="raw_values")
    mse = mean_squared_error(Y, P, multioutput="raw_values")
    metrics = {
        "MAE_macro": float(mae.mean()),
        "MSE_macro": float(mse.mean()),
    }
    for i,k in enumerate(TARGET_KEYS):
        metrics[f"MAE_{k}"] = float(mae[i])
        metrics[f"MSE_{k}"] = float(mse[i])
        metrics[f"CCC_{k}"] = ccc(Y[:,i], P[:,i])
    return metrics

if val_loader and len(val_ds) > 0:
    metrics = evaluate_full(val_loader)
    print("Validation metrics:")
    for k,v in metrics.items():
        print(f"  {k}: {v:.4f}")
else:
    print("No validation split; skipped metrics.")


  info = torchaudio.info(path)


Validation metrics:
  MAE_macro: 0.4424
  MSE_macro: 0.2889
  MAE_Valence_best: 0.7148
  MSE_Valence_best: 0.5886
  CCC_Valence_best: 0.0448
  MAE_Arousal_best: 0.4734
  MSE_Arousal_best: 0.2439
  CCC_Arousal_best: 0.3702
  MAE_Submissive_vs._Dominant_best: 0.1390
  MSE_Submissive_vs._Dominant_best: 0.0341
  CCC_Submissive_vs._Dominant_best: 0.6429


In [None]:
import joblib, json
SAVE_DIR = "/content/w2v2_temporal_head"
Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

# Save torch head
torch.save(head.state_dict(), f"{SAVE_DIR}/temporal_head.pt")
# Save config to rebuild pipeline later
json.dump({
    "model_name": MODEL_NAME,
    "target_sr": TARGET_SR,
    "max_seconds": MAX_SECONDS,
    "target_keys": TARGET_KEYS,
    "head": {"d_model": int(encoder.config.hidden_size), "hidden": 128, "out_dim": len(TARGET_KEYS)}
}, open(f"{SAVE_DIR}/config.json","w"))

print("Saved to", SAVE_DIR)
