In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import librosa


In [None]:
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
THRESHOLD   = 0.5
MANIFEST    = "/home/jovyan/Features/manifest_test.csv"
TRAIN_CSV   = "/home/jovyan/Data/birdclef-2025/train.csv"
CKPT_PATH   = "panns_cnn14_best_epoch.pt"   # your best checkpoint

SR          = 32000
DURATION_S  = 10.0
NUM_SAMPLES = int(SR * DURATION_S)

In [None]:
test_df = pd.read_csv(MANIFEST)

meta = pd.read_csv(TRAIN_CSV, usecols=["filename","secondary_labels"])
meta["recording_id"] = meta.filename.str.replace(r"\.ogg$", "", regex=True)
meta["sec_list"]     = meta.secondary_labels.fillna("").str.split()
sec_map = dict(zip(meta.recording_id, meta.sec_list))

labels = set()
for _, row in test_df.iterrows():
    rid = row.chunk_id.split("_chk")[0]
    labels.add(row.primary_label)
    for sec in sec_map.get(rid, []):
        labels.add(sec)
classes = sorted(labels)

In [None]:
panns = torch.hub.load(
    "qiuqiangkong/audioset_tagging_cnn", "Cnn14",
    pretrained=False, 
    classes_num=None
)

# replace classification head
if hasattr(panns, "fc2"):
    in_f = panns.fc2.in_features
    panns.fc2 = nn.Linear(in_f, len(classes))
else:
    in_f = panns.fc_audioset.in_features
    panns.fc_audioset = nn.Linear(in_f, len(classes))

model = panns.to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

In [None]:
sample = test_df.sample(1).iloc[0]
print("Inferring on chunk:", sample.chunk_id)

wav, _ = librosa.load(sample.audio_path, sr=SR)
# pad or truncate to exactly 10s
if len(wav) < NUM_SAMPLES:
    wav = np.pad(wav, (0, NUM_SAMPLES - len(wav)))
else:
    wav = wav[:NUM_SAMPLES]

# shape [1, num_samples]
x = torch.from_numpy(wav).unsqueeze(0).to(DEVICE)

In [None]:
with torch.no_grad():
    logits = model(x)                   # [1, num_classes]
    probs  = torch.sigmoid(logits)[0]   # [num_classes]

In [None]:
pred_idxs = (probs >= THRESHOLD).nonzero(as_tuple=False).flatten().tolist()
if isinstance(pred_idxs, int):
    pred_idxs = [pred_idxs]

print(f"\nPredictions (threshold ≥ {THRESHOLD}):")
for i in pred_idxs:
    print(f"  • {classes[i]}: {probs[i]:.3f}")