In [1]:
# Cell 2 — Imports & Device
import os, glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import timm
from peft import LoraConfig, get_peft_model
from panns_inference import AudioTagging
import librosa
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


2025-04-30 04:39:01.483219: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745987941.502148   91683 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745987941.508027   91683 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1745987941.522846   91683 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745987941.522867   91683 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745987941.522870   91683 computation_placer.cc:177] computation placer alr

Using device: cuda


In [19]:
# Cell 3 — Paths, Hyperparams & Label Map
TRAIN_CSV     = '/home/jovyan/Data/birdclef-2025/train.csv'
DEN_DIR       = '/home/jovyan/Features/denoised'
BEST_PANNS    = './best_panns_mlp_checkpoint.pt'
BEST_RESNET   = './best_resnet50.pth'
BEST_EFF3     = './best_effnetb3_lora.pth'

SR            = 32000
CHUNK_SEC     = 10
CHUNK_SAMPLES = SR * CHUNK_SEC

BATCH_SIZE    = 128
NUM_EPOCHS    = 5
LR            = 1e-3

# build label mapping
meta = pd.read_csv(TRAIN_CSV)
labels = sorted(meta['primary_label'].unique())
label2idx = {lab:i for i,lab in enumerate(labels)}
meta['label_idx'] = meta['primary_label'].map(label2idx)
NUM_CLASSES = len(labels)
print("Num classes:", NUM_CLASSES)


Num classes: 206


In [20]:
# Cell 4 — Load & Freeze Base Models

### 4.1) PANNs embedding extractor
panns_extractor = AudioTagging(checkpoint_path=None, device=device)
panns_extractor.model.eval()

### 4.2) PANNs-MLP classifier
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, num_classes, dropout=0.5):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout)
            ]
        layers.append(nn.Linear(dims[-1], num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

# infer embedding dimension
with torch.no_grad():
    dummy_wav = torch.zeros(1, CHUNK_SAMPLES, device=device)
    _, emb = panns_extractor.inference(dummy_wav)
    emb_dim = emb.shape[-1]

panns_mlp = MLPClassifier(emb_dim, [1024, 512], NUM_CLASSES, dropout=0.5).to(device)
if os.path.exists(BEST_PANNS):
    ckpt = torch.load(BEST_PANNS, map_location=device)
    sd   = ckpt.get('model_state_dict', ckpt)
    panns_mlp.load_state_dict(sd)
    print("✅ Loaded PANNs-MLP checkpoint")
panns_mlp.eval()
for p in panns_mlp.parameters(): p.requires_grad = False

### 4.3) ResNet-50
resnet50 = models.resnet50(weights=None, num_classes=NUM_CLASSES).to(device)
if os.path.exists(BEST_RESNET):
    ckpt = torch.load(BEST_RESNET, map_location=device)
    sd   = ckpt.get('model_state_dict', ckpt)
    resnet50.load_state_dict(sd)
    print("✅ Loaded ResNet-50 checkpoint")
resnet50.eval()
for p in resnet50.parameters(): p.requires_grad = False

### 4.4) EfficientNet-B3 + LoRA
base_eff3 = timm.create_model(
    'tf_efficientnet_b3_ns',
    pretrained=True,
    in_chans=1,
    num_classes=NUM_CLASSES
)
lora_cfg = LoraConfig(
    r=12,
    lora_alpha=24,
    target_modules=["conv_pw","conv_dw","conv_pwl","conv_head"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"]
)
effnet_b3 = get_peft_model(base_eff3, lora_cfg).to(device)
if os.path.exists(BEST_EFF3):
    ckpt = torch.load(BEST_EFF3, map_location=device)
    sd   = ckpt.get('model_state_dict', ckpt)
    effnet_b3.load_state_dict(sd)
    print("✅ Loaded EfficientNet-B3+LoRA checkpoint")
effnet_b3.eval()
for p in effnet_b3.parameters(): p.requires_grad = False


Checkpoint path: /home/jovyan/panns_data/Cnn14_mAP=0.431.pth


  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Using CPU.
✅ Loaded PANNs-MLP checkpoint


  ckpt = torch.load(BEST_PANNS, map_location=device)
  ckpt = torch.load(BEST_RESNET, map_location=device)
  model = create_fn(


✅ Loaded ResNet-50 checkpoint
✅ Loaded EfficientNet-B3+LoRA checkpoint


  ckpt = torch.load(BEST_EFF3, map_location=device)


In [21]:
# Cell 5 — Stacker Dataset emitting concatenated logits, all chunks (with separate image transforms)

class StackerDataset(Dataset):
    def __init__(self,
                 den_dir,
                 panns_extractor, panns_mlp,
                 resnet50, effnet_b3,
                 mel_transform=None,
                 resnet_transform=None,
                 effnet_transform=None):
        # 1) Collect all npz files
        self.files = sorted(glob.glob(f"{den_dir}/**/*.npz", recursive=True))
        # 2) Read labels once
        self.labels = [int(np.load(f)['label']) for f in self.files]
        self.panns_extractor = panns_extractor
        self.panns_mlp       = panns_mlp
        self.resnet50        = resnet50
        self.effnet_b3       = effnet_b3
        self.mel_transform   = mel_transform       # function(wave_np) -> mel_db array
        self.resnet_transform= resnet_transform    # transform for ResNet (3-ch)
        self.effnet_transform= effnet_transform    # transform for EffNet-B3 (1-ch)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        lbl  = self.labels[idx]
        data = np.load(path)
        wave = data['waveform'].astype(np.float32)  # [CHUNK_SAMPLES]

        # --- 1) PANNs logits ---
        wv = torch.from_numpy(wave).to(device)
        with torch.no_grad():
            _, emb = self.panns_extractor.inference(wv.unsqueeze(0))
            if isinstance(emb, np.ndarray):
                emb_t = torch.from_numpy(emb).to(device)
            else:
                emb_t = emb.to(device)
            logits_p = self.panns_mlp(emb_t.squeeze(0))

        # --- 2) MEL spectrogram → image tensor ---
        mel_db = self.mel_transform(wave)           # [N_MELS, T]
        img    = torch.from_numpy(mel_db.astype(np.float32)).unsqueeze(0)  # [1,H,W]

        # 2a) for ResNet: 3-channel
        img_r = self.resnet_transform(img)          # [3,224,224]
        img_r = img_r.unsqueeze(0).to(device)       # [1,3,224,224]

        # 2b) for EffNet-B3: 1-channel
        img_e = self.effnet_transform(img)          # [1,224,224]
        img_e = img_e.unsqueeze(0).to(device)       # [1,1,224,224]

        # --- 3) ResNet & EffNet logits ---
        with torch.no_grad():
            logits_r = self.resnet50(img_r).squeeze(0)
            logits_e = self.effnet_b3(img_e).squeeze(0)

        # --- 4) Concatenate & return ---
        stacked = torch.cat([logits_p, logits_r, logits_e], dim=-1)
        return stacked.cpu(), lbl

# --- Usage: re-instantiate DataLoaders ---

# Define your mel-spectrogram helper
def your_mel_transform(wave_np):
    S = librosa.feature.melspectrogram(y=wave_np, sr=SR,
                                       n_fft=2048, hop_length=512,
                                       n_mels=128)
    return librosa.power_to_db(S, ref=np.max)

# ResNet transform: repeat to 3 channels + ImageNet normalize
resnet_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Lambda(lambda x: x.repeat(3,1,1)),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

# EffNet-B3 transform: keep 1 channel + custom normalize
effnet_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Normalize([0.5], [0.5])
])

# Build the full dataset
full_ds = StackerDataset(
    den_dir         = DEN_DIR,
    panns_extractor = panns_extractor,
    panns_mlp       = panns_mlp,
    resnet50        = resnet50,
    effnet_b3       = effnet_b3,
    mel_transform   = your_mel_transform,
    resnet_transform= resnet_tf,
    effnet_transform= effnet_tf
)

# Split into train/val by indices
n       = len(full_ds)
n_train = int(0.8 * n)
train_ds, val_ds = random_split(full_ds, [n_train, n-n_train])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=0, pin_memory=True)

print(f"Train chunks: {len(train_ds)}, Val chunks: {len(val_ds)}")


Train chunks: 8988, Val chunks: 2248


In [None]:
# Cell 6 — Meta-MLP Supervisor & Training Loop

class MetaMLP(nn.Module):
    def __init__(self, in_dim, hidden_dims, num_classes, dropout=0.5):
        super().__init__()
        layers = []
        dims   = [in_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout)
            ]
        layers.append(nn.Linear(dims[-1], num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

meta_in_dim = 3 * NUM_CLASSES
meta_model  = MetaMLP(meta_in_dim, [512,256], NUM_CLASSES, dropout=0.5).to(device)
print("Meta-model trainable params:", 
      sum(p.numel() for p in meta_model.parameters() if p.requires_grad))

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(meta_model.parameters(), lr=LR)

best_acc = 0.0
for epoch in range(1, NUM_EPOCHS+1):
    # — Train —
    meta_model.train()
    t_corr, t_tot = 0, 0
    for x,y in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out  = meta_model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        preds = out.argmax(1)
        t_corr += (preds==y).sum().item()
        t_tot  += x.size(0)
    train_acc = t_corr/t_tot

    # — Validate —
    meta_model.eval()
    v_corr, v_tot = 0, 0
    with torch.no_grad():
        for x,y in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            x, y   = x.to(device), y.to(device)
            out    = meta_model(x)
            v_corr += (out.argmax(1)==y).sum().item()
            v_tot  += x.size(0)
    val_acc = v_corr/v_tot

    print(f"\nEpoch {epoch}: train_acc={train_acc:.4f} | val_acc={val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(meta_model.state_dict(), "best_supervisor.pth")
        print("✅ New best saved")

print(f"\n🏁 Done. Best supervisor val_acc: {best_acc:.4f}")


Meta-model trainable params: 501198


Epoch 1 Train:   0%|          | 0/71 [00:00<?, ?it/s]