In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import onnx
import onnxruntime as ort
import pandas as pd
import import_ipynb

from peft import get_peft_model, LoraConfig
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [2]:
FEATURE_BASE  = "/mnt/BirdCLEF/features_sampled"
TEST_MANIFEST = os.path.join(FEATURE_BASE, "manifest_test.csv")

# Load the manifest file that lists test samples
test_manifest = pd.read_csv(TEST_MANIFEST)

# Inspect a few entries
print("Test manifest loaded:")
print(test_manifest.head())
print(f"Total samples: {len(test_manifest)}")


Test manifest loaded:
          chunk_id                   audio_path primary_label  \
0    CSA34200_chk0   /1564122/CSA34200_chk0.ogg       1564122   
1  iNat320679_chk0  /126247/iNat320679_chk0.ogg        126247   
2    CSA18793_chk0   /1346504/CSA18793_chk0.ogg       1346504   
3    CSA34196_chk0   /1564122/CSA34196_chk0.ogg       1564122   
4    CSA18792_chk0   /1346504/CSA18792_chk0.ogg       1346504   

                      mel_path                         emb_path  \
0   /1564122/CSA34200_chk0.npz   /1564122/CSA34200_chk0_emb.npz   
1  /126247/iNat320679_chk0.npz  /126247/iNat320679_chk0_emb.npz   
2   /1346504/CSA18793_chk0.npz   /1346504/CSA18793_chk0_emb.npz   
3   /1564122/CSA34196_chk0.npz   /1564122/CSA34196_chk0_emb.npz   
4   /1346504/CSA18792_chk0.npz   /1346504/CSA18792_chk0_emb.npz   

                  mel_aug_path  
0   /1564122/CSA34200_chk0.npz  
1  /126247/iNat320679_chk0.npz  
2   /1346504/CSA18793_chk0.npz  
3   /1564122/CSA34196_chk0.npz  
4   /1346504/CSA187

In [3]:
import torch.nn.functional as F
from torch import nn

# Define wrapper for ONNX export
class FusionONNXWrapper(nn.Module):
    def __init__(self, emb_model, res_model, eff_model, raw_model, meta_model):
        super().__init__()
        self.emb_model = emb_model
        self.res_model = res_model
        self.eff_model = eff_model
        self.raw_model = raw_model
        self.meta_model = meta_model

    def forward(self, emb, ma, m, wav):
        p1 = torch.sigmoid(self.emb_model(emb))
        p2 = torch.sigmoid(self.res_model(ma))
        p3 = torch.sigmoid(self.eff_model(m))
        p4 = torch.sigmoid(self.raw_model(wav))
        feat = torch.cat([p1, p2, p3, p4], dim=1)
        return torch.sigmoid(self.meta_model(feat))


In [9]:
import torch.nn as nn
import torch.nn.functional as F
import timm
from peft import get_peft_model, LoraConfig

# === Sub-models ===
class EmbeddingClassifier(nn.Module):
    def __init__(self, emb_dim, num_cls):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim, 2048), nn.BatchNorm1d(2048), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(2048, 1024),    nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(1024, 512),     nn.BatchNorm1d(512),  nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, num_cls)
        )
    def forward(self, x): return self.net(x)

def get_resnet50_multilabel(num_classes):
    m = torch.hub.load('pytorch/vision:v0.14.0', 'resnet50', pretrained=False)
    m.conv1 = nn.Conv2d(1, m.conv1.out_channels,
                        kernel_size=m.conv1.kernel_size,
                        stride=m.conv1.stride,
                        padding=m.conv1.padding,
                        bias=False)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

TARGET_MODULES  = ["conv_pw", "conv_dw", "conv_pwl", "conv_head"]
MODULES_TO_SAVE = ["classifier"]
def build_efficientnetb3_lora(num_classes):
    base = timm.create_model("efficientnet_b3", pretrained=True)
    orig_fwd = base.forward
    def forward_patch(*args, input_ids=None, **kwargs):
        x = input_ids if input_ids is not None else args[0]
        return orig_fwd(x)
    base.forward = forward_patch
    base.conv_stem = nn.Conv2d(1, base.conv_stem.out_channels,
                                kernel_size=base.conv_stem.kernel_size,
                                stride=base.conv_stem.stride,
                                padding=base.conv_stem.padding,
                                bias=False)
    base.classifier = nn.Linear(base.classifier.in_features, num_classes)
    lora_cfg = LoraConfig(
        r=12, lora_alpha=24,
        target_modules=TARGET_MODULES,
        lora_dropout=0.1, bias="none",
        modules_to_save=MODULES_TO_SAVE,
        task_type="FEATURE_EXTRACTION",
        inference_mode=False
    )
    return get_peft_model(base, lora_cfg)

class RawAudioCNN(nn.Module):
    def __init__(self, num_cls):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=15, stride=4, padding=7)
        self.bn1 = nn.BatchNorm1d(16)
        self.pool = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7)
        self.bn3 = nn.BatchNorm1d(64)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=15, stride=2, padding=7)
        self.bn4 = nn.BatchNorm1d(128)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, num_cls)
    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.bn1(self.conv1(x))); x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.global_pool(x).squeeze(-1)
        return self.fc(x)

class MetaMLP(nn.Module):
    def __init__(self, in_dim, hidden_dims, dropout):
        super().__init__()
        layers, dims = [], [in_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.BatchNorm1d(dims[i+1]),
                nn.ReLU(),
                nn.Dropout(dropout)
            ]
        layers.append(nn.Linear(dims[-1], NUM_CLASSES))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

class FusionONNXWrapper(nn.Module):
    def __init__(self, emb_model, res_model, eff_model, raw_model, meta_model):
        super().__init__()
        self.emb_model = emb_model
        self.res_model = res_model
        self.eff_model = eff_model
        self.raw_model = raw_model
        self.meta_model = meta_model
    def forward(self, emb, ma, m, wav):
        p1 = torch.sigmoid(self.emb_model(emb))
        p2 = torch.sigmoid(self.res_model(ma))
        p3 = torch.sigmoid(self.eff_model(m))
        p4 = torch.sigmoid(self.raw_model(wav))
        feat = torch.cat([p1, p2, p3, p4], dim=1)
        return torch.sigmoid(self.meta_model(feat))


In [11]:
import pandas as pd

# Load class taxonomy
TAXONOMY_CSV = "/mnt/BirdCLEF/taxonomy.csv"
tax = pd.read_csv(TAXONOMY_CSV)
CLASSES = sorted(tax["primary_label"].astype(str).tolist())
NUM_CLASSES = len(CLASSES)
DEVICE = torch.device("cpu")  # You can switch to "cuda" if desired

# Checkpoint paths

CKPT_EMB    = "/mnt/BirdCLEF/Models/best_emb_mlp.pt"
CKPT_RES    = "/mnt/BirdCLEF/Models/best_resnet50.pt"
CKPT_EFF    = "/mnt/BirdCLEF/Models/best_effb3_lora.pt"
CKPT_RAW    = "/mnt/BirdCLEF/Models/best_rawcnn.pt"
CKPT_META   = "/mnt/BirdCLEF/Models/best_meta_mlp.pt"

# Instantiate models
emb_model  = EmbeddingClassifier(2048, NUM_CLASSES).to(DEVICE)
res_model  = get_resnet50_multilabel(NUM_CLASSES).to(DEVICE)
eff_model  = build_efficientnetb3_lora(NUM_CLASSES).to(DEVICE)
raw_model  = RawAudioCNN(NUM_CLASSES).to(DEVICE)
meta_model = MetaMLP(NUM_CLASSES * 4, [1024, 512], dropout=0.3).to(DEVICE)

# Load weights
emb_model.load_state_dict(torch.load(CKPT_EMB, map_location=DEVICE))
res_model.load_state_dict(torch.load(CKPT_RES, map_location=DEVICE))
eff_model.load_state_dict(torch.load(CKPT_EFF, map_location=DEVICE))
raw_model.load_state_dict(torch.load(CKPT_RAW, map_location=DEVICE))
meta_model.load_state_dict(torch.load(CKPT_META, map_location=DEVICE))

# Wrap fusion model
fusion_model = FusionONNXWrapper(
    emb_model, res_model, eff_model, raw_model, meta_model
).to(DEVICE).eval()

# Dummy inputs for ONNX export
dummy_emb = torch.randn(1, 2048).to(DEVICE)             # Embedding vector
dummy_ma  = torch.randn(1, 1, 64, 313).to(DEVICE)        # Mel-spectrogram (ResNet)
dummy_m   = torch.randn(1, 1, 64, 313).to(DEVICE)        # Mel-spectrogram (EffNet)
dummy_wav = torch.randn(1, 320000).to(DEVICE)           # Raw audio (20 sec @ 16kHz)


Using cache found in /home/jovyan/.cache/torch/hub/pytorch_vision_v0.14.0
  emb_model.load_state_dict(torch.load(CKPT_EMB, map_location=DEVICE))
  res_model.load_state_dict(torch.load(CKPT_RES, map_location=DEVICE))
  eff_model.load_state_dict(torch.load(CKPT_EFF, map_location=DEVICE))
  raw_model.load_state_dict(torch.load(CKPT_RAW, map_location=DEVICE))
  meta_model.load_state_dict(torch.load(CKPT_META, map_location=DEVICE))


In [12]:
# === Cell 6: Export Fusion Model to ONNX ===

onnx_model_path = "fusion_birdclef.onnx"

torch.onnx.export(
    fusion_model,
    (dummy_emb, dummy_ma, dummy_m, dummy_wav),
    onnx_model_path,
    input_names=["emb", "mel_aug", "mel", "wav"],
    output_names=["probabilities"],
    dynamic_axes={
        "emb":          {0: "batch"},
        "mel_aug":      {0: "batch"},
        "mel":          {0: "batch"},
        "wav":          {0: "batch"},
        "probabilities": {0: "batch"}
    },
    export_params=True,
    opset_version=17,
    do_constant_folding=True
)

print(f"ONNX model exported to: {onnx_model_path}")


ONNX model exported to: fusion_birdclef.onnx


In [13]:
def benchmark_fusion_onnx(ort_session):
    print(f"Execution provider: {ort_session.get_providers()}")

    # === Accuracy is not benchmarked here due to fusion input complexity ===

    # === Benchmark latency on single sample ===
    num_trials = 100
    dummy_emb  = np.random.randn(1, 2048).astype(np.float32)
    dummy_ma   = np.random.randn(1, 1, 64, 313).astype(np.float32)
    dummy_m    = np.random.randn(1, 1, 64, 313).astype(np.float32)
    dummy_wav  = np.random.randn(1, 320000).astype(np.float32)

    # Warm-up
    ort_session.run(None, {
        "emb": dummy_emb,
        "mel_aug": dummy_ma,
        "mel": dummy_m,
        "wav": dummy_wav
    })

    latencies = []
    for _ in range(num_trials):
        start = time.time()
        ort_session.run(None, {
            "emb": dummy_emb,
            "mel_aug": dummy_ma,
            "mel": dummy_m,
            "wav": dummy_wav
        })
        latencies.append(time.time() - start)

    latencies = np.array(latencies)
    print(f"Inference Latency (single sample, median): {np.percentile(latencies, 50) * 1000:.2f} ms")
    print(f"Inference Latency (single sample, 95th percentile): {np.percentile(latencies, 95) * 1000:.2f} ms")
    print(f"Inference Latency (single sample, 99th percentile): {np.percentile(latencies, 99) * 1000:.2f} ms")
    print(f"Inference Throughput (single sample): {num_trials / np.sum(latencies):.2f} FPS")

    # === Benchmark batch throughput ===
    num_batches = 50
    batch_size = 32

    dummy_emb  = np.random.randn(batch_size, 2048).astype(np.float32)
    dummy_ma   = np.random.randn(batch_size, 1, 64, 313).astype(np.float32)
    dummy_m    = np.random.randn(batch_size, 1, 64, 313).astype(np.float32)
    dummy_wav  = np.random.randn(batch_size, 320000).astype(np.float32)

    # Warm-up
    ort_session.run(None, {
        "emb": dummy_emb,
        "mel_aug": dummy_ma,
        "mel": dummy_m,
        "wav": dummy_wav
    })

    batch_times = []
    for _ in range(num_batches):
        start = time.time()
        ort_session.run(None, {
            "emb": dummy_emb,
            "mel_aug": dummy_ma,
            "mel": dummy_m,
            "wav": dummy_wav
        })
        batch_times.append(time.time() - start)

    batch_fps = (batch_size * num_batches) / np.sum(batch_times)
    print(f"Batch Throughput: {batch_fps:.2f} FPS")


In [14]:
onnx_path = "fusion_birdclef.onnx"
ort_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
benchmark_fusion_onnx(ort_session)


Execution provider: ['CPUExecutionProvider']
Inference Latency (single sample, median): 46.45 ms
Inference Latency (single sample, 95th percentile): 46.66 ms
Inference Latency (single sample, 99th percentile): 47.08 ms
Inference Throughput (single sample): 21.52 FPS
Batch Throughput: 75.84 FPS


In [15]:
onnx_model_path = "models/food11.onnx"
ort_session = ort.InferenceSession(onnx_model_path, providers=['CUDAExecutionProvider'])
benchmark_session(ort_session)
ort.get_device()

NoSuchFile: [ONNXRuntimeError] : 3 : NO_SUCHFILE : Load model from models/food11.onnx failed:Load model models/food11.onnx failed. File doesn't exist