How close can we get our mel2vec to fit to the psuedo-labels generated perch? Hopefully the answer is close...

In [1]:
import polars as pl
from pathlib import Path
import numpy as np
from gensim.models import KeyedVectors
import faiss


# we'll join the two datasets on start time with a udf
def get_start_time(timestamp, interval=5) -> int:
    # up to but not including the value
    for i in range(0, 100, interval):
        if i <= timestamp < i + interval:
            return i
    return -1


shared_root = Path("~/shared/birdclef").expanduser()
scratch_root = Path("~/scratch/birdclef").expanduser()

# let's join this with the data that we have for the mfcc dataset
mfcc = pl.scan_parquet(f"{scratch_root}/2025/mfcc-soundscape/data").with_columns(
    pl.col("timestamp")
    .map_elements(get_start_time, return_dtype=pl.Int64)
    .alias("start_time")
)
display(mfcc.collect_schema())

Schema([('index', Int64),
        ('file', String),
        ('timestamp', Float64),
        ('mfcc', List(Float32)),
        ('part', Int64),
        ('start_time', Int64)])

In [2]:
# tokenizer
centroids = np.load(f"{scratch_root}/2025/mel2vec/tokenizer/centroids.npy")
index = faiss.IndexFlatL2(centroids.shape[1])
index.add(centroids)

prefix = "tokenizer=tokenizer/vector_size=256/window=80/ns_exponent=0.75/sample=0.0001/epochs=100"
word_vectors = KeyedVectors.load(
    f"{scratch_root}/2025/mel2vec/word2vec/{prefix}/word2vec.wordvectors"
)
display(word_vectors.index_to_key[:10])


def mfcc_to_wv(mfcc: list) -> list:
    # convert mfcc to word vectors
    X = np.array(mfcc).reshape(1, -1)
    _, indices = index.search(X, 1)  # get the closest centroid
    return word_vectors[indices[0][0]].tolist()


def aggregate_mfcc(group: pl.DataFrame) -> pl.DataFrame:
    X_mfcc = np.stack(group.get_column("mfcc").to_numpy())
    X_w2v = np.stack(group.get_column("word_vector").to_numpy())
    return pl.DataFrame(
        {
            "file": group.get_column("file").to_numpy()[0],
            "start_time": group.get_column("start_time").to_numpy()[0],
            "mfcc_stats": [X_mfcc.mean(axis=0).tolist() + X_mfcc.std(axis=0).tolist()],
            "word_vector": [X_w2v.mean(axis=0).tolist()],
        }
    )

[6122, 13688, 1185, 9798, 4637, 10836, 6358, 9107, 10603, 12453]

In [None]:
processed = (
    mfcc.with_columns(
        pl.col("mfcc")
        .map_elements(mfcc_to_wv, return_dtype=pl.List(pl.Float64))
        .alias("word_vector")
    )
    .group_by("file", "start_time")
    .map_groups(
        aggregate_mfcc,
        schema=pl.Schema(
            {
                "file": pl.Utf8,
                "start_time": pl.Int64,
                "mfcc_stats": pl.List(pl.Float64),
                "word_vector": pl.List(pl.Float64),
            }
        ),
    )
    .sort("file", "start_time")
)

# write this to parquet
processed.sink_parquet(
    f"{scratch_root}/2025/mel2vec/mfcc-word-vector",
    compression="zstd",
)

Well, we put this in a script because it takes a while to run. 