In [None]:
# Cell 2 — Imports & Device
import os, glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import timm
from peft import LoraConfig, get_peft_model
from panns_inference import AudioTagging
import librosa
from tqdm.notebook import tqdm
import soundfile as sf
from sklearn.preprocessing import StandardScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


2025-04-30 04:39:01.483219: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745987941.502148   91683 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745987941.508027   91683 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1745987941.522846   91683 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745987941.522867   91683 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745987941.522870   91683 computation_placer.cc:177] computation placer alr

Using device: cuda


In [None]:
# Cell 3 — Paths, Hyperparams & Label Map
BEST_PANNS    = './best_panns_mlp_checkpoint.pth'
BEST_RESNET   = './best_resnet50.pth'
BEST_EFF3     = './best_effnetb3_lora.pth'

BASE_FEAT      = '/home/jovyan/Features'
DEN_DIR        = os.path.join(BASE_FEAT, 'denoised')
MEL_DIR        = os.path.join(BASE_FEAT, 'mel')
EMB_DIR        = os.path.join(BASE_FEAT, 'embeddings')
TRAIN_MANIFEST = os.path.join(BASE_FEAT, 'manifest_train.csv')
TEST_MANIFEST  = os.path.join(BASE_FEAT, 'manifest_test.csv')
TAXONOMY_CSV       = '/home/jovyan/Features/taxonomy.csv'

SR            = 32000
CHUNK_SEC     = 10
CHUNK_SAMPLES = SR * CHUNK_SEC

BATCH_SIZE    = 128
NUM_EPOCHS    = 1
LR            = 1e-3
ALPHA        = 0.5

# build label mapping
tax_df      = pd.read_csv(TAXONOMY_CSV)
labels_all  = sorted(tax_df['primary_label'].unique())
label2idx   = {lab: i for i, lab in enumerate(labels_all)}
NUM_CLASSES = len(labels_all)

print("Num classes:", NUM_CLASSES)



Num classes: 206


In [None]:
def create_augmented_mel(log_mel, emb, alpha=ALPHA):
    """
    log_mel: np.array shape (n_mels, T)
    emb:     np.array shape (embed_dim,)
    """
    n_mels, T    = log_mel.shape
    embed_dim    = emb.shape[0]

    if embed_dim == n_mels:
        proj = emb
    elif embed_dim > n_mels:
        factor = embed_dim // n_mels
        if embed_dim % n_mels == 0:
            proj = emb.reshape(n_mels, factor).mean(axis=1)
        else:
            proj = emb[:n_mels]
    else:
        proj = np.pad(emb, (0, n_mels - embed_dim))

    tiled = np.tile(proj[:, None], (1, T))
    normed = StandardScaler().fit_transform(tiled.T).T
    return log_mel + alpha * normed

In [None]:
# Cell 4 — Load & Freeze Base Models

### 4.1) PANNs embedding extractor
panns_extractor = AudioTagging(checkpoint_path=None, device=device)
panns_extractor.model.eval()

### 4.2) PANNs-MLP classifier
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, num_classes, dropout=0.5):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout)
            ]
        layers.append(nn.Linear(dims[-1], num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

# infer embedding dimension
with torch.no_grad():
    dummy_wav = torch.zeros(1, CHUNK_SAMPLES, device=device)
    _, emb = panns_extractor.inference(dummy_wav)
    emb_dim = emb.shape[-1]

panns_mlp = MLPClassifier(emb_dim, [1024, 512], NUM_CLASSES, dropout=0.5).to(device)
if os.path.exists(BEST_PANNS):
    ckpt = torch.load(BEST_PANNS, map_location=device)
    sd   = ckpt.get('model_state_dict', ckpt)
    panns_mlp.load_state_dict(sd)
    print("✅ Loaded PANNs-MLP checkpoint")
panns_mlp.eval()
for p in panns_mlp.parameters(): p.requires_grad = False

### 4.3) ResNet-50
resnet50 = models.resnet50(weights=None, num_classes=NUM_CLASSES).to(device)
if os.path.exists(BEST_RESNET):
    ckpt = torch.load(BEST_RESNET, map_location=device)
    sd   = ckpt.get('model_state_dict', ckpt)
    resnet50.load_state_dict(sd)
    print("✅ Loaded ResNet-50 checkpoint")
resnet50.eval()
for p in resnet50.parameters(): p.requires_grad = False

### 4.4) EfficientNet-B3 + LoRA
base_eff3 = timm.create_model(
    'tf_efficientnet_b3_ns',
    pretrained=True,
    in_chans=3,
    num_classes=NUM_CLASSES
)
lora_cfg = LoraConfig(
    r=12,
    lora_alpha=24,
    target_modules=["conv_pw","conv_dw","conv_pwl","conv_head"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"]
)
effnet_b3 = get_peft_model(base_eff3, lora_cfg).to(device)
if os.path.exists(BEST_EFF3):
    ckpt = torch.load(BEST_EFF3, map_location=device)
    sd   = ckpt.get('model_state_dict', ckpt)
    effnet_b3.load_state_dict(sd)
    print("✅ Loaded EfficientNet-B3+LoRA checkpoint")
effnet_b3.eval()
for p in effnet_b3.parameters(): p.requires_grad = False


Checkpoint path: /home/jovyan/panns_data/Cnn14_mAP=0.431.pth


  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Using CPU.
✅ Loaded PANNs-MLP checkpoint


  ckpt = torch.load(BEST_PANNS, map_location=device)
  ckpt = torch.load(BEST_RESNET, map_location=device)
  model = create_fn(


✅ Loaded ResNet-50 checkpoint
✅ Loaded EfficientNet-B3+LoRA checkpoint


  ckpt = torch.load(BEST_EFF3, map_location=device)


In [None]:
class MultiModalDataset(Dataset):
    def __init__(self, manifest_fp):
        self.df = pd.read_csv(manifest_fp)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # — Denoised waveform (.ogg) —
        rel_audio = row['audio_path'].lstrip('/')
        audio_fp  = os.path.join(DEN_DIR, rel_audio)
        wav, sr   = sf.read(audio_fp, dtype='float32')
        wav_t     = torch.from_numpy(wav).float()

        # — PANNs embedding (.npz) —
        rel_emb   = row['emb_path'].lstrip('/')
        emb_fp    = os.path.join(EMB_DIR, rel_emb)
        emb_arr   = np.load(emb_fp)['embedding'].astype(np.float32)
        emb_t     = torch.from_numpy(emb_arr)

        # — Mel spectrogram (.npz) —
        rel_mel   = row['mel_path'].lstrip('/')
        mel_fp    = os.path.join(MEL_DIR, rel_mel)
        mel_arr   = np.load(mel_fp)['mel'].astype(np.float32)
        mel_t     = torch.from_numpy(mel_arr)

        # — Augmented Mel —
        aug_arr   = create_augmented_mel(mel_arr, emb_arr)  # np array
        aug_t     = torch.from_numpy(aug_arr.astype(np.float32))

        # — Label —
        lbl       = row['primary_label']

        return wav_t, emb_t, mel_t, aug_t, lbl

# Instantiate datasets & loaders
train_ds = MultiModalDataset(TRAIN_MANIFEST)
test_ds  = MultiModalDataset(TEST_MANIFEST)

train_loader = DataLoader(
    train_ds,
    batch_size= BATCH_SIZE,
    shuffle=    True,
    num_workers=4,
    pin_memory= True,
    collate_fn=lambda batch: (
        # batch is list of tuples: (wav, emb, mel, aug, label)
        torch.stack([b[0] for b in batch]),  # [B, samples]
        torch.stack([b[1] for b in batch]),  # [B, emb_dim]
        torch.stack([b[2] for b in batch]),  # [B, n_mels, T]
        torch.stack([b[3] for b in batch]),  # [B, n_mels, T]
        [b[4] for b in batch]                # List[str] labels
    )
)

test_loader = DataLoader(
    test_ds,
    batch_size= BATCH_SIZE,
    shuffle=    False,
    num_workers=4,
    pin_memory= True,
    collate_fn=train_loader.collate_fn
)

print(f"Train data: {len(train_ds)}, Test data: {len(test_ds)}")

Train chunks: 8988, Val chunks: 2248


In [None]:
# Cell 6 — Meta-MLP Supervisor & Training Loop

class MetaMLP(nn.Module):
    def __init__(self, in_dim, hidden_dims, num_classes, dropout=0.5):
        super().__init__()
        layers = []
        dims   = [in_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout)
            ]
        layers.append(nn.Linear(dims[-1], num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

meta_in_dim = 3 * NUM_CLASSES
meta_model  = MetaMLP(meta_in_dim, [512,256], NUM_CLASSES, dropout=0.5).to(device)
print("Meta-model trainable params:", 
      sum(p.numel() for p in meta_model.parameters() if p.requires_grad))

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(meta_model.parameters(), lr=LR)

best_acc = 0.0
for epoch in range(1, NUM_EPOCHS+1):
    # — Train —
    meta_model.train()
    t_corr, t_tot = 0, 0
    for x,y in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out  = meta_model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        preds = out.argmax(1)
        t_corr += (preds==y).sum().item()
        t_tot  += x.size(0)
    train_acc = t_corr/t_tot

    # — Validate —
    meta_model.eval()
    v_corr, v_tot = 0, 0
    with torch.no_grad():
        for x,y in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            x, y   = x.to(device), y.to(device)
            out    = meta_model(x)
            v_corr += (out.argmax(1)==y).sum().item()
            v_tot  += x.size(0)
    val_acc = v_corr/v_tot

    print(f"\nEpoch {epoch}: train_acc={train_acc:.4f} | val_acc={val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(meta_model.state_dict(), "best_supervisor.pth")
        print("✅ New best saved")

print(f"\n🏁 Done. Best supervisor val_acc: {best_acc:.4f}")


Meta-model trainable params: 501198


Epoch 1 Train:   0%|          | 0/71 [00:00<?, ?it/s]