In [1]:
import os
from dataclasses import dataclass
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio

In [None]:
@dataclass
class CFG:
    data_root: str = "."
    sample_rate: int = 16000
    clip_seconds: float = 4.0            # fixed-length clips
    batch_size: int = 16
    num_workers: int = 4
    lr: float = 1e-3
    epochs: int = 5
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

cfg = CFG()

In [4]:
cfg.device

'cuda'

In [5]:
AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".m4a"}

def list_class_folders(data_root: str) -> List[str]:
    classes = []
    for name in os.listdir(data_root):
        p = os.path.join(data_root, name)
        if os.path.isdir(p):
            classes.append(name)
    return sorted(classes)

def index_audio_files(data_root: str, classes: List[str]) -> List[Tuple[str, int]]:
    items = []
    class_to_idx = {c: i for i, c in enumerate(classes)}
    for c in classes:
        class_dir = os.path.join(data_root, c)
        for root, _, files in os.walk(class_dir):
            for fn in files:
                ext = os.path.splitext(fn)[1].lower()
                if ext in AUDIO_EXTS:
                    items.append((os.path.join(root, fn), class_to_idx[c]))
    return items

class FolderAudioDataset(Dataset):
    def __init__(self, data_root: str, sample_rate: int, clip_seconds: float):
        self.data_root = data_root
        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * clip_seconds)

        self.classes = list_class_folders(data_root)
        if not self.classes:
            raise ValueError(f"No class subfolders found in: {data_root}")

        self.items = index_audio_files(data_root, self.classes)
        if not self.items:
            raise ValueError(f"No audio files found under: {data_root}")

        self.resamplers: Dict[int, torchaudio.transforms.Resample] = {}

    def __len__(self):
        return len(self.items)

    def _resample_if_needed(self, wav: torch.Tensor, orig_sr: int) -> torch.Tensor:
        if orig_sr == self.sample_rate:
            return wav
        if orig_sr not in self.resamplers:
            self.resamplers[orig_sr] = torchaudio.transforms.Resample(orig_sr, self.sample_rate)
        return self.resamplers[orig_sr](wav)

    def _to_mono(self, wav: torch.Tensor) -> torch.Tensor:
        # wav: [channels, time]
        if wav.size(0) == 1:
            return wav
        return wav.mean(dim=0, keepdim=True)

    def _pad_or_trim(self, wav: torch.Tensor) -> torch.Tensor:
        # wav: [1, time]
        T = wav.size(1)
        if T == self.num_samples:
            return wav
        if T > self.num_samples:
            return wav[:, : self.num_samples]
        pad = self.num_samples - T
        return F.pad(wav, (0, pad))

    def __getitem__(self, idx: int):
        path, label = self.items[idx]
        wav, sr = torchaudio.load(path)             # wav: [C, T]
        wav = self._to_mono(wav)
        wav = self._resample_if_needed(wav, sr)
        wav = self._pad_or_trim(wav)
        return wav, torch.tensor(label, dtype=torch.long)

In [None]:
def build_model(num_classes: int, freeze_backbone: bool = True):
    bundle = torchaudio.pipelines.WAV2VEC2_BASE
    backbone = bundle.get_model()  # wav2vec2 feature encoder + transformer

    if freeze_backbone:
        for p in backbone.parameters():
            p.requires_grad = False

    # The encoder outputs features with hidden size = bundle._params["encoder_embed_dim"],
    # so we infer it by running a tiny forward once
    backbone.eval()
    with torch.no_grad():
        dummy = torch.zeros(1, 16000)  # [B, T]
        feats, _ = backbone.extract_features(dummy)  # list of layer outputs
        last = feats[-1]  # [B, frames, hidden]
        hidden_dim = last.size(-1)

    backbone.train()

    class Model(nn.Module):
        def __init__(self, backbone, hidden_dim, num_classes):
            super().__init__()
            self.backbone = backbone
            self.classifier = nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Linear(hidden_dim, num_classes),
            )

        def forward(self, wav: torch.Tensor):
            # wav expected: [B, 1, T] from dataset; wav2vec2 expects [B, T]
            wav = wav.squeeze(1)

            feats, _ = self.backbone.extract_features(wav)
            x = feats[-1]              # [B, frames, hidden]
            x = x.mean(dim=1)          # simple temporal pooling
            logits = self.classifier(x)
            return logits

    return Model(backbone, hidden_dim, num_classes)

In [7]:
def train_epoch(model, loader, opt, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for wav, y in loader:
        wav, y = wav.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(wav)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        opt.step()

        total_loss += loss.item() * y.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    for wav, y in loader:
        wav, y = wav.to(device), y.to(device)
        logits = model(wav)
        loss = F.cross_entropy(logits, y)

        total_loss += loss.item() * y.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return total_loss / total, correct / total

In [None]:
def main():
    ds = FolderAudioDataset(cfg.data_root, cfg.sample_rate, cfg.clip_seconds)

    # TODO: try stratified split
    n = len(ds)
    n_train = int(0.8 * n)
    n_val = n - n_train
    train_ds, val_ds = torch.utils.data.random_split(ds, [n_train, n_val])

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True)

    model = build_model(num_classes=len(ds.classes), freeze_backbone=True).to(cfg.device)

    opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr)

    print("Classes:", ds.classes)
    for epoch in range(cfg.epochs):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, opt, cfg.device)
        va_loss, va_acc = eval_one_epoch(model, val_loader, cfg.device)
        print(f"epoch {epoch+1:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} "
              f"| val loss {va_loss:.4f} acc {va_acc:.3f}")

if __name__ == "__main__":
    main()

FileNotFoundError: [Errno 2] No such file or directory: 'data_root'