In [15]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [39]:
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
THRESHOLD  = 0.1
FEATURE_BASE = "/home/jovyan/Features"
CHECKPOINT = "best_ckpt.pt"

In [17]:
class EmbeddingClassifier(nn.Module):
    def __init__(self, emb_dim, num_classes, hidden_dims=[512,256], dropout=0.3):
        super().__init__()
        self.fc1   = nn.Linear(emb_dim, hidden_dims[0])
        self.bn1   = nn.BatchNorm1d(hidden_dims[0])
        self.drop1 = nn.Dropout(dropout)
        self.fc2   = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.bn2   = nn.BatchNorm1d(hidden_dims[1])
        self.drop2 = nn.Dropout(dropout)
        self.out   = nn.Linear(hidden_dims[1], num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.drop1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.drop2(x)
        return self.out(x)

In [18]:
tax_df     = pd.read_csv("/home/jovyan/Data/birdclef-2025/taxonomy.csv")
classes    = sorted(tax_df['primary_label'].astype(str).tolist())
num_classes = len(classes)

In [33]:
test_df = pd.read_csv(os.path.join(FEATURE_BASE, "manifest_test.csv"))
sample  = test_df.sample(300).iloc[0]
print("Running inference on chunk:", sample.chunk_id)


Running inference on chunk: XC668352_chk5


In [34]:
rel_path = sample.emb_path.lstrip(os.sep)
emb_file = os.path.join(FEATURE_BASE, "embeddings", rel_path)
emb_arr  = np.load(emb_file)["embedding"]   # shape: (n_windows, emb_dim)
emb_dim  = emb_arr.shape[1]

In [35]:
model = EmbeddingClassifier(
    emb_dim    = emb_dim,
    num_classes= num_classes,
    hidden_dims= [512,256],
    dropout    = 0.3
).to(DEVICE)

ckpt = torch.load(CHECKPOINT, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

EmbeddingClassifier(
  (fc1): Linear(in_features=2048, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop1): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop2): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=256, out_features=206, bias=True)
)

In [36]:
sample_emb = emb_arr.mean(axis=0).astype(np.float32)   # (emb_dim,)
x = torch.from_numpy(sample_emb).unsqueeze(0).to(DEVICE)  # (1, emb_dim)

with torch.no_grad():
    logits = model(x)                    # (1, num_classes)
    probs  = torch.sigmoid(logits)[0]    # (num_classes,)

In [40]:
pred_idxs = (probs >= THRESHOLD).nonzero(as_tuple=False).flatten().tolist()
if isinstance(pred_idxs, int):
    pred_idxs = [pred_idxs]

print(f"\nPredicted species (threshold ≥ {THRESHOLD}):")
for i in pred_idxs:
    print(f"  • {classes[i]}: {probs[i]:.3f}")


Predicted species (threshold ≥ 0.1):
  • roahaw: 0.111
  • yeofly1: 0.158


In [38]:
probs

tensor([2.0285e-08, 1.2113e-06, 4.1922e-05, 2.8624e-09, 8.9907e-04, 1.8784e-04,
        3.2474e-02, 8.5911e-10, 1.7500e-06, 1.2957e-08, 9.3187e-09, 6.9883e-09,
        3.3979e-02, 1.7957e-07, 2.3511e-03, 1.2399e-04, 7.6945e-11, 1.1644e-08,
        2.2516e-06, 2.0137e-04, 1.3261e-06, 7.5820e-06, 8.4081e-08, 4.0574e-12,
        2.8809e-13, 9.3497e-08, 5.7027e-08, 2.5522e-04, 1.4659e-08, 2.7401e-03,
        7.4959e-04, 2.7096e-06, 5.8604e-09, 1.4544e-07, 8.1871e-03, 1.6918e-08,
        1.8999e-07, 1.4802e-10, 2.8193e-07, 1.0782e-06, 1.2136e-08, 4.4366e-06,
        3.0804e-04, 2.2671e-04, 7.4001e-06, 3.0907e-04, 2.9765e-06, 1.5725e-08,
        8.5352e-14, 2.9398e-12, 3.5178e-10, 1.2250e-06, 6.1212e-09, 5.4215e-06,
        1.0881e-07, 2.1755e-07, 1.5622e-06, 1.8804e-08, 3.9623e-09, 2.2825e-07,
        6.4498e-06, 6.0333e-04, 3.0089e-04, 6.8096e-05, 3.9351e-03, 1.8950e-05,
        4.5916e-02, 3.2799e-05, 2.3346e-04, 3.6194e-04, 1.0303e-03, 6.2840e-03,
        1.1404e-02, 3.2445e-02, 1.7870e-

In [41]:
topk = probs.cpu().topk(5)
print("\nTop 5 predictions:")
for score, idx in zip(topk.values, topk.indices):
    print(f"  • {classes[idx]}: {score:.4f}")


Top 5 predictions:
  • yeofly1: 0.1579
  • roahaw: 0.1113
  • gohman1: 0.0788
  • thbeup1: 0.0779
  • wbwwre1: 0.0661


In [42]:
print("Checkpoint keys:", ckpt.keys())
print("Output‑layer weight shape:", ckpt["model_state"]["out.weight"].shape)

Checkpoint keys: dict_keys(['epoch', 'model_state', 'optim_state', 'train_loss', 'test_loss'])
Output‑layer weight shape: torch.Size([206, 256])
