In [3]:
from birdclef.utils import get_spark
from pyspark.sql import functions as F

# modify cores and memory as needed
spark = get_spark(cores=8, memory="16g")

path = "../../data/processed/birdclef-2023/consolidated_v3/"
df = spark.read.parquet(path)
df.printSchema()
df.count()

                                                                                

root
 |-- species: string (nullable = true)
 |-- track_stem: string (nullable = true)
 |-- track_type: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- prediction_vec: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- rank: long (nullable = true)
 |    |    |-- index: long (nullable = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- mapped_label: string (nullable = true)
 |    |    |-- probability: double (nullable = true)
 |-- start_time: long (nullable = true)
 |-- energy: double (nullable = true)



                                                                                

1198860

In [4]:
from pyspark.sql import Window, functions as F

# keep the track_type for each
highest_energy_channel = (
    df
    # get the track stem without the part
    .withColumn("original_track_stem", F.split(F.col("track_stem"), "_").getItem(0))
    .where("track_type != 'original'")
    # get the track type that has the most energy
    .withColumn(
        "rank",
        F.rank().over(
            Window.partitionBy("original_track_stem").orderBy(F.desc("energy"))
        ),
    )
    # keep the first row
    .where(F.col("rank") == 1)
    # drop the rank column
    .select("species", "track_stem", "track_type")
    .distinct()
)

highest_energy_channel.show()

[Stage 6:>                                                          (0 + 4) / 4]

+-------+----------+----------+
|species|track_stem|track_type|
+-------+----------+----------+
|refcro1|  XC239955|   source0|
|gyhspa1|  XC270259|   source0|
|yebbar1|  XC292826|   source2|
| reccor|  XC312724|   source0|
| litegr|  XC333916|   source2|
|gnbcam2|  XC395639|   source0|
|afpfly1|  XC418708|   source0|
|purgre2|  XC432646|   source2|
|combuz1|  XC463273|   source2|
|walsta1|  XC516711|   source0|
| hoopoe|  XC542705|   source2|
|laudov1|  XC558438|   source2|
|blakit1|  XC572730|   source0|
|afrjac1|  XC585200|   source0|
| comsan|  XC595918|   source2|
|blbpuf2|  XC633870|   source2|
|yertin1|  XC634144|   source1|
|eaywag1|  XC642065|   source0|
|afghor1|  XC720728|   source2|
|afrgos1|  XC147873|   source3|
+-------+----------+----------+
only showing top 20 rows



                                                                                

In [9]:
with open("../../data/models/birdnet-analyzer-pruned/labels.txt") as f:
    labels = [x.strip() for x in f.readlines()]



In [30]:
import json
with open('../../data/models/birdnet-analyzer-pruned/eBird_taxonomy_codes_2021E.json') as json_file:
    mapped = json.load(json_file)

In [43]:
noise = [(i, x) for i,x in enumerate(labels) if "human" in x.lower() or len(x.split(" ")) < 2]
index = [x[0] for x in noise]
print(noise)
print(index)

[(1022, 'Dog_Dog'), (1136, 'Engine_Engine'), (1141, 'Environmental_Environmental'), (1219, 'Fireworks_Fireworks'), (1352, 'Gun_Gun'), (1449, 'Human non-vocal_Human non-vocal'), (1450, 'Human vocal_Human vocal'), (1451, 'Human whistle_Human whistle'), (1997, 'Noise_Noise'), (2812, 'Siren_Siren')]
[1022, 1136, 1141, 1219, 1352, 1449, 1450, 1451, 1997, 2812]


In [42]:
label_df = spark.createDataFrame(
    [
        {
            "label": label,
            "mapped_label": mapped_label,
            "index": i,
        }
        for i, (label, mapped_label) in enumerate(zip(labels, mapped))
    ]
)
label_df.show(n=5)

[Stage 10:>                                                         (0 + 1) / 1]

+-----+--------------------+--------------------+
|index|               label|        mapped_label|
+-----+--------------------+--------------------+
|    0|Abroscopus albogu...|             ostric2|
|    1|Abroscopus superc...|Struthio camelus_...|
|    2|Aburria aburri_Wa...|             ostric3|
|    3|Acanthagenys rufo...|Struthio molybdop...|
|    4|Acanthis cabaret_...|             grerhe1|
+-----+--------------------+--------------------+
only showing top 5 rows



                                                                                

In [47]:
temp = (
    df
    # explode the predictions with their indices
    .select(
        "species",
        "track_name",
        "start_time",
        F.posexplode("prediction_vec").alias("index", "logit"),
    )
    # join with the labels, in case we want to use it for anything
    .join(label_df, on="index", how="inner")
    # now only keep human vocals
    .where("index in (1022, 1136, 1141, 1219, 1352, 1449, 1450, 1451, 1997, 2812)")
    # and convert the logit to a probability via sigmoid
    .withColumn("probability", F.expr("1/(1+exp(-logit))"))
)

temp.show(n=5)

                                                                                

+-----+-------+--------------------+----------+-------------------+-----------+------------+--------------------+
|index|species|          track_name|start_time|              logit|      label|mapped_label|         probability|
+-----+-------+--------------------+----------+-------------------+-----------+------------+--------------------+
| 2812| grecor|grecor/XC629875_p...|        69| -9.145736694335938|Siren_Siren|     stther2|1.066622388732338E-4|
| 2812| grecor|grecor/XC629875_p...|        33| -13.65604305267334|Siren_Siren|     stther2|1.172884773347008...|
| 2812| wlwwar|wlwwar/XC475384_p...|        54|-11.956526756286621|Siren_Siren|     stther2|6.417171116191094E-6|
| 2812| grecor|grecor/XC629875_p...|       126| -8.804743766784668|Siren_Siren|     stther2|1.499972233170989E-4|
| 2812| grecor|grecor/XC629875_p...|       126| -15.09192180633545|Siren_Siren|     stther2|2.790368219906542E-7|
+-----+-------+--------------------+----------+-------------------+-----------+---------

In [49]:
high_prob = temp.where("probability > 0.5")
high_prob.show(n=5)
print(high_prob.count())

                                                                                

+-----+-------+--------------------+----------+-------------------+--------------------+------------+------------------+
|index|species|          track_name|start_time|              logit|               label|mapped_label|       probability|
+-----+-------+--------------------+----------+-------------------+--------------------+------------+------------------+
| 1136| grecor|grecor/XC629875_p...|        39| 0.2663832902908325|       Engine_Engine|     stodov1|0.5662047934675343|
| 1136|combuz1|combuz1/XC579931_...|        51| 1.8115729093551636|       Engine_Engine|     stodov1| 0.859551867154321|
| 1136| grecor|grecor/XC505211_s...|        84| 0.5453298687934875|       Engine_Engine|     stodov1|0.6330514076807028|
| 1450|thrnig1|thrnig1/XC412630_...|        51| 1.8050605058670044|Human vocal_Human...|     obqdov1|0.8587638313808836|
| 1450|afrthr1|afrthr1/XC652884_...|       108|0.11409962177276611|Human vocal_Human...|     obqdov1|0.5284939991900475|
+-----+-------+-----------------

ERROR:root:KeyboardInterrupt while sending command.                 (0 + 0) / 4]
Traceback (most recent call last):
  File "/home/nzhon/.local/lib/python3.10/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/home/nzhon/.local/lib/python3.10/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/usr/lib/python3.10/socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

