In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "h://spark-tmp/"

spark = get_spark(cores=24, memory="40g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v4"
)
df.printSchema()

preds = (
    spark.read.parquet("../data/processed/birdclef-2023/consolidated_v4_with_preds")
    .repartition(32)
    .cache()
)
preds.printSchema()

# also include the metadata
birdclef_root = "../data/raw/birdclef-2023"
train_metadata = spark.read.csv(f"{birdclef_root}/train_metadata.csv", header=True)
train_metadata.printSchema()

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)

root
 |-- track_name: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- prediction: string (nullable = true)
 |-- probability: double (nullable = true)

root
 |-- primary_label: string (nullable = true)
 |-- secondary

In [3]:
df.show(n=2, vertical=True, truncate=80)

-RECORD 0------------------------------------------------------------------------------------------
 species        | grecor                                                                           
 track_stem     | XC629875_part003                                                                 
 track_type     | source0                                                                          
 track_name     | grecor/XC629875_part003_source0.mp3                                              
 embedding      | [0.6731137633323669, 1.1389738321304321, 0.6284520626068115, 0.65399438142776... 
 prediction_vec | [-8.725186347961426, -7.3204827308654785, -9.82101821899414, -10.396224021911... 
 predictions    | [{0, 3026, Tetrastes bonasia_Hazel Grouse, hazgro1, 0.02235649898648262}, {1,... 
 start_time     | 75                                                                               
 energy         | 0.01598571054637432                                                              


In [23]:
import ast
import numpy as np


@F.udf(returnType="array<string>")
def parse_labels(label_str: str):
    # use literal eval to parse the string
    return ast.literal_eval(label_str)


@F.udf("array<float>")
def embedding_mean(v):
    return np.vstack(v).mean(axis=0).tolist()


def xc_id(track_name):
    return F.regexp_extract(track_name, r"XC(\d+)", 1)


def process(df, train_metadata):
    labels = train_metadata.select(
        "primary_label",
        F.array_union(F.array("primary_label"), parse_labels("secondary_labels")).alias(
            "metadata_species"
        ),
        xc_id("filename").alias("xc_id"),
    )

    align_to_window = (
        df.where("start_time % 5 = 0 or (start_time + 2) % 5 = 0")
        .join(preds, on=["track_name", "start_time"])
        .withColumn("seq_id", (F.col("start_time") / 5).cast("int"))
        .withColumn("birdnet_probability", F.col("predictions")[0]["probability"])
        .select(
            "track_stem",
            "track_type",
            "start_time",
            "seq_id",
            "embedding",
            "prediction",
            "birdnet_probability",
        )
    ).cache()

    # compute the current and next embedding
    sequence_embedding = (
        align_to_window.groupBy("track_stem", "track_type", "seq_id")
        .agg(
            # collect the embedding as an array sorted by start time
            F.sort_array(F.collect_list(F.struct("start_time", "embedding"))).alias(
                "embedding"
            ),
            F.min("start_time").alias("start_time"),
            F.max("birdnet_probability").alias("probability"),
        )
        .where(F.size("embedding") == 2)
        .withColumn("embedding", F.flatten("embedding.embedding"))
        # also include the next sequence, or the current sequence
        .withColumn(
            "next_embedding",
            F.coalesce(
                F.lead("embedding", 1).over(
                    Window.partitionBy("track_stem", "track_type").orderBy("seq_id")
                ),
                "embedding",
            ),
        )
        .select(
            "track_stem",
            "track_type",
            "seq_id",
            "start_time",
            "probability",
            "embedding",
            "next_embedding",
        )
    ).cache()

    # compute track embedding (global context)
    track_embedding = (
        # this is the average embedding of all of the sequences
        sequence_embedding.groupBy("track_stem", "track_type")
        .agg(F.collect_set("embedding").alias("embedding"))
        .select(
            "track_stem",
            "track_type",
            embedding_mean("embedding").alias("track_embedding"),
        )
    )

    # the most representative source is the track that has the most predictions
    # with the main class
    most_representative_source = (
        align_to_window.where("track_type <> 'original'")
        .withColumn("xc_id", xc_id("track_stem"))
        .join(labels, on=["xc_id"])
        .where("primary_label = prediction")
        .groupBy("track_stem", "track_type")
        .count()
        .withColumn(
            "rank",
            F.row_number().over(
                Window.partitionBy("track_stem").orderBy(F.desc("count"))
            ),
        )
        .where("rank = 1")
        .select("track_stem", "track_type")
    )

    track_with_fuzzy_annotations = (
        sequence_embedding.join(track_embedding, on=["track_stem", "track_type"])
        .join(
            most_representative_source.union(
                align_to_window.select(
                    "track_stem", F.lit("original").alias("track_type")
                )
            ).distinct(),
            on=["track_stem", "track_type"],
        )
        .withColumn("xc_id", xc_id("track_stem"))
        .join(labels, on=["xc_id"])
        .select(
            "track_stem",
            "track_type",
            "start_time",
            "primary_label",
            "metadata_species",
            "probability",
            "embedding",
            "next_embedding",
            "track_embedding",
        )
    )
    return track_with_fuzzy_annotations


subset = df.where("species = 'bltapa1'").cache()
processed_sample = process(subset, train_metadata)
processed_sample.printSchema()
processed_sample.show(n=3, vertical=True, truncate=80)

root
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- primary_label: string (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- probability: double (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------
 track_stem       | XC237745                                                                         
 track_type       | original                                                                         
 start_time       | 0                                                                                
 prim

In [24]:
processed_df = process(df, train_metadata)
processed_df.write.parquet(
    "../data/intermediate/train_postprocessed_v6_00", mode="overwrite"
)

In [26]:
processed_df = spark.read.parquet("../data/intermediate/train_postprocessed_v6_00")
processed_df.printSchema()
processed_df.show(n=2, vertical=True, truncate=80)
processed_df.select(F.size("embedding")).show(n=3)
processed_df.count()

root
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- primary_label: string (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- probability: double (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------
 track_stem       | XC116777                                                                         
 track_type       | source1                                                                          
 start_time       | 0                                                                                
 prim

261006

In [32]:
counts = (
    processed_df.where("probability > 0.1")
    .groupBy("primary_label")
    .count()
    .orderBy("count")
)
counts.show()
counts.count()

+-------------+-----+
|primary_label|count|
+-------------+-----+
|      whhsaw1|    1|
|      whctur2|    1|
|      afpkin1|    2|
|      golher1|    4|
|      lotlap1|    4|
|      rehblu1|    5|
|      dotbar1|    8|
|      crefra2|    8|
|      brtcha1|    9|
|      lotcor1|   12|
|      darter3|   12|
|      fatwid1|   12|
|      brcwea1|   13|
|      blksaw1|   17|
|      sacibi2|   19|
|      rostur1|   19|
|      witswa1|   19|
|      yebsto1|   20|
|      joygre1|   21|
|      stusta1|   22|
+-------------+-----+
only showing top 20 rows



264

In [35]:
processed_df.orderBy("probability").show()

+----------------+----------+----------+-------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|      track_stem|track_type|start_time|primary_label|    metadata_species|         probability|           embedding|      next_embedding|     track_embedding|
+----------------+----------+----------+-------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|        XC660390|   source0|        20|      hadibi1|  [hadibi1, somgre1]|4.973471590119516E-8|[1.06636977195739...|[0.86483210325241...|[1.0912713, 1.591...|
|        XC244601|   source2|        30|      refbar2|           [refbar2]|2.461477208726137E-7|[1.29924261569976...|[0.43837857246398...|[0.5527429, 1.294...|
|        XC156754|   source2|        20|      tamdov1|  [tamdov1, combul2]|3.206638155006658E-7|[0.40902122855186...|[0.73934954404830...|[0.7265094, 1.464...|
|        XC423275|  original|        55|

In [42]:
@F.pandas_udf("array<float>", F.PandasUDFType.SCALAR)
def embedding_pair_mean(v1, v2, weight=0.5):
    return (v1 * weight + v2 * (1 - weight)) / 2


noise_examples = (
    processed_df.orderBy("probability")
    .selectExpr("embedding as noise_embedding")
    .limit(20)
)

# we choose this value because we want to give more weight to the original
# example.
weight = F.lit(0.9)
augmented_examples = (
    processed_df.withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("primary_label").orderBy(F.rand())),
    )
    .where("rank <= 20")
    # now lets cross join with the noise examples
    .crossJoin(noise_examples)
    # now randomly keep a subset of these examples
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("primary_label").orderBy(F.rand())),
    )
    .where("rank <= 20")
    .select(
        "track_stem",
        "track_type",
        "start_time",
        "primary_label",
        "metadata_species",
        "probability",
        embedding_pair_mean("embedding", "noise_embedding", weight).alias("embedding"),
        embedding_pair_mean("next_embedding", "noise_embedding", weight).alias(
            "next_embedding"
        ),
        embedding_pair_mean("track_embedding", "noise_embedding", weight).alias(
            "track_embedding"
        ),
    )
)

processed_with_augmented = augmented_examples.union(processed_df)



In [43]:
processed_with_augmented.write.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v6", mode="overwrite"
)

In [44]:
processed_with_augmented = spark.read.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v6"
)
processed_with_augmented.count()

266286