In this notebook we're going to paramterize a word2vec model with tokens.

In [None]:
from pathlib import Path

scratch = Path("~/scratch/birdclef/2025").expanduser()
! tree {scratch}/mel2vec

[01;34m/storage/home/hcoda1/8/amiyaguchi3/scratch/birdclef/2025/mel2vec[0m
├── [01;34mtokenizer[0m
│   └── centroids.npy
└── [01;34mtokenizer_pca[0m
    ├── centroids.npy
    └── pca.bin

2 directories, 3 files


In [2]:
import faiss
import numpy as np
import polars as pl

# load the tokenizer
centroids = np.load(f"{scratch}/mel2vec/tokenizer/centroids.npy")
index = faiss.IndexFlatL2(centroids.shape[1])
index.add(centroids)

In [9]:
# now we load the dataset, learning word2vec only on 80% of the data
df = (
    pl.scan_parquet(f"{scratch}/mfcc-soundscape/data")
    .filter(pl.col("part") < 80)
    .sort("file", "timestamp")
)
X = np.stack(df.select("mfcc").collect().get_column("mfcc").to_numpy())
_, indices = index.search(X, 1)
ids = pl.Series("token", indices.flatten())
token_df = df.with_columns(ids)
token_df

In [19]:
from tqdm.auto import tqdm
from gensim.models.callbacks import CallbackAny2Vec


class TqdmCallback(CallbackAny2Vec):
    def __init__(self, total_epochs):
        self.total_epochs = total_epochs
        self.epoch_count = 0
        self.pbar = None

    def on_epoch_begin(self, model):
        if self.pbar is None:
            self.pbar = tqdm(total=self.total_epochs, desc="Epochs")
        self.epoch_count += 1
        self.pbar.set_description(
            f"Training Epoch {self.epoch_count}/{self.total_epochs}"
        )

    def on_epoch_end(self, model):
        if self.pbar is not None:
            self.pbar.update(1)
        current_loss = model.get_latest_training_loss()
        if current_loss is not None:
            self.pbar.set_postfix_str(f"Loss: {current_loss:.4f}", refresh=True)

    def on_train_end(self, model):
        if self.pbar is not None:
            self.pbar.close()
            self.pbar = None

In [None]:
from gensim.models import Word2Vec


# group by file, order by timestamp, and collect the tokens
def token_generator(df, limit=-1):
    if limit > 0:
        df = df.filter(pl.col("part") < limit)
    for sub in df.collect().partition_by("file"):
        yield sub.sort("timestamp").get_column("token").to_list()


model = Word2Vec(
    sentences=list(token_generator(token_df, limit=10)),
    vector_size=128,
    # 5 seconds, 8 frames per second = 40
    # can go to 10 seconds to have more context
    min_count=1,
    window=80,
    sg=1,
    negative=10,
    ns_exponent=0.75,
    sample=1e-3,
    workers=8,
    compute_loss=True,
    shrink_windows=True,
    epochs=5,
    callbacks=[TqdmCallback(total_epochs=5)],
)

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]