# Transforming data for Torch model

In [5]:
%load_ext autoreload
%autoreload 2
from birdclef.utils import get_spark
from pyspark.sql import functions as F

spark = get_spark()
spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/10 16:06:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/06/10 16:06:06 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [2]:
path = "gs://dsgt-clef-birdclef-2024/data/processed/birdclef-2024/asbfly.parquet"
df = spark.read.parquet(path).cache()
df.printSchema()
df.show()
df.count()

                                                                                

root
 |-- name: string (nullable = true)
 |-- chunk_5s: long (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- logits: array (nullable = true)
 |    |-- element: float (containsNull = true)



                                                                                

+------------+--------+--------------------+--------------------+
|        name|chunk_5s|           embedding|              logits|
+------------+--------+--------------------+--------------------+
|XC134896.ogg|       0|[-0.01697171, -0....|[4.367853, -15.69...|
|XC134896.ogg|       1|[0.08178271, -0.1...|[5.7584195, -14.0...|
|XC134896.ogg|       2|[0.15756801, -0.1...|[5.41736, -14.507...|
|XC134896.ogg|       3|[0.07789261, -0.1...|[7.383127, -14.17...|
|XC134896.ogg|       4|[0.0338157, -0.11...|[5.162613, -13.20...|
|XC134896.ogg|       5|[0.027517725, -0....|[4.7014565, -15.4...|
|XC164848.ogg|       0|[-0.037761074, 0....|[-1.5938205, -12....|
|XC164848.ogg|       1|[0.04028651, -0.0...|[-0.54023355, -11...|
|XC164848.ogg|       2|[0.020008465, -0....|[0.059392925, -11...|
|XC164848.ogg|       3|[0.011329669, 0.0...|[-8.847201, -14.5...|
|XC175797.ogg|       0|[-0.06643575, 0.0...|[6.4482026, -7.80...|
|XC175797.ogg|       1|[-2.4496057E-4, 0...|[3.09487, -9.9500...|
|XC175797.

924

In [3]:
from birdclef.transforms import SpeciesData

sp = SpeciesData()
species = sp.read_species_from_file()

In [4]:
def sigmoid_udf(x):
    """
    compute the sigmoid of the columns and only keep
    the rows that are above a certain threshold
    """
    return 1 / (1 + F.exp(-F.col(x)))


def transform(df):
    exploded = (
        df.select(
            F.concat_ws("_", "name", "chunk_5s").alias("id"),
            F.arrays_zip("logits", F.lit(species).alias("species")).alias("logits"),
        )
        .withColumn("exploded", F.explode("logits"))
        .select("id", "exploded.*")
        .withColumn("sigmoid", sigmoid_udf("logits"))
    ).cache()
    # get create column of an array of sigmoid logits grouped by id
    grouped_df = (
        exploded.orderBy("id", "species")
        .groupBy("id")
        .agg(F.collect_list(F.col("sigmoid")).alias("sigmoid_logits"))
    )
    # get embeddings
    df_id = df.select(F.concat_ws("_", "name", "chunk_5s").alias("id"), "embedding")
    final_df = grouped_df.join(df_id, on="id", how="inner")
    return final_df

In [5]:
# transform DF
transformed_df = transform(df)
transformed_df.show(truncate=30)

[Stage 11:>                                                         (0 + 2) / 2]

+--------------+------------------------------+------------------------------+
|            id|                sigmoid_logits|                     embedding|
+--------------+------------------------------+------------------------------+
|XC134896.ogg_0|[0.9874803002619361, 1.5216...|[-0.01697171, -0.14812551, ...|
|XC134896.ogg_1|[0.9968538337772187, 8.0886...|[0.08178271, -0.10625668, 0...|
|XC134896.ogg_2|[0.9955807658044297, 5.0082...|[0.15756801, -0.12693903, 0...|
|XC134896.ogg_3|[0.9993787324137937, 6.9467...|[0.07789261, -0.10598332, 0...|
|XC134896.ogg_4|[0.9943058918488539, 1.8360...|[0.0338157, -0.112703614, 0...|
|XC134896.ogg_5|[0.9909997020094458, 1.8662...|[0.027517725, -0.052032292,...|
|XC164848.ogg_0|[0.16884706474015923, 5.010...|[-0.037761074, 0.0027916778...|
|XC164848.ogg_1|[0.3681332536829722, 1.0093...|[0.04028651, -0.0765781, 0....|
|XC164848.ogg_2|[0.514843868091535, 7.68173...|[0.020008465, -0.10142819, ...|
|XC164848.ogg_3|[1.4376290324810643E-4, 4.7...|[0.01

                                                                                

In [6]:
print(len(transformed_df.first().sigmoid_logits))
print(len(transformed_df.first().embedding))

182
1280


### import intermediate google embeddings data

In [7]:
root_path = "gs://dsgt-clef-birdclef-2024/data"
emb_path = "intermediate/google_embeddings/v1"
gcs_path = f"{root_path}/{emb_path}"
df = spark.read.parquet(gcs_path).cache()
df.printSchema()
df.show()
df.count()

root
 |-- name: string (nullable = true)
 |-- chunk_5s: long (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- logits: array (nullable = true)
 |    |-- element: float (containsNull = true)





+--------------------+--------+--------------------+--------------------+
|                name|chunk_5s|           embedding|              logits|
+--------------------+--------+--------------------+--------------------+
|blrwar1/XC108661.ogg|       0|[-0.0023240745, 0...|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       1|[-0.06323654, 0.0...|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       2|[-0.10394307, 0.0...|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       3|[-0.02910586, 0.0...|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       4|[-0.072895475, 0....|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       5|[-0.07493828, -0....|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       6|[-0.033370227, 0....|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       7|[-0.05588576, 0.0...|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       8|[-0.025152182, 0....|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|       9|[-0.0033844702, 0...|[-Infinity, -Infi...|
|blrwar1/XC108661.ogg|      10|[-0.052

                                                                                

217814

In [8]:
# get parameters for the model
num_features = int(len(df.select("embedding").first()["embedding"]))
num_classes = int(len(df.select("sigmoid_logits").first()["sigmoid_logits"]))
print(f"num features: {num_features}")
print(f"num classes: {num_classes}")

num features: 1280
num classes: 182


In [6]:
import itertools


class HyperparameterGrid:
    def get_hyperparameter_config(self):
        # Model and Loss mappings
        model_params = {
            "linear": "LinearClassifier",
            "two_layer": "TwoLayerClassifier",
        }
        loss_params = {
            "bce": {},
            "asl": {
                "gamma_neg": [0, 2, 4],
                "gamma_pos": [0, 1],
            },
            "sigmoidf1": {
                "S": [-1, -15, -30],
                "E": [0, 1, 2],
            },
        }
        hidden_layers = [64, 128, 256]
        return model_params, loss_params, hidden_layers


def generate_loss_hp_params(loss_params):
    """Generate all combinations of hyperparameters for a given loss function."""
    if not loss_params:
        return [{}]

    keys, values = zip(*loss_params.items())
    combinations = [
        dict(zip(keys, combination)) for combination in itertools.product(*values)
    ]
    return combinations

In [7]:
hp = HyperparameterGrid()
model_params, loss_params, hidden_layers = hp.get_hyperparameter_config()
default_root_dir = "torch-v1-google"
hidden_layer_size = 256

for loss in loss_params:
    for hp_params in generate_loss_hp_params(loss_params[loss]):
        param_log = [f"{k}{v}" for k, v in hp_params.items()]
        if len(param_log) > 0:
            param_name = "-".join(param_log)
            print(param_name)

gamma_neg0-gamma_pos0
gamma_neg0-gamma_pos1
gamma_neg2-gamma_pos0
gamma_neg2-gamma_pos1
gamma_neg4-gamma_pos0
gamma_neg4-gamma_pos1
S-1-E0
S-1-E1
S-1-E2
S-15-E0
S-15-E1
S-15-E2
S-30-E0
S-30-E1
S-30-E2


### Add `1` for the current label of the species

In [9]:
root_path = "gs://dsgt-clef-birdclef-2024/data"
emb_path = "processed/google_embeddings/v1"
gcs_path = f"{root_path}/{emb_path}"
df = spark.read.parquet(gcs_path).cache()
df.printSchema()
df.show()
df.count()

                                                                                

root
 |-- name: string (nullable = true)
 |-- chunk_5s: long (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- logits: array (nullable = true)
 |    |-- element: float (containsNull = true)



                                                                                

+--------------------+--------+--------------------+--------------------+
|                name|chunk_5s|           embedding|              logits|
+--------------------+--------+--------------------+--------------------+
|blrwar1/XC650323.ogg|       5|[0.0041362955, 0....|[-Infinity, -Infi...|
|blrwar1/XC560248.ogg|      10|[-0.037416093, 0....|[-Infinity, -Infi...|
|blrwar1/XC570070.ogg|      11|[-0.033473324, 0....|[-Infinity, -Infi...|
|blrwar1/XC826779.ogg|     108|[0.0074216262, -0...|[-Infinity, -Infi...|
|blrwar1/XC134920.ogg|      23|[-0.028284838, -0...|[-Infinity, -Infi...|
|blrwar1/XC431699.ogg|       3|[-0.06691173, 0.3...|[-Infinity, -Infi...|
|blrwar1/XC826057.ogg|     520|[0.037868567, -0....|[-Infinity, -Infi...|
|blrwar1/XC270114.ogg|       4|[0.20961681, -0.1...|[-Infinity, -Infi...|
|blrwar1/XC808169.ogg|      17|[-0.006478134, -0...|[-Infinity, -Infi...|
|blrwar1/XC480057.ogg|       0|[0.04627442, -0.1...|[-Infinity, -Infi...|
|blrwar1/XC826043.ogg|     228|[-0.002

217814

In [8]:
from birdclef.config import SPECIES
from birdclef.transforms import SpeciesData

sp1 = SPECIES
sp2 = SpeciesData().get_species()
print(len(sp1))
print(len(sp2))

182
182


In [29]:
from pathlib import Path


# UDF to retrieve species inddex
def species_index(name):
    species = Path(name).parent.name
    species_index = SPECIES.index(species)
    return species_index


# get species
filename = "blrwar1/XC650323.ogg"
sp_idx = species_index(name=filename)
sp_idx

20

In [44]:
from pyspark.sql.types import IntegerType

# Register the UDF
species_index_udf = F.udf(species_index, IntegerType())
df_species = df.withColumn("species_idx", species_index("name"))
df_species.show()

ValueError: '' is not in list

In [38]:
print(
    f"\ntrain count: {df_species.count()}\n" f"valid count: {df_species.count()}\n",
)


train count: 217814
valid count: 217814



In [43]:
import random

name = f"{random.choice(SPECIES)}/XC123456.ogg"
name

'goflea1/XC123456.ogg'