In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pandas as pd

In [None]:
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEATURE_BASE  = "/home/jovyan/Features"
TEST_MANIFEST = os.path.join(FEATURE_BASE, "manifest_test.csv")
TAXONOMY_CSV  = "/home/jovyan/Data/birdclef-2025/taxonomy.csv"

CKPT_EMB    = "best_emb_mlp.pt"
CKPT_RES    = "best_resnet50.pt"
CKPT_EFF    = "best_effb3_lora.pt"
CKPT_RAW    = "best_rawcnn.pt"
CKPT_META   = "best_meta_mlp.pt"

THRESHOLD   = 0.5

In [None]:
tax = pd.read_csv(TAXONOMY_CSV)
CLASSES = sorted(tax["primary_label"].astype(str).tolist())
NUM_CLASSES = len(CLASSES)


In [None]:
class EmbeddingClassifier(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim, 2048), nn.BatchNorm1d(2048), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(2048, 1024),   nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 512),    nn.BatchNorm1d(512),  nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, NUM_CLASSES)
        )
    def forward(self, x): return self.net(x)

def get_resnet50_multilabel():
    from torchvision.models import resnet50
    m = resnet50(weights=None)
    m.conv1 = nn.Conv2d(1, m.conv1.out_channels,
                        kernel_size=m.conv1.kernel_size,
                        stride=m.conv1.stride,
                        padding=m.conv1.padding,
                        bias=False)
    m.fc = nn.Linear(m.fc.in_features, NUM_CLASSES)
    return m

def build_efficientnetb3_lora():
    import timm
    from peft import get_peft_model, LoraConfig
    base = timm.create_model("efficientnet_b3", pretrained=True)
    # patch forward
    orig = base.forward
    def fw(x, **kwargs): return orig(x)
    base.forward = fw
    # adapt channel & head
    stem = base.conv_stem
    base.conv_stem = nn.Conv2d(1, stem.out_channels, stem.kernel_size, stem.stride, stem.padding, bias=False)
    base.classifier = nn.Linear(base.classifier.in_features, NUM_CLASSES)
    # LoRA config must match training – but at inference it’s baked in
    return base

class RawAudioCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1,16,15,4,7); self.bn1 = nn.BatchNorm1d(16); self.pool = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(16,32,15,2,7); self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32,64,15,2,7); self.bn3 = nn.BatchNorm1d(64)
        self.conv4 = nn.Conv1d(64,128,15,2,7);self.bn4 = nn.BatchNorm1d(128)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, NUM_CLASSES)
    def forward(self, x):
        x = x.unsqueeze(1)  # [B,T]→[B,1,T]
        x = F.relu(self.bn1(self.conv1(x))); x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.global_pool(x).squeeze(-1)
        return self.fc(x)

class MetaMLP(nn.Module):
    def __init__(self):
        super().__init__()
        input_dim = 4 * NUM_CLASSES
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(512, 256),       nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, NUM_CLASSES)
        )
    def forward(self, x): return self.net(x)

In [None]:
# 1) Embedding MLP
# we need emb_dim: grab from one sample
test_manifest = pd.read_csv(TEST_MANIFEST)
sample = test_manifest.iloc[0]
emb_path = os.path.join(FEATURE_BASE, "embeddings", sample.emb_path.lstrip(os.sep))
emb_arr  = np.load(emb_path)["embedding"].mean(axis=0).astype(np.float32)
emb_model = EmbeddingClassifier(emb_dim=emb_arr.shape[0]).to(DEVICE)
emb_model.load_state_dict(torch.load(CKPT_EMB, map_location=DEVICE))
emb_model.eval()

# 2) ResNet50
res_model = get_resnet50_multilabel().to(DEVICE)
res_model.load_state_dict(torch.load(CKPT_RES, map_location=DEVICE))
res_model.eval()

# 3) EfficientNet‑B3 LoRA (at inference just use as plain module)
eff_model = build_efficientnetb3_lora().to(DEVICE)
eff_model.load_state_dict(torch.load(CKPT_EFF, map_location=DEVICE))
eff_model.eval()

# 4) RawAudioCNN
raw_model = RawAudioCNN().to(DEVICE)
raw_model.load_state_dict(torch.load(CKPT_RAW, map_location=DEVICE))
raw_model.eval()

# 5) Meta supervisor
meta_model = MetaMLP().to(DEVICE)
meta_model.load_state_dict(torch.load(CKPT_META, map_location=DEVICE))
meta_model.eval()

In [None]:
# embedding
emb = torch.from_numpy(emb_arr).unsqueeze(0).to(DEVICE)  # [1,emb_dim]

# mel‑aug (ResNet50)
ma_path = os.path.join(FEATURE_BASE, "mel_aug", sample.mel_aug_path.lstrip(os.sep))
ma_arr  = np.load(ma_path)["mel"].astype(np.float32)
ma = torch.from_numpy(ma_arr).unsqueeze(0).unsqueeze(0).to(DEVICE)  # [1,1,n_mels,n_frames]

# mel (EffNetB3)
m_path = os.path.join(FEATURE_BASE, "mel", sample.mel_path.lstrip(os.sep))
m_arr  = np.load(m_path)["mel"].astype(np.float32)
m = torch.from_numpy(m_arr).unsqueeze(0).unsqueeze(0).to(DEVICE)       # [1,1,n_mels,n_frames]

# raw audio
wav_path = os.path.join(FEATURE_BASE, "denoised", sample.audio_path.lstrip(os.sep))
wav, sr   = torchaudio.load(wav_path)   # [1,T]
wav       = wav.squeeze(0)
T         = sr * 10
if wav.size(0)<T:
    wav = F.pad(wav, (0, T-wav.size(0)))
else:
    wav = wav[:T]
wav = (wav - wav.mean())/wav.std().clamp_min(1e-6)
wav = wav.unsqueeze(0).to(DEVICE)       # [1,T]

In [None]:
with torch.no_grad():
    p1 = torch.sigmoid(emb_model(emb))     # [1,NUM_CLASSES]
    p2 = torch.sigmoid(res_model(ma))      # [1,NUM_CLASSES]
    p3 = torch.sigmoid(eff_model(m))       # [1,NUM_CLASSES]
    p4 = torch.sigmoid(raw_model(wav))     # [1,NUM_CLASSES]

    feat   = torch.cat([p1,p2,p3,p4], dim=1)
    logits = meta_model(feat)
    probs  = torch.sigmoid(logits)[0].cpu().numpy()

In [None]:
ml_preds = [(CLASSES[i], float(probs[i]))
            for i in range(NUM_CLASSES) if probs[i] >= THRESHOLD]

print(f"\nMulti‑label predictions (prob ≥ {THRESHOLD}):")
if ml_preds:
    for lab, sc in ml_preds:
        print(f"  • {lab}: {sc:.3f}")
else:
    print("  • <none>")


In [None]:
primary_idx   = int(probs.argmax())
primary_label = CLASSES[primary_idx]
primary_score = float(probs[primary_idx])

print(f"\nPrimary‑label (top‑1) prediction:")
print(f"  → {primary_label}: {primary_score:.3f}")