In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "h://spark-tmp/"

spark = get_spark(cores=24, memory="40g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v4"
)
df.printSchema()

preds = (
    spark.read.parquet("../data/processed/birdclef-2023/consolidated_v4_with_preds")
    .repartition(32)
    .cache()
)
preds.printSchema()

# also include the metadata
birdclef_root = "../data/raw/birdclef-2023"
train_metadata = spark.read.csv(f"{birdclef_root}/train_metadata.csv", header=True)
train_metadata.printSchema()

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)

root
 |-- track_name: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- prediction: string (nullable = true)
 |-- probability: double (nullable = true)

root
 |-- primary_label: string (nullable = true)
 |-- secondary

In [3]:
df.show(n=2, vertical=True, truncate=80)

-RECORD 0------------------------------------------------------------------------------------------
 species        | grecor                                                                           
 track_stem     | XC629875_part003                                                                 
 track_type     | source0                                                                          
 track_name     | grecor/XC629875_part003_source0.mp3                                              
 embedding      | [0.6731137633323669, 1.1389738321304321, 0.6284520626068115, 0.65399438142776... 
 prediction_vec | [-8.725186347961426, -7.3204827308654785, -9.82101821899414, -10.396224021911... 
 predictions    | [{0, 3026, Tetrastes bonasia_Hazel Grouse, hazgro1, 0.02235649898648262}, {1,... 
 start_time     | 75                                                                               
 energy         | 0.01598571054637432                                                              


In [6]:
import ast
import numpy as np


@F.udf(returnType="array<string>")
def parse_labels(label_str: str):
    # use literal eval to parse the string
    return ast.literal_eval(label_str)


@F.udf("array<float>")
def embedding_mean(v):
    return np.vstack(v).mean(axis=0).tolist()


def xc_id(track_name):
    return F.regexp_extract(track_name, r"XC(\d+)", 1)


def process(df, train_metadata):
    labels = train_metadata.select(
        "primary_label",
        F.array_union(F.array("primary_label"), parse_labels("secondary_labels")).alias(
            "metadata_species"
        ),
        xc_id("filename").alias("xc_id"),
    )

    align_to_window = (
        df.where("start_time % 5 = 0 or start_time % 2 = 0")
        .join(preds, on=["track_name", "start_time"])
        .withColumn("seq_id", (F.col("start_time") / 5).cast("int"))
        .select(
            "track_stem",
            "start_time",
            "seq_id",
            "embedding",
            "prediction",
            "probability",
        )
    ).cache()

    # compute the current and next embedding
    sequence_embedding = (
        align_to_window.where("track_type = 'original'")
        .groupBy("track_stem", "seq_id")
        .agg(
            F.collect_set("embedding").alias("embedding"),
            F.min("start_time").alias("start_time"),
        )
        .withColumn("embedding", embedding_mean("embedding"))
        # also include the next sequence, or the current sequence
        .withColumn(
            "next_embedding",
            F.coalesce(
                F.lead("embedding", 1).over(
                    Window.partitionBy("track_stem").orderBy("seq_id")
                ),
                "embedding",
            ),
        )
        .select("track_stem", "seq_id", "start_time", "embedding", "next_embedding")
    )

    # compute track embedding (global context)
    track_embedding = (
        # this is the average embedding of all of the sequences
        sequence_embedding.groupBy("track_stem")
        .agg(F.collect_set("embedding").alias("embedding"))
        .select("track_stem", embedding_mean("embedding").alias("track_embedding"))
    )

    # let's figure out what predictions there are at every sequence id
    mixit_predictions = (
        align_to_window.where("prediction != 'no_call' and probability > 0.7")
        .groupBy("track_stem", "seq_id")
        .agg(
            F.sort_array(
                F.collect_list(F.struct("probability", "prediction")), asc=False
            ).alias("values"),
            F.min("start_time").alias("start_time"),
        )
        .selectExpr(
            "track_stem",
            "seq_id",
            "values.prediction as predicted_species",
            "values.probability as predicted_species_prob",
        )
    )

    track_with_fuzzy_annotations = (
        sequence_embedding.join(track_embedding, on="track_stem")
        .join(mixit_predictions, on=["track_stem", "seq_id"])
        .withColumn("xc_id", xc_id("track_stem"))
        .join(labels, on=["xc_id"])
        .select(
            "track_stem",
            "start_time",
            "primary_label",
            "metadata_species",
            "predicted_species",
            "predicted_species_prob",
            "embedding",
            "next_embedding",
            "track_embedding",
        )
    )
    return track_with_fuzzy_annotations


subset = df.where("species = 'bltapa1'").cache()
processed_sample = process(subset, train_metadata)
processed_sample.printSchema()
processed_sample.show(n=2, vertical=True, truncate=80)

root
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- primary_label: string (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = false)
 |    |-- element: string (containsNull = true)
 |-- predicted_species_prob: array (nullable = false)
 |    |-- element: double (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------------
 track_stem             | XC113284                                                                         
 start_time             | 0                                                          

In [7]:
processed_df = process(df, train_metadata)
processed_df.write.parquet(
    "../data/intermediate/train_postprocessed_v4_00", mode="overwrite"
)

In [18]:
processed_df = spark.read.parquet("../data/intermediate/train_postprocessed_v4_00")
processed_df.printSchema()
processed_df.show(n=2, vertical=True, truncate=80)
processed_df.count()

root
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- primary_label: string (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species_prob: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------------
 track_stem             | XC109029                                                                         
 start_time             | 15                                                           

135324

In [20]:
@F.pandas_udf("array<float>", F.PandasUDFType.SCALAR)
def embedding_pair_mean(v1, v2, weight=0.5):
    return (v1 * weight + v2 * (1 - weight)) / 2


noise_examples = (
    df.withColumn("probability", F.col("predictions")[0]["probability"])
    .where("start_time % 3 = 0 and probability < 0.01")
    .orderBy("probability")
    .limit(1000)
    # take the best example for each track here
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("track_name").orderBy("probability")),
    )
    .where("rank = 1")
    .drop("rank")
    .select("embedding")
)

# we choose this value because we want to give more weight to the original
# example.
weight = F.lit(0.9)
augmented_examples = (
    processed_df.withColumn("species", F.col("metadata_species")[0])
    .withColumn(
        "rank", F.row_number().over(Window.partitionBy("species").orderBy(F.rand()))
    )
    .where("rank <= 20")
    # now lets cross join with the noise examples
    .crossJoin(noise_examples.selectExpr("embedding as noise_embedding").limit(20))
    # now randomly keep a subset of these examples
    .withColumn(
        "rank", F.row_number().over(Window.partitionBy("species").orderBy(F.rand()))
    )
    .where("rank <= 20")
    .select(
        "track_stem",
        "start_time",
        "primary_label",
        "metadata_species",
        "predicted_species",
        "predicted_species_prob",
        embedding_pair_mean("embedding", "noise_embedding", weight).alias("embedding"),
        embedding_pair_mean("next_embedding", "noise_embedding", weight).alias(
            "next_embedding"
        ),
        embedding_pair_mean("track_embedding", "noise_embedding", weight).alias(
            "track_embedding"
        ),
    )
)

processed_with_augmented = augmented_examples.union(processed_df)

In [21]:
processed_with_augmented.write.parquet(
    "../data/intermediate/train_postprocessed_v4_01", mode="overwrite"
)

In [22]:
processed_with_augmented = spark.read.parquet(
    "../data/intermediate/train_postprocessed_v4_01"
)

In [26]:
# now let's handle some filtering of the data.
# we'll generate a list of allowable overlapping species

coocurring = (
    processed_df.select(
        "primary_label", F.explode("metadata_species").alias("exploded_species")
    )
    .groupBy("primary_label")
    .agg(F.collect_set("exploded_species").alias("coocurring"))
)

coocurring.show(truncate=100)

+-------------+----------------------------------------------------------------------------------------------------+
|primary_label|                                                                                          coocurring|
+-------------+----------------------------------------------------------------------------------------------------+
|      abethr1|                                                         [rbsrob1, helgui, eswdov1, rindov, abethr1]|
|      abhori1|[amesun2, afghor1, hadibi1, rbsrob1, norfis1, crheag1, vilwea1, spemou2, somgre1, abhori1, egygoo...|
|      abythr1|                    [strsee1, abhori1, grbcam1, afrgos1, afdfly1, bswdov1, combul2, rindov, abythr1]|
|      afbfly1|                            [piecro1, trobou1, afrgrp1, grbcam1, yertin1, combul2, afbfly1, scrcha1]|
|      afdfly1|                            [yewgre1, amesun2, reftin1, afdfly1, yertin1, combul2, chtapa3, tamdov1]|
|      afecuc1|[grbcam1, bawhor2, yertin1, klacuc1, didcuc1, taf

In [32]:
filtered_train_df = (
    processed_with_augmented.join(coocurring, on="primary_label")
    .withColumn(
        "predicted_species", F.array_intersect("predicted_species", "coocurring")
    )
    .withColumn(
        "predicted_species",
        F.when(F.size("predicted_species") == 0, F.array("primary_label")).otherwise(
            F.col("predicted_species")
        ),
    )
    .drop("coocurring")
    .drop("predicted_species_prob")
)
filtered_train_df.printSchema()
filtered_train_df.show(n=2, vertical=True, truncate=80)

root
 |-- primary_label: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0---------------------------------------------------------------------------------------------
 primary_label     | ccbeat1                                                                          
 track_stem        | XC120632                                                                         
 start_time        | 5                                                                         

DataFrame[track_stem: string, start_time: bigint, predicted_species: array<string>, metadata_species: array<string>, predict_count: bigint]

In [38]:
filtered_train_df.write.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v4", mode="overwrite"
)

In [39]:
filtered_train_df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v4"
)
filtered_train_df.count()

140604

In [46]:
example = (
    filtered_train_df.withColumn(
        "predict_count",
        F.sum(F.size("predicted_species")).over(Window.partitionBy("track_stem")),
    )
    .withColumn("rank", F.rank().over(Window.orderBy(F.desc("predict_count"))))
    .where("rank = 1")
    .orderBy("start_time")
    .select(
        "track_stem",
        "start_time",
        "predicted_species",
        "metadata_species",
        "predict_count",
    )
)
example.show(n=100)

+----------+----------+--------------------+------------------+-------------+
|track_stem|start_time|   predicted_species|  metadata_species|predict_count|
+----------+----------+--------------------+------------------+-------------+
|  XC462100|         0|  [reccuc1, grbcam1]|[grbcam1, reccuc1]|           65|
|  XC462100|         5|  [grbcam1, reccuc1]|[grbcam1, reccuc1]|           65|
|  XC462100|        10|  [grbcam1, reccuc1]|[grbcam1, reccuc1]|           65|
|  XC462100|        15|  [grbcam1, reccuc1]|[grbcam1, reccuc1]|           65|
|  XC462100|        20|  [grbcam1, reccuc1]|[grbcam1, reccuc1]|           65|
|  XC462100|        25|           [grbcam1]|[grbcam1, reccuc1]|           65|
|  XC462100|        30|  [grbcam1, reccuc1]|[grbcam1, reccuc1]|           65|
|  XC462100|        35|  [grbcam1, reccuc1]|[grbcam1, reccuc1]|           65|
|  XC462100|        40|  [reccuc1, grbcam1]|[grbcam1, reccuc1]|           65|
|  XC462100|        45|  [grbcam1, abhori1]|[grbcam1, reccuc1]| 