In [1]:
import os
import sys
import torch
import torch.nn.functional as F
import torchaudio
import pandas as pd
import numpy as np

In [2]:
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEATURE_BASE  = "/home/jovyan/Features"
MANIFEST_TEST = os.path.join(FEATURE_BASE, "manifest_test.csv")
TAXONOMY_CSV  = "/home/jovyan/Data/birdclef-2025/taxonomy.csv"
CHECKPOINT    = "cnn10_epoch_1.pt"   # path to your best CNN10 checkpoint
THRESHOLD     = 0.5

In [3]:
REPO_DIR = "audioset_tagging_cnn"
if not os.path.isdir(REPO_DIR):
    os.system("git clone https://github.com/qiuqiangkong/audioset_tagging_cnn.git")
sys.path.insert(0, os.path.join(REPO_DIR, "pytorch"))
from models import Cnn10

In [4]:
tax_df  = pd.read_csv(TAXONOMY_CSV)
classes = sorted(tax_df["primary_label"].astype(str).tolist())


In [5]:
model = Cnn10(
    sample_rate=32000,
    window_size=1024,
    hop_size=320,
    mel_bins=64,
    fmin=50,
    fmax=14000,
    classes_num=len(classes)
).to(DEVICE)

ckpt = torch.load(CHECKPOINT, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

Cnn10(
  (spectrogram_extractor): Spectrogram(
    (stft): STFT(
      (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
      (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
    )
  )
  (logmel_extractor): LogmelFilterBank()
  (spec_augmenter): SpecAugmentation(
    (time_dropper): DropStripes()
    (freq_dropper): DropStripes()
  )
  (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_block1): ConvBlock(
    (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block2): ConvBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (c

In [6]:
df     = pd.read_csv(MANIFEST_TEST)
sample = df.sample(1).iloc[0]
print("Running inference on:", sample.chunk_id)

# build full path to denoised .ogg
wav_path = os.path.join(
    FEATURE_BASE,
    "denoised",
    sample.audio_path.lstrip(os.sep)
)

waveform, sr = torchaudio.load(wav_path)   # (channels, samples)
waveform = waveform.mean(dim=0)             # mono

# pad or trim to exactly 10 s @32 kHz
target_len = 10 * 32000
if waveform.shape[0] < target_len:
    waveform = F.pad(waveform, (0, target_len - waveform.shape[0]))
else:
    waveform = waveform[:target_len]

Running inference on: XC197666_chk5


In [7]:
with torch.no_grad():
    out    = model(waveform.unsqueeze(0).to(DEVICE))  # returns dict
    logits = out["clipwise_output"]                   # [1, num_classes]
    probs  = torch.sigmoid(logits)[0].cpu().numpy()   # (num_classes,)

In [8]:
# thresholded predictions
pred_idxs = np.where(probs >= THRESHOLD)[0].tolist()
print(f"\nPredictions (threshold ≥ {THRESHOLD}):")
for i in pred_idxs:
    print(f" • {classes[i]}: {probs[i]:.3f}")


Predictions (threshold ≥ 0.5):
 • 1139490: 0.500
 • 1192948: 0.500
 • 1194042: 0.500
 • 126247: 0.500
 • 1346504: 0.500
 • 134933: 0.500
 • 135045: 0.500
 • 1462711: 0.500
 • 1462737: 0.500
 • 1564122: 0.500
 • 21038: 0.500
 • 21116: 0.500
 • 21211: 0.500
 • 22333: 0.500
 • 22973: 0.500
 • 22976: 0.500
 • 24272: 0.500
 • 24292: 0.500
 • 24322: 0.500
 • 41663: 0.500
 • 41778: 0.500
 • 41970: 0.500
 • 42007: 0.500
 • 42087: 0.500
 • 42113: 0.500
 • 46010: 0.500
 • 47067: 0.500
 • 476537: 0.500
 • 476538: 0.500
 • 48124: 0.500
 • 50186: 0.500
 • 517119: 0.500
 • 523060: 0.500
 • 528041: 0.500
 • 52884: 0.500
 • 548639: 0.500
 • 555086: 0.500
 • 555142: 0.500
 • 566513: 0.500
 • 64862: 0.500
 • 65336: 0.500
 • 65344: 0.500
 • 65349: 0.500
 • 65373: 0.500
 • 65419: 0.500
 • 65448: 0.500
 • 65547: 0.500
 • 65962: 0.500
 • 66016: 0.500
 • 66531: 0.500
 • 66578: 0.500
 • 66893: 0.500
 • 67082: 0.500
 • 67252: 0.500
 • 714022: 0.500
 • 715170: 0.500
 • 787625: 0.500
 • 81930: 0.500
 • 868458: 

In [9]:
# top‑5 overall
print("\nTop‑5 overall:")
top5 = np.argsort(probs)[-5:][::-1]
for idx in top5:
    print(f" • {classes[idx]}: {probs[idx]:.4f}")


Top‑5 overall:
 • bubwre1: 0.7310
 • compau: 0.7310
 • butsal1: 0.7310
 • yebsee1: 0.7281
 • whbant1: 0.7253
