In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [1]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "h://spark-tmp/"

spark = get_spark(cores=24, memory="40g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v4"
)
df.printSchema()

preds = (
    spark.read.parquet("../data/processed/birdclef-2023/consolidated_v4_with_preds")
    .repartition(32)
    .cache()
)
preds.printSchema()

# also include the metadata
birdclef_root = "../data/raw/birdclef-2023"
train_metadata = spark.read.csv(f"{birdclef_root}/train_metadata.csv", header=True)
train_metadata.printSchema()

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)

root
 |-- track_name: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- prediction: string (nullable = true)
 |-- probability: double (nullable = true)

root
 |-- primary_label: string (nullable = true)
 |-- secondary

In [2]:
df.show(n=2, vertical=True, truncate=80)

-RECORD 0------------------------------------------------------------------------------------------
 species        | grecor                                                                           
 track_stem     | XC629875_part003                                                                 
 track_type     | source0                                                                          
 track_name     | grecor/XC629875_part003_source0.mp3                                              
 embedding      | [0.6731137633323669, 1.1389738321304321, 0.6284520626068115, 0.65399438142776... 
 prediction_vec | [-8.725186347961426, -7.3204827308654785, -9.82101821899414, -10.396224021911... 
 predictions    | [{0, 3026, Tetrastes bonasia_Hazel Grouse, hazgro1, 0.02235649898648262}, {1,... 
 start_time     | 75                                                                               
 energy         | 0.01598571054637432                                                              


In [3]:
train_metadata.groupBy("primary_label").count().orderBy("count").where(
    "count > 10"
).show(n=10)

+-------------+-----+
|primary_label|count|
+-------------+-----+
|      bltapa1|   11|
|      yenspu1|   12|
|      refbar2|   12|
|      grccra1|   12|
|      augbuz1|   12|
|      spfwea1|   12|
|      blksaw1|   13|
|      ccbeat1|   13|
|      yeccan1|   13|
|      spfbar1|   13|
+-------------+-----+
only showing top 10 rows



In [4]:
subset = df.where("species = 'bltapa1'").cache()
subset.count()

1540

In [5]:
preds.groupBy("prediction").count().orderBy(F.desc("count")).show()

+----------+------+
|prediction| count|
+----------+------+
|   thrnig1|367509|
|   no_call|346374|
|   combuz1|192776|
|    wlwwar|140779|
|    barswa|133200|
|   eubeat1|111025|
|    comsan|110426|
|    hoopoe|110322|
|   cohmar1|106396|
|   eaywag1| 93031|
|    woosan| 60836|
|   blakit1| 48823|
|   combul2| 45950|
|   colsun2| 44934|
|   rbsrob1| 43741|
|    litegr| 42906|
|   gnbcam2| 42606|
|   yertin1| 36798|
|   fotdro5| 36190|
|   afecuc1| 35527|
+----------+------+
only showing top 20 rows



First, we only consider rows that match every first and third index.
Then out of this set, we create labels from the sound separate tracks to build a multi-label dataset.
we only keep the time steps that are labeled.

In [46]:
import ast
import numpy as np


@F.udf(returnType="array<string>")
def parse_labels(label_str: str):
    # use literal eval to parse the string
    return ast.literal_eval(label_str)


@F.udf("array<float>")
def embedding_mean(v):
    return np.vstack(v).mean(axis=0).tolist()


def xc_id(track_name):
    return F.regexp_extract(track_name, r"XC(\d+)", 1)


def process(df, train_metadata):
    labels = train_metadata.select(
        F.array_union(F.array("primary_label"), parse_labels("secondary_labels")).alias(
            "metadata_species"
        ),
        xc_id("filename").alias("xc_id"),
    )

    align_to_window = (
        df.where("start_time % 2 = 0 and start_time % 5 = 0")
        .join(preds, on=["track_name", "start_time"])
        .withColumn("seq_id", (F.col("start_time") / 5).cast("int"))
        .select(
            "track_stem",
            "start_time",
            "seq_id",
            "embedding",
            "prediction",
            "probability",
        )
    ).cache()

    # compute the current and next embedding
    sequence_embedding = (
        align_to_window.where("track_type = 'original'")
        .groupBy("track_stem", "seq_id")
        .agg(
            F.collect_set("embedding").alias("embedding"),
            F.min("start_time").alias("start_time"),
        )
        .withColumn("embedding", embedding_mean("embedding"))
        # also include the next sequence, or the current sequence
        .withColumn(
            "next_embedding",
            F.coalesce(
                F.lead("embedding", 1).over(
                    Window.partitionBy("track_stem").orderBy("seq_id")
                ),
                "embedding",
            ),
        )
        .select("track_stem", "seq_id", "start_time", "embedding", "next_embedding")
    )

    # compute track embedding (global context)
    track_embedding = (
        # this is the average embedding of all of the sequences
        sequence_embedding.groupBy("track_stem")
        .agg(F.collect_set("embedding").alias("embedding"))
        .select("track_stem", embedding_mean("embedding").alias("track_embedding"))
    )

    # let's figure out what predictions there are at every sequence id
    mixit_predictions = (
        align_to_window.where("prediction != 'no_call' and probability > 0.5")
        .groupBy("track_stem", "seq_id")
        .agg(
            F.sort_array(
                F.collect_list(F.struct("probability", "prediction")), asc=False
            ).alias("values"),
            F.min("start_time").alias("start_time"),
        )
        .selectExpr(
            "track_stem",
            "seq_id",
            "values.prediction as predicted_species",
            "values.probability as predicted_species_prob",
        )
    )

    track_with_fuzzy_annotations = (
        sequence_embedding.join(track_embedding, on="track_stem")
        .join(mixit_predictions, on=["track_stem", "seq_id"])
        .withColumn("xc_id", xc_id("track_stem"))
        .join(labels, on=["xc_id"])
        .select(
            "track_stem",
            "start_time",
            "metadata_species",
            "predicted_species",
            "predicted_species_prob",
            "embedding",
            "next_embedding",
            "track_embedding",
        )
    )
    return track_with_fuzzy_annotations


processed_sample = process(subset, train_metadata)
processed_sample.printSchema()
processed_sample.show(n=2, vertical=True, truncate=80)

root
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = false)
 |    |-- element: string (containsNull = true)
 |-- predicted_species_prob: array (nullable = false)
 |    |-- element: double (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------------
 track_stem             | XC753135                                                                         
 start_time             | 10                                                                               
 metadata_species     

In [47]:
processed_df = process(df, train_metadata)
processed_df.write.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v2", mode="overwrite"
)

In [49]:
processed_df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v2"
)
processed_df.printSchema()
processed_df.show(n=2, vertical=True, truncate=80)
processed_df.count()

root
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species_prob: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------------
 track_stem             | XC108820                                                                         
 start_time             | 0                                                                                
 metadata_species       

71533