# Transforming data for Torch model

In [2]:
%load_ext autoreload
%autoreload 2
from birdclef.utils import get_spark
from pyspark.sql import functions as F

spark = get_spark()
spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/06 19:44:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/06/06 19:44:14 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
path = "gs://dsgt-clef-birdclef-2024/data/processed/birdclef-2024/asbfly.parquet"
df = spark.read.parquet(path).cache()
df.printSchema()
df.show()
df.count()

                                                                                

root
 |-- name: string (nullable = true)
 |-- chunk_5s: long (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- logits: array (nullable = true)
 |    |-- element: float (containsNull = true)



                                                                                

+------------+--------+--------------------+--------------------+
|        name|chunk_5s|           embedding|              logits|
+------------+--------+--------------------+--------------------+
|XC134896.ogg|       0|[-0.01697171, -0....|[4.367853, -15.69...|
|XC134896.ogg|       1|[0.08178271, -0.1...|[5.7584195, -14.0...|
|XC134896.ogg|       2|[0.15756801, -0.1...|[5.41736, -14.507...|
|XC134896.ogg|       3|[0.07789261, -0.1...|[7.383127, -14.17...|
|XC134896.ogg|       4|[0.0338157, -0.11...|[5.162613, -13.20...|
|XC134896.ogg|       5|[0.027517725, -0....|[4.7014565, -15.4...|
|XC164848.ogg|       0|[-0.037761074, 0....|[-1.5938205, -12....|
|XC164848.ogg|       1|[0.04028651, -0.0...|[-0.54023355, -11...|
|XC164848.ogg|       2|[0.020008465, -0....|[0.059392925, -11...|
|XC164848.ogg|       3|[0.011329669, 0.0...|[-8.847201, -14.5...|
|XC175797.ogg|       0|[-0.06643575, 0.0...|[6.4482026, -7.80...|
|XC175797.ogg|       1|[-2.4496057E-4, 0...|[3.09487, -9.9500...|
|XC175797.

924

In [4]:
from birdclef.transforms import SpeciesData

sp = SpeciesData()
species = sp.read_species_from_file()

In [5]:
def sigmoid_udf(x):
    """
    compute the sigmoid of the columns and only keep
    the rows that are above a certain threshold
    """
    return 1 / (1 + F.exp(-F.col(x)))


def transform(df):
    exploded = (
        df.select(
            F.concat_ws("_", "name", "chunk_5s").alias("id"),
            F.arrays_zip("logits", F.lit(species).alias("species")).alias("logits"),
        )
        .withColumn("exploded", F.explode("logits"))
        .select("id", "exploded.*")
        .withColumn("sigmoid", sigmoid_udf("logits"))
    ).cache()
    # get create column of an array of sigmoid logits grouped by id
    grouped_df = (
        exploded.orderBy("id", "species")
        .groupBy("id")
        .agg(F.collect_list(F.col("sigmoid")).alias("sigmoid_logits"))
    )
    # get embeddings
    df_id = df.select(F.concat_ws("_", "name", "chunk_5s").alias("id"), "embedding")
    final_df = grouped_df.join(df_id, on="id", how="inner")
    return final_df

In [6]:
# transform DF
transformed_df = transform(df)
transformed_df.show(truncate=30)



+--------------+------------------------------+------------------------------+
|            id|                sigmoid_logits|                     embedding|
+--------------+------------------------------+------------------------------+
|XC134896.ogg_0|[0.9874803002619361, 1.5216...|[-0.01697171, -0.14812551, ...|
|XC134896.ogg_1|[0.9968538337772187, 8.0886...|[0.08178271, -0.10625668, 0...|
|XC134896.ogg_2|[0.9955807658044297, 5.0082...|[0.15756801, -0.12693903, 0...|
|XC134896.ogg_3|[0.9993787324137937, 6.9467...|[0.07789261, -0.10598332, 0...|
|XC134896.ogg_4|[0.9943058918488539, 1.8360...|[0.0338157, -0.112703614, 0...|
|XC134896.ogg_5|[0.9909997020094458, 1.8662...|[0.027517725, -0.052032292,...|
|XC164848.ogg_0|[0.16884706474015923, 5.010...|[-0.037761074, 0.0027916778...|
|XC164848.ogg_1|[0.3681332536829722, 1.0093...|[0.04028651, -0.0765781, 0....|
|XC164848.ogg_2|[0.514843868091535, 7.68173...|[0.020008465, -0.10142819, ...|
|XC164848.ogg_3|[1.4376290324810643E-4, 4.7...|[0.01

                                                                                

In [7]:
print(len(transformed_df.first().sigmoid_logits))

182


### import processed data

In [9]:
root_path = "gs://dsgt-clef-birdclef-2024/data"
emb_path = "processed/birdclef-2024-train-google-embedding/data"
gcs_path = f"{root_path}/{emb_path}"
df = spark.read.parquet(gcs_path).cache()
df.printSchema()
df.show()
df.count()

root
 |-- id: string (nullable = true)
 |-- sigmoid_logits: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)



                                                                                

+---------------+--------------------+--------------------+---------+
|             id|      sigmoid_logits|           embedding|sample_id|
+---------------+--------------------+--------------------+---------+
|XC724148.ogg_15|[0.01680772021377...|[-0.0068491795, 0...|        3|
|XC313835.ogg_10|[6.23079547323694...|[-0.012723277, -0...|        3|
|XC756431.ogg_13|[0.05752642597570...|[-0.062239904, -0...|        3|
|XC724266.ogg_29|[3.85630487725207...|[-0.049252495, 0....|        3|
|XC724148.ogg_28|[6.42797297697613...|[0.03280939, -0.0...|        3|
|XC821773.ogg_11|[0.00577954761659...|[0.019143827, 0.3...|        3|
|XC724266.ogg_11|[0.03024041125066...|[0.110734664, 0.0...|        3|
|XC342037.ogg_22|[7.67365900528079...|[0.0068388185, 0....|        3|
|XC374520.ogg_16|[0.54916403821089...|[0.026325822, 0.0...|        3|
|XC821773.ogg_14|[0.01075345895269...|[-0.016234195, 0....|        3|
|XC305518.ogg_18|[0.64899250523639...|[-0.072603844, 0....|        3|
|XC484489.ogg_14|[0.

924