In [6]:
from birdclef.utils import get_spark
from pyspark.sql import Window, functions as F
import os

os.environ["SPARK_LOCAL_DIRS"] = "h://spark-tmp/"

spark = get_spark(cores=8, memory="8g")
train_df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_postprocessed/v3"
).cache()
train_df.printSchema()

birdclef_root = "../data/raw/birdclef-2023"
train_metadata = spark.read.csv(f"{birdclef_root}/train_metadata.csv", header=True)
train_metadata.printSchema()

root
 |-- track_stem: string (nullable = true)
 |-- start_time: long (nullable = true)
 |-- metadata_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- predicted_species_prob: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- next_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- track_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)

root
 |-- primary_label: string (nullable = true)
 |-- secondary_labels: string (nullable = true)
 |-- type: string (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)
 |-- scientific_name: string (nullable = true)
 |-- common_name: string (nullable = true)
 |-- author: string (nullable = true)
 |-- 

In [5]:
train_df.show(n=2, vertical=True, truncate=80)

-RECORD 0--------------------------------------------------------------------------------------------------
 track_stem             | XC108820                                                                         
 start_time             | 0                                                                                
 metadata_species       | [categr]                                                                         
 predicted_species      | [categr, cohmar1, sichor1]                                                       
 predicted_species_prob | [0.9218195065749416, 0.7636727008861889, 0.6866624362805946]                     
 embedding              | [1.6699009, 1.9233328, 0.37807903, 1.2282478, 1.5084257, 1.9636716, 0.5422926... 
 next_embedding         | [1.6699009, 1.9233328, 0.37807903, 1.2282478, 1.5084257, 1.9636716, 0.5422926... 
 track_embedding        | [1.6699009, 1.9233328, 0.37807903, 1.2282478, 1.5084257, 1.9636716, 0.5422926... 
-RECORD 1-------------------

In [18]:
# let's use a track as an example of everything that is wrong with v2
example = (
    train_df.withColumn(
        "predict_count",
        F.sum(F.size("predicted_species")).over(Window.partitionBy("track_stem")),
    )
    .withColumn("rank", F.rank().over(Window.orderBy(F.desc("predict_count"))))
    .where("rank = 1")
    .orderBy("start_time")
    .select(
        "track_stem",
        "start_time",
        "predicted_species",
        "metadata_species",
        "predict_count",
    )
).cache()
example.show()

+----------------+----------+--------------------+----------------+-------------+
|      track_stem|start_time|   predicted_species|metadata_species|predict_count|
+----------------+----------+--------------------+----------------+-------------+
|XC661001_part001|         0|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        10|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        20|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        30|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        40|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        50|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        50|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        50|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part001|        60|[gabgos2, gabgos2...|       [gabgos2]|           92|
|XC661001_part00

In [19]:
pdf = example.toPandas()
pdf

Unnamed: 0,track_stem,start_time,predicted_species,metadata_species,predict_count
0,XC661001_part001,0,"[gabgos2, gabgos2, gabgos2, gabgos2, yefcan]",[gabgos2],92
1,XC661001_part001,10,"[gabgos2, gabgos2, gabgos2, gabgos2]",[gabgos2],92
2,XC661001_part001,20,"[gabgos2, gabgos2, gabgos2, gabgos2]",[gabgos2],92
3,XC661001_part001,30,"[gabgos2, gabgos2, gabgos2, gabgos2, blakit1]",[gabgos2],92
4,XC661001_part001,40,"[gabgos2, gabgos2, gabgos2, gabgos2]",[gabgos2],92
5,XC661001_part001,50,"[gabgos2, gabgos2, gabgos2, barswa, gabgos2]",[gabgos2],92
6,XC661001_part001,50,"[gabgos2, gabgos2, gabgos2, barswa, gabgos2]",[gabgos2],92
7,XC661001_part001,50,"[gabgos2, gabgos2, gabgos2, barswa, gabgos2]",[gabgos2],92
8,XC661001_part001,60,"[gabgos2, gabgos2, gabgos2, gabgos2]",[gabgos2],92
9,XC661001_part001,70,"[gabgos2, gabgos2, gabgos2, gabgos2]",[gabgos2],92


In [8]:
def xc_id(track_name):
    return F.regexp_extract(track_name, r"XC(\d+)", 1)


res_df = train_df.withColumn("xc_id", xc_id("track_stem")).join(
    train_metadata.select(xc_id("filename").alias("xc_id"), "primary_label"), on="xc_id"
)

res_df.show(n=2, vertical=True, truncate=80)

-RECORD 0--------------------------------------------------------------------------------------------------
 xc_id                  | 108820                                                                           
 track_stem             | XC108820                                                                         
 start_time             | 0                                                                                
 metadata_species       | [categr]                                                                         
 predicted_species      | [categr, cohmar1, sichor1]                                                       
 predicted_species_prob | [0.9218195065749416, 0.7636727008861889, 0.6866624362805946]                     
 embedding              | [1.6699009, 1.9233328, 0.37807903, 1.2282478, 1.5084257, 1.9636716, 0.5422926... 
 next_embedding         | [1.6699009, 1.9233328, 0.37807903, 1.2282478, 1.5084257, 1.9636716, 0.5422926... 
 track_embedding        | [1