In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

## Binary Classifier (call/no call)

In [2]:
from birdclef.utils import get_spark

spark = get_spark(cores=16, memory="20g")
df = spark.read.parquet(
    "../data/processed/birdclef-2023/train_embeddings/consolidated_v3"
    # "../data/processed/birdclef-2023/train_embeddings/consolidated_v4"
)
df.printSchema()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/05/20 16:13:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/05/20 16:13:23 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


                                                                                

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)



### 1. Data Preprocessing

#### 1.1 Positive labels

In [3]:
from pyspark.sql import Window, functions as F

# keep the track_type for the highest energy
highest_energy_channel = (
    df
    # get the track stem without the part
    .withColumn("original_track_stem", F.split(F.col("track_stem"), "_").getItem(0))
    .where("track_type != 'original'")
    # get the track type that has the most energy
    .withColumn(
        "rank",
        F.rank().over(
            Window.partitionBy("original_track_stem").orderBy(F.desc("energy"))
        ),
    )
    # keep the first row
    .where(F.col("rank") == 1)
    # drop the rank column
    .select("species", "track_stem", "track_type")
    .distinct()
)

# get the highest predictions by exploding the values
exploded_embeddings = (
    df
    # join against the highest energy channel
    .join(
        highest_energy_channel,
        on=["species", "track_stem", "track_type"],
        how="inner",
    )
    # explode the embeddings, these are ordered by confidence
    .withColumn("predictions", F.explode("predictions")).select(
        "species",
        "track_stem",
        "track_type",
        "start_time",
        "track_name",
        "embedding",
        "predictions.*",
    )
    # simplifying assumption: we assume the prediction with the highest confidence is the true label
    .where("rank = 0")
).cache()

exploded_embeddings.drop("embedding").show(n=5)

[Stage 4:>                                                          (0 + 1) / 1]

+-------+----------+----------+----------+--------------------+----+-----+--------------------+------------+--------------------+
|species|track_stem|track_type|start_time|          track_name|rank|index|               label|mapped_label|         probability|
+-------+----------+----------+----------+--------------------+----+-----+--------------------+------------+--------------------+
|abythr1|  XC233199|   source0|         0|abythr1/XC233199_...|   0|  639|Chloropsis hardwi...|     orblea1|0.002208352088928...|
|abythr1|  XC233199|   source0|        57|abythr1/XC233199_...|   0| 1151|Erpornis zanthole...|     whbyuh1|0.025502817705273628|
|abythr1|  XC233199|   source0|        27|abythr1/XC233199_...|   0| 3164|Turdus abyssinicu...|     abythr1|0.024902962148189545|
|abythr1|  XC233199|   source0|        30|abythr1/XC233199_...|   0|  639|Chloropsis hardwi...|     orblea1|0.012038093991577625|
|abythr1|  XC233199|   source0|        21|abythr1/XC233199_...|   0| 3185|Turdus leucomela

                                                                                

In [4]:
# quick count of the number of species
counts = (
    exploded_embeddings.groupBy("species")
    .agg(F.count("*").alias("n"))
    .orderBy(F.desc("n"))
)
counts.show(n=5)
counts.orderBy("n").show(n=5)

                                                                                

+-------+-----+
|species|    n|
+-------+-----+
|thrnig1|12987|
| wlwwar| 9249|
|combuz1| 7173|
| hoopoe| 6731|
| barswa| 6191|
+-------+-----+
only showing top 5 rows





+-------+---+
|species|  n|
+-------+---+
|afpkin1|  3|
|whhsaw1|  4|
|whctur2|  4|
|golher1|  5|
|lotlap1|  8|
+-------+---+
only showing top 5 rows



                                                                                

In [5]:
# Prepared DF
rarity_min_count = 100
rare_species_count = (
    exploded_embeddings.groupBy("species")
    .agg(F.count("*").alias("n"))
    .where(f"n < {rarity_min_count}")
)
rare_species_count.show(n=5)

# if there are a lot of examples, we can use a higher threshold
common_species = exploded_embeddings.where("probability > 0.4").join(
    rare_species_count.select("species"), on="species", how="left_anti"
)
# these ones are less common so we use a lower threshold so we have at least one
# example for each species
rare_species = exploded_embeddings.where("probability > 0.05").join(
    rare_species_count.select("species"), on="species", how="inner"
)
prepared = common_species.union(rare_species).select(
    "species", "probability", "embedding"
)
prepared.show(n=5)
prepared.count()

                                                                                

+-------+---+
|species|  n|
+-------+---+
|purgre2| 60|
|bubwar2| 90|
|rehwea1| 69|
|kvbsun1| 80|
|equaka1| 63|
+-------+---+
only showing top 5 rows



                                                                                

+-------+------------------+--------------------+
|species|       probability|           embedding|
+-------+------------------+--------------------+
|afghor1|0.9965255856513977|[0.57833033800125...|
|afghor1| 0.511886715888977|[1.00166213512420...|
|afghor1|0.9984956979751587|[0.88829582929611...|
|afghor1|0.9988522529602051|[1.26016914844512...|
|afghor1|0.9997662901878357|[1.16302716732025...|
+-------+------------------+--------------------+
only showing top 5 rows



                                                                                

74490

In [6]:
# lets check that we have the right number of classes, and how many examples we are working with
prepared_counts = (
    prepared.groupBy("species").agg(F.count("*").alias("n")).orderBy(F.desc("n"))
)
print(f"number of species {prepared_counts.count()}")

prepared_counts.show(n=5)
prepared_counts.orderBy("n").show(n=5)

                                                                                

number of species 264


                                                                                

+-------+----+
|species|   n|
+-------+----+
|thrnig1|3833|
| hoopoe|3822|
|eubeat1|3116|
| wlwwar|2687|
| barswa|2603|
+-------+----+
only showing top 5 rows



                                                                                

+-------+---+
|species|  n|
+-------+---+
|afpkin1|  2|
|whctur2|  2|
|rehblu1|  2|
|whhsaw1|  3|
|easmog1|  4|
+-------+---+
only showing top 5 rows



#### 1.1 Negative labels

In [8]:
# Negative calls
noise_indices = [
    (1022, "Dog_Dog"),
    (1136, "Engine_Engine"),
    (1141, "Environmental_Environmental"),
    (1219, "Fireworks_Fireworks"),
    (1352, "Gun_Gun"),
    (1449, "Human non-vocal_Human non-vocal"),
    (1450, "Human vocal_Human vocal"),
    (1451, "Human whistle_Human whistle"),
    (1997, "Noise_Noise"),
    (2812, "Siren_Siren"),
]
noise_indices = [i[0] for i in noise_indices]

# Craete negative samples DF
negative_samples = (
    df
    # explode the predictions with their indices
    .select(
        "track_name",
        "start_time",
        "embedding",
        F.posexplode("prediction_vec").alias("index", "logit"),
    )
    .where(F.col("index").isin(noise_indices))
    .withColumn("probability", F.expr("1/(1+exp(-logit))"))
    # keep the top 250 of each class
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("index").orderBy(F.desc("probability"))),
    )
    .where("rank <= 250")
    .select("track_name", "start_time", F.lit("no_call").alias("species"), "embedding")
).cache()
negative_samples.show()



+--------------------+----------+-------+--------------------+
|          track_name|start_time|species|           embedding|
+--------------------+----------+-------+--------------------+
|wlwwar/XC237179_p...|        90|no_call|[0.74351072311401...|
|moccha1/XC382190_...|        27|no_call|[0.56722432374954...|
|somgre1/XC476125_...|        84|no_call|[0.65573865175247...|
|ratcis1/XC307172_...|        30|no_call|[0.50343912839889...|
|yertin1/XC633722_...|         3|no_call|[0.80654186010360...|
|augbuz1/XC493837_...|         6|no_call|[0.82505774497985...|
|cohmar1/XC749488_...|        60|no_call|[1.11857354640960...|
|thrnig1/XC494901_...|        45|no_call|[0.46692180633544...|
|egygoo/XC613342_s...|        63|no_call|[0.93724751472473...|
|ratcis1/XC307172_...|         9|no_call|[0.78493213653564...|
|litegr/XC332727_p...|        93|no_call|[0.41296926140785...|
|thrnig1/XC467330_...|        36|no_call|[0.56190323829650...|
|colsun2/XC188947_...|         6|no_call|[0.61920136213

                                                                                

In [9]:
negative_samples.count()

[Stage 185:=>                                                     (5 + 4) / 200]