In [34]:
import os
from itertools import chain
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import torch
import torchaudio
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [35]:
AUDIO_PATH = Path("data/birdclef-2024/train_audio")
OUT_DIR = Path("data/google_embeddings")

os.makedirs(OUT_DIR, exist_ok=True)

df_meta = pd.read_csv("data/birdclef-2024/train_metadata.csv")
model_path = "https://kaggle.com/models/google/bird-vocalization-classifier/frameworks/TensorFlow2/variations/bird-vocalization-classifier/versions/4"
model = hub.load(model_path)
model_labels_df = pd.read_csv(hub.resolve(model_path) + "/assets/label.csv")

SAMPLE_RATE = 32000
WINDOW = 5 * SAMPLE_RATE

In [36]:
index_to_label = sorted(df_meta.primary_label.unique())
label_to_index = {v: k for k, v in enumerate(index_to_label)}
model_labels = {v: k for k, v in enumerate(model_labels_df.ebird2021)}
model_bc_indexes = [
    model_labels[label] if label in model_labels else -1 for label in index_to_label
]

In [37]:
len(model_labels), len(model_bc_indexes)

(10932, 182)

In [38]:
missing_birds = set(np.array(index_to_label)[np.array(model_bc_indexes) == -1])
missing_birds

{'bkrfla1', 'indrol2'}

In [39]:
for species in missing_birds:
    count = df_meta.primary_label.value_counts()[species]
    print(f"{species}: {count}")

indrol2: 35
bkrfla1: 29


In [40]:
def get_embeddings_and_logits(file):
    audio = torchaudio.load(AUDIO_PATH / file)[0].numpy()[0]
    embeddings = []
    logits = []
    for i in range(0, len(audio), WINDOW):
        clip = audio[i : i + WINDOW]
        if len(clip) < WINDOW:
            clip = np.concatenate([clip, np.zeros(WINDOW - len(clip))])
        result = model.infer_tf(clip[None, :])
        embeddings.append(result[1][0].numpy())
        clip_logits = np.concatenate([result[0].numpy(), -np.inf], axis=None)
        logits.append(clip_logits[model_bc_indexes])
    embeddings = np.stack(embeddings)
    logits = np.stack(logits)
    return embeddings, logits

In [41]:
example_embeddings, example_logits = get_embeddings_and_logits("asbfly/XC49755.ogg")

W0000 00:00:1716569685.508780    2067 assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert


In [42]:
example_embeddings.shape, example_logits.shape

((11, 1280), (11, 182))

In [43]:
np.where(example_logits == -np.inf)

(array([ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,
         8,  9,  9, 10, 10]),
 array([12, 86, 12, 86, 12, 86, 12, 86, 12, 86, 12, 86, 12, 86, 12, 86, 12,
        86, 12, 86, 12, 86]))

In [44]:
index_to_label[12], index_to_label[86]

('bkrfla1', 'indrol2')

In [45]:
def embed_single(file_df):
    file = file_df["filename"]
    embeddings, logits = get_embeddings_and_logits(file)
    n_chunks = embeddings.shape[0]
    indices = range(n_chunks)
    names = [file.split("/")[1]] * n_chunks
    return names, indices, list(embeddings), list(logits)


def embed_species(species):
    tqdm.pandas()
    files = df_meta[df_meta["primary_label"] == species]
    cols = files.progress_apply(embed_single, axis=1)
    cols = zip(*cols)
    names, indices, embeddings, logits = [chain(*col) for col in cols]
    df = pd.DataFrame(
        {"name": names, "chunk_5s": indices, "embedding": embeddings, "logits": logits}
    )

    out_path = OUT_DIR / f"{species}.parquet"
    df.to_parquet(out_path, index=False)
    return df

In [55]:
species = "asiope1"

species_df = embed_species(species)
species_df.head()

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:11<00:00,  2.31s/it]


Unnamed: 0,name,chunk_5s,embedding,logits
0,XC194954.ogg,0,"[-0.031185264, -0.035795283, -0.054329205, 0.0...","[-9.128941, -13.177884, -13.67497, -12.337171,..."
1,XC194954.ogg,1,"[-0.08725081, -0.038439784, -0.017878823, -0.0...","[-10.217005, -11.332575, -11.622447, -9.567158..."
2,XC194954.ogg,2,"[-0.059770834, -0.0416055, -0.05187557, -0.077...","[-11.657135, -12.182461, -12.450273, -13.01046..."
3,XC397761.ogg,0,"[0.022343082, -0.010296039, -0.024152642, 0.03...","[-7.8223104, -10.272348, -12.720812, -11.36624..."
4,XC504755.ogg,0,"[-0.000396855, -0.12520109, -0.013607574, 0.13...","[-10.462232, -15.686266, -9.906326, -10.56645,..."
