In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "h://spark-tmp/"

spark = get_spark(cores=24, memory="40g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v4"
)
df.printSchema()

preds = (
    spark.read.parquet("../data/processed/birdclef-2023/consolidated_v4_with_preds")
    .repartition(32)
    .cache()
)
preds.printSchema()

# also include the metadata
birdclef_root = "../data/raw/birdclef-2023"
train_metadata = spark.read.csv(f"{birdclef_root}/train_metadata.csv", header=True)
train_metadata.printSchema()

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)

root
 |-- track_name: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- prediction: string (nullable = true)
 |-- probability: double (nullable = true)

root
 |-- primary_label: string (nullable = true)
 |-- secondary

In [3]:
df.show(n=2, vertical=True, truncate=80)

-RECORD 0------------------------------------------------------------------------------------------
 species        | grecor                                                                           
 track_stem     | XC629875_part003                                                                 
 track_type     | source0                                                                          
 track_name     | grecor/XC629875_part003_source0.mp3                                              
 embedding      | [0.6731137633323669, 1.1389738321304321, 0.6284520626068115, 0.65399438142776... 
 prediction_vec | [-8.725186347961426, -7.3204827308654785, -9.82101821899414, -10.396224021911... 
 predictions    | [{0, 3026, Tetrastes bonasia_Hazel Grouse, hazgro1, 0.02235649898648262}, {1,... 
 start_time     | 75                                                                               
 energy         | 0.01598571054637432                                                              


In [44]:
import ast
import numpy as np


@F.udf(returnType="array<string>")
def parse_labels(label_str: str):
    # use literal eval to parse the string
    return ast.literal_eval(label_str)


@F.udf("array<float>")
def embedding_mean(v):
    return np.vstack(v).mean(axis=0).tolist()


def xc_id(track_name):
    return F.regexp_extract(track_name, r"XC(\d+)", 1)


def process(df, train_metadata):
    labels = train_metadata.select(
        "primary_label",
        F.array_union(F.array("primary_label"), parse_labels("secondary_labels")).alias(
            "metadata_species"
        ),
        xc_id("filename").alias("xc_id"),
    )

    align_to_window = (
        (
            df.where("start_time % 5 = 0 or (start_time + 2) % 5 = 0")
            .join(preds, on=["track_name", "start_time"])
            .withColumn("seq_id", (F.col("start_time") / 5).cast("int"))
            .withColumn("birdnet_probability", F.col("predictions")[0]["probability"])
            .select(
                "track_stem",
                "track_type",
                "start_time",
                "seq_id",
                "embedding",
                "prediction_vec",
                "prediction",
                "birdnet_probability",
            )
        )
        .repartition(200)
        .cache()
    )

    # compute the current and next embedding
    sequence_embedding = (
        align_to_window.groupBy("track_stem", "track_type", "seq_id")
        .agg(
            # collect the embedding as an array sorted by start time
            F.sort_array(F.collect_list(F.struct("start_time", "embedding"))).alias(
                "embedding"
            ),
            F.sort_array(
                F.collect_list(F.struct("start_time", "prediction_vec"))
            ).alias("prediction_vec"),
            F.min("start_time").alias("start_time"),
            F.max("birdnet_probability").alias("probability"),
        )
        .where(F.size("embedding") == 2)
        .withColumn("embedding", F.flatten("embedding.embedding"))
        .withColumn("prediction_vec", embedding_mean("prediction_vec.prediction_vec"))
        .select(
            "track_stem",
            "track_type",
            "seq_id",
            "start_time",
            "probability",
            "embedding",
            "prediction_vec",
        )
        .repartition(200)
    ).cache()

    # the most representative source is the track that has the most predictions
    # with the main class
    most_representative_source = (
        align_to_window.where("track_type <> 'original'")
        .withColumn("xc_id", xc_id("track_stem"))
        .join(labels, on=["xc_id"])
        .where("primary_label = prediction")
        .groupBy("track_stem", "track_type")
        .count()
        .withColumn(
            "rank",
            F.row_number().over(
                Window.partitionBy("track_stem").orderBy(F.desc("count"))
            ),
        )
        .where("rank = 1")
        .select("track_stem", "track_type")
    )

    track_with_fuzzy_annotations = (
        sequence_embedding.join(
            most_representative_source.union(
                align_to_window.select(
                    "track_stem", F.lit("original").alias("track_type")
                )
            ).distinct(),
            on=["track_stem", "track_type"],
        )
        .withColumn("xc_id", xc_id("track_stem"))
        .join(labels, on=["xc_id"])
        .select(
            "track_stem",
            "track_type",
            "start_time",
            "primary_label",
            "metadata_species",
            "probability",
            "embedding",
            "prediction_vec",
        )
    )
    return track_with_fuzzy_annotations


subset = df.where("species = 'bltapa1'").cache()
processed_sample = process(subset, train_metadata)
processed_sample.printSchema()
processed_sample.show(n=3, vertical=True, truncate=80)
processed_sample.groupBy("track_type").count().show()

root
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- primary_label: string (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- probability: double (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------
 track_stem       | XC239961                                                                         
 track_type       | source3                                                                          
 start_time       | 0                                                                                
 primary_label    | bltapa1                                                                        

In [45]:
processed_df = process(df, train_metadata)
processed_df.write.parquet(
    "../data/intermediate/train_postprocessed_v7", mode="overwrite"
)

In [47]:
processed_df = spark.read.parquet("../data/intermediate/train_postprocessed_v7")
processed_df.printSchema()
processed_df.show(n=2, vertical=True, truncate=80)
processed_df.select(F.size("embedding"), F.size("prediction_vec")).show(n=3)
processed_df.count()

root
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- primary_label: string (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- probability: double (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0--------------------------------------------------------------------------------------------
 track_stem       | XC116777                                                                         
 track_type       | source1                                                                          
 start_time       | 15                                                                               
 primary_label    | ratcis1                                                                        

261006

In [49]:
processed_df.groupBy("track_type").count().show()

+----------+------+
|track_type| count|
+----------+------+
|   source1|  4798|
|   source2| 13313|
|   source0| 54257|
|   source3| 55316|
|  original|133322|
+----------+------+



In [58]:
def process_for_training(df):
    count = (
        df.where("probability > 0.5")
        .groupBy("primary_label")
        .count()
        .where("count > 10")
    )
    common = (
        df.join(count.select("primary_label"), on="primary_label")
        .where("track_type <> 'original'")
        .withColumn(
            "species",
            F.when(F.expr("probability > 0.5"), F.col("primary_label")).otherwise(
                F.lit("no_call")
            ),
        )
    )
    rare = (
        df.join(count.select("primary_label"), on="primary_label", how="left_anti")
        .where("track_type <> 'original'")
        .withColumn(
            "species",
            F.when(F.expr("probability > 0.1"), F.col("primary_label")).otherwise(
                F.lit("no_call")
            ),
        )
    )
    res = common.union(rare).select("track_stem", "start_time", "species")
    # now also use the original tracks
    return df.join(res, on=["track_stem", "start_time"]).select(
        "track_stem",
        "track_type",
        "start_time",
        "species",
        "embedding",
        "prediction_vec",
    )


process_no_call = process_for_training(processed_df)

In [59]:
process_no_call.write.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v7", mode="overwrite"
)

In [60]:
process_no_call = spark.read.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v7"
)
process_no_call.count()

255372

In [63]:
process_no_call.groupBy("track_type").count().show()

+----------+------+
|track_type| count|
+----------+------+
|   source1|  4798|
|   source2| 13313|
|   source0| 54257|
|   source3| 55318|
|  original|127686|
+----------+------+



In [64]:
counts = process_no_call.groupBy("species").count().orderBy(F.desc("count"))
counts.show(n=5)
counts.orderBy("count").show(n=5)
counts.count()

+-------+------+
|species| count|
+-------+------+
|no_call|114882|
|thrnig1| 11664|
|eubeat1|  8362|
| hoopoe|  7604|
| wlwwar|  5930|
+-------+------+
only showing top 5 rows

+-------+-----+
|species|count|
+-------+-----+
|rostur1|    2|
|lotlap1|    2|
|litwea1|    2|
|afpkin1|    2|
|witswa1|    2|
+-------+-----+
only showing top 5 rows



264

In [65]:
process_no_call.printSchema()
process_no_call.show(n=2, vertical=True, truncate=80)

root
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- species: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0------------------------------------------------------------------------------------------
 track_stem     | XC213642                                                                         
 track_type     | original                                                                         
 start_time     | 55                                                                               
 species        | afmdov1                                                                          
 embedding      | [1.7688512802124023, 0.8326520323753357, 0.4971006214618683, 0.76335829496383... 
 prediction_vec | [-14.449501, -12.099122, -15.6948