In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [9]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "h://spark-tmp/"

spark = get_spark(cores=24, memory="40g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v4"
)
df.printSchema()

train_df = spark.read.parquet("../data/processed/birdclef-2023/train_postprocessed/v2")
train_df.printSchema()

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)

root
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = true)
 |    |

In [23]:
@F.pandas_udf("array<float>", F.PandasUDFType.SCALAR)
def embedding_pair_mean(v1, v2, weight=0.5):
    return (v1 * weight + v2 * (1 - weight)) / 2


noise_examples = (
    df.withColumn("probability", F.col("predictions")[0]["probability"])
    .where("start_time % 3 = 0 and probability < 0.01")
    .orderBy("probability")
    .limit(1000)
    # take the best example for each track here
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("track_name").orderBy("probability")),
    )
    .where("rank = 1")
    .drop("rank")
    .select("embedding")
)

weight = F.lit(0.75)
augmented_examples = (
    train_df.withColumn("species", F.col("metadata_species")[0])
    .withColumn(
        "rank", F.row_number().over(Window.partitionBy("species").orderBy(F.rand()))
    )
    .where("rank <= 20")
    # now lets cross join with the noise examples
    .crossJoin(noise_examples.selectExpr("embedding as noise_embedding").limit(20))
    # now randomly keep a subset of these examples
    .withColumn(
        "rank", F.row_number().over(Window.partitionBy("species").orderBy(F.rand()))
    )
    .where("rank <= 20")
    .select(
        "track_stem",
        "start_time",
        "metadata_species",
        "predicted_species",
        "predicted_species_prob",
        embedding_pair_mean("embedding", "noise_embedding", weight).alias("embedding"),
        embedding_pair_mean("next_embedding", "noise_embedding", weight).alias(
            "next_embedding"
        ),
        embedding_pair_mean("track_embedding", "noise_embedding", weight).alias(
            "track_embedding"
        ),
    )
)

processed_df = augmented_examples.union(train_df)

In [24]:
processed_df.write.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v3", mode="overwrite"
)

In [25]:
processed_df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v3"
)
processed_df.groupBy(F.col("metadata_species")[0]).count().orderBy("count").show()
processed_df.count()

+-------------------+-----+
|metadata_species[0]|count|
+-------------------+-----+
|            whctur2|   21|
|            afpkin1|   21|
|            whhsaw1|   21|
|            golher1|   22|
|            brtcha1|   22|
|            rehblu1|   23|
|            lotlap1|   23|
|            crefra2|   24|
|            dotbar1|   24|
|            lotcor1|   24|
|            brcwea1|   25|
|            yebsto1|   26|
|            palpri1|   27|
|            fatwid1|   28|
|            gobsta5|   29|
|            darter3|   30|
|            sacibi2|   32|
|            witswa1|   34|
|            rostur1|   34|
|            bltbar1|   34|
+-------------------+-----+
only showing top 20 rows



76813