In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import timm
from peft import get_peft_model, LoraConfig, TaskType

In [None]:
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
THRESHOLD   = 0.5
MANIFEST    = "/home/jovyan/Features/manifest_test.csv"
TRAIN_CSV   = "/home/jovyan/Data/birdclef-2025/train.csv"
CKPT_PATH   = "effb3_lora_best_epoch.pt"

In [None]:
test_df = pd.read_csv(MANIFEST)

meta = pd.read_csv(TRAIN_CSV, usecols=["filename","secondary_labels"])
meta["recording_id"] = meta.filename.str.replace(r"\.ogg$", "", regex=True)
meta["sec_list"]     = meta.secondary_labels.fillna("").str.split()
sec_map = dict(zip(meta.recording_id, meta.sec_list))

labels = set()
for _, row in test_df.iterrows():
    rid = row.chunk_id.split("_chk")[0]
    labels.add(row.primary_label)
    for sec in sec_map.get(rid, []):
        labels.add(sec)
classes = sorted(labels)

In [None]:
def build_efficientnetb3_lora(num_classes):
    # base model
    model = timm.create_model("efficientnet_b3", pretrained=False)
    # adapt stem conv to 1‑channel
    stem = model.conv_stem
    model.conv_stem = nn.Conv2d(
        in_channels=1,
        out_channels=stem.out_channels,
        kernel_size=stem.kernel_size,
        stride=stem.stride,
        padding=stem.padding,
        bias=False
    )
    # replace head
    in_feat = model.classifier.in_features
    model.classifier = nn.Linear(in_feat, num_classes)

    # LoRA config for inference
    peft_cfg = LoraConfig(
        task_type=TaskType.IMAGE_CLASSIFICATION,
        inference_mode=True,
        r=4,
        lora_alpha=16,
        target_modules=["conv_stem","classifier"]
    )
    model = get_peft_model(model, peft_cfg)
    return model

model = build_efficientnetb3_lora(len(classes)).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

In [None]:
sample = test_df.sample(1).iloc[0]
print("Inferring on chunk:", sample.chunk_id)

npz = np.load(sample.mel_path)      # use mel (not augmented) here
mel = npz["mel"]                    # shape: (n_mels, n_frames)

# tensor shape: [1,1,n_mels,n_frames]
x = torch.from_numpy(mel).unsqueeze(0).unsqueeze(0).float().to(DEVICE)

In [None]:
with torch.no_grad():
    logits = model(x)                   # [1, num_classes]
    probs  = torch.sigmoid(logits)[0]   # [num_classes]

In [None]:
pred_idxs = (probs >= THRESHOLD).nonzero(as_tuple=False).flatten().tolist()
if isinstance(pred_idxs, int):
    pred_idxs = [pred_idxs]

print(f"\nPredictions (threshold ≥ {THRESHOLD}):")
for i in pred_idxs:
    print(f"  • {classes[i]}: {probs[i]:.3f}")