In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import transforms
import timm
from peft import get_peft_model, LoraConfig
from torch.amp import autocast

In [2]:
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEATURE_DIR   = "/home/jovyan/Features"
MANIFEST_CSV  = os.path.join(FEATURE_DIR, "manifest_test.csv")
TAXONOMY_CSV  = "/home/jovyan/Data/birdclef-2025/taxonomy.csv"
CHECKPOINT    = "best_effb3_lora.pt"   # path to your saved best LoRA model
MEL_KEY       = "mel"                  # key inside the .npz for mel
THRESHOLD     = 0.5

In [3]:
tax_df  = pd.read_csv(TAXONOMY_CSV)
classes = sorted(tax_df["primary_label"].astype(str).tolist())
num_classes = len(classes)

TARGET_MODULES  = ["conv_pw","conv_dw","conv_pwl","conv_head"]
MODULES_TO_SAVE = ["classifier"]


In [4]:
def build_effb3_lora(num_classes):
    base = timm.create_model("efficientnet_b3", pretrained=True)

    # patch forward to accept arbitrary kwargs
    orig_fwd = base.forward
    def forward_patch(*args, **kwargs):
        if "input_ids" in kwargs:
            x = kwargs.pop("input_ids")
        elif len(args)>0:
            x = args[0]
        else:
            raise ValueError("No input tensor")
        # drop any other keys
        return orig_fwd(x)
    base.forward = forward_patch

    # adapt to single-channel
    stem = base.conv_stem
    base.conv_stem = nn.Conv2d(
        1, stem.out_channels,
        kernel_size=stem.kernel_size,
        stride=stem.stride,
        padding=stem.padding,
        bias=False
    )
    # replace head
    in_f = base.classifier.in_features
    base.classifier = nn.Linear(in_f, num_classes)

    # attach LoRA
    lora_cfg = LoraConfig(
        r=12, lora_alpha=24, lora_dropout=0.1,
        target_modules=TARGET_MODULES,
        modules_to_save=MODULES_TO_SAVE,
        bias="none",
        task_type="FEATURE_EXTRACTION",
        inference_mode=True
    )
    model = get_peft_model(base, lora_cfg)
    return model

model = build_effb3_lora(num_classes).to(DEVICE)
state = torch.load(CHECKPOINT, map_location=DEVICE)
model.load_state_dict(state)
model.eval()

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): EfficientNet(
      (conv_stem): Conv2d(1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNormAct2d(
        40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): SiLU(inplace=True)
      )
      (blocks): Sequential(
        (0): Sequential(
          (0): DepthwiseSeparableConv(
            (conv_dw): lora.Conv2d(
              (base_layer): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.1, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Conv2d(40, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              )
              (lora_B): ModuleDict(
                (default): Conv2d(12, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
    

In [5]:
df = pd.read_csv(MANIFEST_CSV)
row = df.sample(1).iloc[0]
chunk_id = row.chunk_id
rel_path = row.mel_path.lstrip(os.sep)
mel_path = os.path.join(FEATURE_DIR, "mel", rel_path)

print(f"Running inference on chunk: {chunk_id}")

Running inference on chunk: iNat329195_chk1


In [6]:
data = np.load(mel_path)
mel  = data[MEL_KEY]                         # [n_mels, n_frames]
x    = torch.from_numpy(mel).unsqueeze(0).unsqueeze(0).float()  # [1,1,n_mels,n_frames]
x    = x.to(DEVICE, non_blocking=True)


In [7]:
with torch.no_grad(), autocast(device_type="cuda"):
    logits = model(x)                       # [1, num_classes]
    probs  = torch.sigmoid(logits)[0].cpu().numpy()

In [8]:
ml_preds = [(classes[i], float(probs[i]))
            for i in range(num_classes) if probs[i] >= THRESHOLD]

print(f"\nMulti‑label predictions (prob ≥ {THRESHOLD}):")
if ml_preds:
    for lab, sc in ml_preds:
        print(f"  • {lab}: {sc:.3f}")
else:
    print("  • <none>")


Multi‑label predictions (prob ≥ 0.5):
  • grbhaw1: 0.803


In [9]:
idx = int(probs.argmax())
print(f"\nPrimary‑label (top‑1) prediction:")
print(f"  → {classes[idx]}: {probs[idx]:.3f}")


Primary‑label (top‑1) prediction:
  → grbhaw1: 0.803
