In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from birdclef.utils import get_spark

spark = get_spark()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/01/03 01:40:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [6]:
# lets combine several tables to get all the metadata that we need

birdnet_analyze = spark.read.parquet(
    "../data/processed/birdclef-2022/birdnet-analyze.parquet"
)
birdnet_embeddings = spark.read.parquet(
    "../data/processed/birdclef-2022/birdnet-embeddings.parquet"
)
train_metadata = spark.read.csv(
    "../data/raw/birdclef-2022/train_metadata.csv", header=True
)

birdnet_analyze.printSchema()
birdnet_analyze.show(1, vertical=True, truncate=80)
birdnet_embeddings.printSchema()
birdnet_embeddings.show(1, vertical=True, truncate=80)
train_metadata.printSchema()
train_metadata.show(1, vertical=True, truncate=80)

root
 |-- start_sec: double (nullable = true)
 |-- end_sec: double (nullable = true)
 |-- confidence: double (nullable = true)
 |-- birdnet_label: string (nullable = true)
 |-- birdnet_common_name: string (nullable = true)
 |-- filename: string (nullable = true)

-RECORD 0-----------------------------------
 start_sec           | 0.0                  
 end_sec             | 3.0                  
 confidence          | 0.903                
 birdnet_label       | afrsil1              
 birdnet_common_name | African Silverbill   
 filename            | afrsil1/XC125458.ogg 
only showing top 1 row

root
 |-- start_sec: float (nullable = true)
 |-- end_sec: float (nullable = true)
 |-- filename: string (nullable = true)
 |-- emb: array (nullable = true)
 |    |-- element: float (containsNull = true)

-RECORD 0-------------------------------------------------------------------------------------
 start_sec | 0.0                                                                              
 e

In [7]:
birdnet_analyze.count(), birdnet_embeddings.count(), train_metadata.count()

(280222, 227660, 14852)

In [8]:
# TODO: why is there this difference?
birdnet_analyze.join(
    birdnet_embeddings, on=["filename", "start_sec", "end_sec"]
).count()

276866

In [13]:
from pyspark.sql import functions as F, Window

# joined, also with a canonical rank
birdnet_joined = (
    birdnet_analyze.join(birdnet_embeddings, on=["filename", "start_sec", "end_sec"])
    .join(
        train_metadata.select("filename", "primary_label", "secondary_labels", "type"),
        on="filename",
    )
    .orderBy("filename", "start_sec")
    .withColumn("id", F.monotonically_increasing_id())
)
birdnet_joined.show()



+--------------------+---------+-------+----------+-------------+--------------------+--------------------+-------------+--------------------+--------------------+---+
|            filename|start_sec|end_sec|confidence|birdnet_label| birdnet_common_name|                 emb|primary_label|    secondary_labels|                type| id|
+--------------------+---------+-------+----------+-------------+--------------------+--------------------+-------------+--------------------+--------------------+---+
|afrsil1/XC125458.ogg|      0.0|    3.0|     0.472|       indsil|   Indian Silverbill|[0.50174934, 0.45...|      afrsil1|                  []|['call', 'flight ...|  0|
|afrsil1/XC125458.ogg|      0.0|    3.0|     0.903|      afrsil1|  African Silverbill|[0.50174934, 0.45...|      afrsil1|                  []|['call', 'flight ...|  1|
|afrsil1/XC125458.ogg|      3.0|    6.0|    0.7311|      afrsil1|  African Silverbill|[1.3441721, 0.464...|      afrsil1|                  []|['call', 'flight .

                                                                                

In [14]:
# now lets generate nearest neighbors for each of the embeddings
# we'll use pynndescent to do this
import numpy as np

# convert the embedding to a numpy matrix that we can run nn-descent over
X = np.stack(birdnet_joined.select("emb").toPandas().emb)
X.shape

                                                                                

(276748, 320)

In [15]:
# size in gb
X.nbytes / 1024**3

0.3299093246459961

In [16]:
from pynndescent import NNDescent

n_neighbors = 20
index = NNDescent(X, n_neighbors=n_neighbors, verbose=True)

Tue Jan  3 02:06:11 2023 Building RP forest with 28 trees
Tue Jan  3 02:06:14 2023 NN descent for 18 iterations
	 1  /  18
	 2  /  18
	 3  /  18
	 4  /  18
	 5  /  18
	 6  /  18
	 7  /  18
	Stopping threshold met -- exiting after 7 iterations


In [25]:
neighbors, distances = index.query(X, k=n_neighbors)
neighbors.shape, distances.shape

((276748, 20), (276748, 20))

In [26]:
import pandas as pd


def index_query_to_pandas(neighbors, distances) -> pd.DataFrame:
    """Convert the output of index.query to a pandas dataframe."""
    res = []
    for origin, (neighbor, distance) in enumerate(zip(neighbors, distances)):
        for n, d in zip(neighbor, distance):
            res.append(dict(origin=origin, neighbor=n, distance=d))
    return pd.DataFrame(res)


index_query_to_pandas(neighbors, distances).head()

Unnamed: 0,origin,neighbor,distance
0,0,0,0.0
1,0,1,0.0
2,0,3,6.5489
3,0,198010,7.179422
4,0,150,7.291882


In [34]:
query_df = pd.DataFrame(
    dict(
        neighbors=neighbors.tolist(),
        distances=distances.tolist(),
        id=np.arange(neighbors.shape[0]),
    )
)

birdnet_neighbors = birdnet_joined.join(spark.createDataFrame(query_df), on="id")
birdnet_neighbors.show()

23/01/03 02:19:29 WARN TaskSetManager: Stage 78 contains a task of very large size (5680 KiB). The maximum recommended task size is 1000 KiB.


[Stage 85:===>                                                    (1 + 16) / 17]

+---+--------------------+---------+-------+----------+-------------+--------------------+--------------------+-------------+--------------------+--------------------+--------------------+--------------------+
| id|            filename|start_sec|end_sec|confidence|birdnet_label| birdnet_common_name|                 emb|primary_label|    secondary_labels|                type|           neighbors|           distances|
+---+--------------------+---------+-------+----------+-------------+--------------------+--------------------+-------------+--------------------+--------------------+--------------------+--------------------+
|  0|afrsil1/XC125458.ogg|      0.0|    3.0|     0.472|       indsil|   Indian Silverbill|[0.50174934, 0.45...|      afrsil1|                  []|['call', 'flight ...|[0, 1, 3, 198010,...|[0.0, 0.0, 6.5489...|
|  7|afrsil1/XC175522.ogg|      0.0|    3.0|     0.578|       indsil|   Indian Silverbill|[0.7229609, 1.923...|      afrsil1|['houspa', 'redav...|            ['

                                                                                