In [None]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [None]:
from birdclef.utils import get_spark

spark = get_spark(memory="2g")

In [None]:
df = spark.read.parquet(
    "../data/processed/birdclef-2022/birdnet-embeddings-with-neighbors/v1"
)
df.printSchema()
df.show(n=1, vertical=True)

In [None]:
df.groupby("primary_label").count().sort("count", ascending=False).show(10)

In [None]:
import pyspark.sql.functions as F
import matplotlib.pyplot as plt

mean_matching = (
    df.select(
        F.expr("birdnet_label = primary_label")
        .cast("integer")
        .alias("matching_labels"),
        "primary_label",
    )
    .groupby("primary_label")
    .agg(F.avg("matching_labels").alias("matched"))
    .orderBy("matched", ascending=False)
)

mean_matching.show(10)

pdf = mean_matching.toPandas()
plt.hist(pdf["matched"], bins=20)
plt.title("Histogram of mean of matching labels between birdnet and xeno-canto")
plt.xlabel("Mean of matching labels")
plt.ylabel("Count")
plt.show()

In [None]:
exploded_neighborhood = df.select(
    "id",
    F.explode(F.arrays_zip("neighbors", "distances")).alias("neighbor"),
).select("id", "neighbor.*")

exploded_neighborhood.show(10)

In [None]:
# note that we can take advantage of the fact that ids are sorted by filename
# (including species and file) and time.

diffs = exploded_neighborhood.select(F.expr("log(abs(neighbors - id)+1)").alias("diff"))
pdf = diffs.toPandas()
plt.hist(pdf["diff"], bins=20)
plt.title("Histogram of log differences between neighbors and id")
plt.xlabel("Difference")
plt.ylabel("Count")
plt.show()

We take log distances because otherwise it heavily skews to one side.
It's interesting to see that we parts on the long tail, where neighbors come from some distant cluster.
This means that there are definitely related species (or sounds) in the dataset.

One improvement that we can make is to sort species by projecting all the entries on a single dimension (a line) and to find differences on that line.
Actually lets go ahead and try out that idea.

In [None]:
import numpy as np
import umap

X = np.stack(df.select("emb").toPandas().emb)
reducer = umap.UMAP(n_components=1, verbose=True)
emb = reducer.fit_transform(X)

In [None]:
# what does this look like? a histogram?
# a histogram probably has highest density around clusters

plt.hist(emb, bins=20)
plt.title("Histogram of UMAP embedding")
plt.xlabel("UMAP embedding")
plt.ylabel("Count")
plt.show()

# what about a line plot? This will show the difference in ordering between elements.
# I'd expect
plt.scatter(np.arange(len(emb)), emb, s=1, alpha=0.1)
plt.title("Line plot of UMAP embedding")
plt.xlabel("Index")
plt.ylabel("UMAP embedding")

In [None]:
# now create an index that's based on the ordering of the embedding
import pandas as pd

index_df = (
    pd.DataFrame({"line_emb": emb[:, 0], "id": df.select("id").toPandas().id})
    .sort_values("line_emb")
    .reset_index()
    .drop("index", axis=1)
    .reset_index()
    .rename(columns={"index": "index_id"})
)
index_df

In [None]:
mapping = spark.createDataFrame(index_df)
mapping.show(10)

In [None]:
mapped_neighbors = exploded_neighborhood.join(
    mapping.select("id", F.col("index_id").alias("mapped_id")), on="id"
).join(
    mapping.select(
        F.col("id").alias("neighbors"), F.col("index_id").alias("mapped_neighbors")
    ),
    on="neighbors",
)
mapped_neighbors.show()

In [None]:
# plot distances again
diffs = mapped_neighbors.select(
    F.expr("log(abs(mapped_neighbors - mapped_id)+1)").alias("diff")
)
pdf = diffs.toPandas()
plt.hist(pdf["diff"], bins=20)
plt.title("Histogram of log differences between neighbors and id")
plt.xlabel("Difference")
plt.ylabel("Count")
plt.show()

In [None]:
diffs = exploded_neighborhood.select(F.expr("log(distances+1)").alias("diff"))
pdf = diffs.toPandas()
plt.hist(pdf["diff"], bins=20)
plt.title("Histogram of log distances between neighbors and id")
plt.xlabel("Difference")
plt.ylabel("Count")
plt.show()