In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "../data/tmp/spark"

spark = get_spark(cores=24, memory="30g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v3"
)
df.printSchema()

preds = spark.read.parquet("../data/processed/birdclef-2023/consolidated_v3_with_preds")
preds.printSchema()

# also include the metadata
birdclef_root = "../data/raw/birdclef-2023"
train_metadata = spark.read.csv(f"{birdclef_root}/train_metadata.csv", header=True)
train_metadata.printSchema()

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)

root
 |-- track_name: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- prediction: string (nullable = true)
 |-- probability: double (nullable = true)

root
 |-- primary_label: string (nullable = true)
 |-- secondary

In [3]:
noise_indices = [
    (1022, "Dog_Dog"),
    (1136, "Engine_Engine"),
    (1141, "Environmental_Environmental"),
    (1219, "Fireworks_Fireworks"),
    (1352, "Gun_Gun"),
    (1449, "Human non-vocal_Human non-vocal"),
    (1450, "Human vocal_Human vocal"),
    (1451, "Human whistle_Human whistle"),
    (1997, "Noise_Noise"),
    (2812, "Siren_Siren"),
]


def keep_top_n(df, n=250):
    return (
        df.withColumn(
            "rank",
            F.row_number().over(
                Window.partitionBy("index").orderBy(F.desc("probability"))
            ),
        )
        .where(f"rank <= {n}")
        .select(
            "track_name", "start_time", F.lit("no_call").alias("species"), "embedding"
        )
    )


exploded_noise = (
    df
    # explode the predictions with their indices
    .select(
        "track_name",
        "start_time",
        "embedding",
        F.posexplode("prediction_vec").alias("index", "logit"),
    )
    .where(F.col("index").isin([i[0] for i in noise_indices]))
    .withColumn("probability", F.expr("1/(1+exp(-logit))"))
).cache()


negative_samples = keep_top_n(exploded_noise).cache()

# environmental and noise, to average with the other embeddings
noise_samples = keep_top_n(
    exploded_noise.where(F.col("index").isin([1141, 1997])), n=50
).cache()

In [4]:
import ast
import numpy as np


@F.udf(returnType="array<string>")
def parse_labels(label_str: str):
    # use literal eval to parse the string
    return ast.literal_eval(label_str)


@F.pandas_udf("array<float>", F.PandasUDFType.GROUPED_AGG)
def embedding_mean(v):
    return np.stack(v).mean(axis=0).tolist()


# whats the lowest number tracks for a species?
primary_labels = train_metadata.select(
    F.col("primary_label").alias("species"), "filename"
)
# now find all the secondary labels and see how frequent they can be
secondary_labels = train_metadata.select(
    F.explode(parse_labels("secondary_labels")).alias("species"), "filename"
)

multi_label = (
    primary_labels.union(secondary_labels)
    .distinct()
    .groupBy("filename")
    .agg(F.collect_list("species").alias("species"))
    .where(F.size("species") > 1)
    # extract the xeno-canto id with regex
    .withColumn("xc_id", F.regexp_extract(F.col("filename"), r"XC(\d+)", 1))
)

track_averaged_samples = (
    df.withColumn("xc_id", F.regexp_extract(F.col("track_name"), r"XC(\d+)", 1))
    .drop("species")
    .join(multi_label, on="xc_id")
    .where("track_type = 'original'")
    .groupBy("species")
    .agg(embedding_mean("embedding").alias("embedding"))
)



In [5]:
# choose the track name that has the most predictions with the main class
most_representative_source = (
    preds.groupBy("track_name", "prediction")
    .count()
    .orderBy("track_name", F.desc("count"))
    .withColumn("species", F.split("track_name", "/").getItem(0))
    .where("track_name like '%source%'")
    .withColumn(
        "track_stem", F.split(F.split("track_name", "/").getItem(1), "_").getItem(0)
    )
    .where("species = prediction")
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("track_stem").orderBy(F.desc("count"))),
    )
    .where("rank = 1")
    .select("track_name", "prediction", "count")
)

most_representative_embeddings = preds.join(
    most_representative_source.select("track_name", "prediction"),
    on=["track_name", "prediction"],
)

representative_samples = (
    df
    # join against the most representative source, and only keep the embeddings
    # that are actually labeled by the previous model
    .join(
        most_representative_embeddings.select("track_name", "start_time", "prediction"),
        on=["track_name", "start_time"],
        how="inner",
    ).select("track_name", "start_time", "species", "embedding", "prediction")
).cache()

In [10]:
# for each of the representative samples, generate a new sample by averaging it with a
# random noise sample


@F.pandas_udf("array<float>", F.PandasUDFType.SCALAR)
def embedding_pair_mean(v1, v2):
    return (v1 + v2) / 2


@F.pandas_udf("array<float>", F.PandasUDFType.SCALAR)
def embedding_triple_mean(v1, v2, v3):
    return (v1 + v2 + v3) / 3


best_representative_sample_per_track_with_noise = (
    # generate a subset of rows, and then join against the noise samples
    representative_samples.withColumn(
        "rank",
        F.row_number().over(
            Window.partitionBy("track_name").orderBy(F.desc("prediction"))
        ),
    )
    .where("rank = 1")
    .drop("rank")
    .withColumn("random_rank", F.rand())
    .withColumn(
        "rank",
        F.row_number().over(
            Window.partitionBy("species").orderBy(F.desc("random_rank"))
        ),
    )
    .where("rank <= 100")
    .drop("rank", "random_rank")
    # cross join against the noise samples
    .join(noise_samples.selectExpr("embedding as noise_embedding"), how="cross")
    # now there should be at least 1000 samples per class
    .withColumn("random_rank", F.rand())
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("species").orderBy("random_rank")),
    )
    # 250k samples in total, this is a tad much for actual training so make sure
    # to subsample
    .where("rank <= 1000")
    .withColumn("embedding", embedding_pair_mean("embedding", "noise_embedding"))
    .withColumn("index", F.row_number().over(Window.orderBy(F.desc("random_rank"))))
    .select(F.array(F.col("species")).alias("species"), "embedding", "rank", "index")
).cache()

n_classes = primary_labels.select("species").distinct().count()

single_label_samples = (
    # ~62k samples
    best_representative_sample_per_track_with_noise.where("rank <= 250").select(
        "species", "embedding"
    )
)

# rough proportion of co-occurs, pairs and up
# 1, 1/3, 1/10

pair_label_samples = (
    # we generate 100 per class for 25k samples
    best_representative_sample_per_track_with_noise.withColumn("random_rank", F.rand())
    .withColumn("random_index", (F.rand() * 1e6).cast("int") % (n_classes * 500))
    .withColumn(
        "rank",
        F.row_number().over(
            Window.partitionBy("species").orderBy(F.desc("random_rank"))
        ),
    )
    .where("rank <= 100")
    .join(
        best_representative_sample_per_track_with_noise.selectExpr(
            "species as other_species",
            "embedding as other_embedding",
            "index as random_index",
        ),
        on="random_index",
    )
    .select(
        F.array_union(F.col("species"), F.col("other_species")).alias("species"),
        embedding_pair_mean("embedding", "other_embedding").alias("embedding"),
    )
)

# NOTE: this could surely be written better, but eh
triple_label_samples = (
    # we generate 30 per class for about 7.5k samples
    best_representative_sample_per_track_with_noise.withColumn("random_rank", F.rand())
    .withColumn("random_index_1", (F.rand() * 1e6).cast("int") % (n_classes * 1000))
    .withColumn("random_index_2", (F.rand() * 1e6).cast("int") % (n_classes * 1000))
    .withColumn(
        "rank",
        F.row_number().over(
            Window.partitionBy("species").orderBy(F.desc("random_rank"))
        ),
    )
    .where("rank <= 30")
    .join(
        best_representative_sample_per_track_with_noise.selectExpr(
            "species as species_1",
            "embedding as embedding_1",
            "index as random_index_1",
        ),
        on="random_index_1",
    )
    .join(
        best_representative_sample_per_track_with_noise.selectExpr(
            "species as species_2",
            "embedding as embedding_2",
            "index as random_index_2",
        ),
        on="random_index_2",
    )
    .select(
        F.array_union(
            F.array_union(F.col("species"), F.col("species_1")), F.col("species_2")
        ).alias("species"),
        embedding_triple_mean("embedding", "embedding_1", "embedding_2").alias(
            "embedding"
        ),
    )
)

In [12]:
train_df = (
    track_averaged_samples.union(
        representative_samples.select(F.array("species").alias("species"), "embedding")
    )
    .union(negative_samples.select(F.array("species").alias("species"), "embedding"))
    .union(single_label_samples)
    .union(pair_label_samples)
    .union(triple_label_samples)
    .select("species", "embedding")
)

# let's write this to a parquet file, because we'll probably want to load this
# a few times, plus its good to version

train_df.write.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v1", mode="overwrite"
)

In [15]:
train_df = spark.read.parquet("../data/processed/birdclef-2023/train_postprocessed/v1")
train_df.show(n=5)
train_df.count()

+---------+--------------------+
|  species|           embedding|
+---------+--------------------+
|[gobsta5]|[0.61802852153778...|
|[chespa1]|[0.61116945743560...|
|[golher1]|[0.90223264694213...|
|[marsto1]|[0.61373400688171...|
|[gobwea1]|[0.62747323513031...|
+---------+--------------------+
only showing top 5 rows



239569