In [6]:
%load_ext autoreload
%autoreload 2
from animalclef.spark import get_spark
from pyspark.sql import functions as F, Window
from pathlib import Path
import numpy as np
from animalclef.dataset import split_reid_data, summarize_split

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
spark = get_spark(cores=4, memory="10g")
display(spark)

metadata = spark.read.parquet(f"{Path.home()}/shared/animalclef/data/parquet/metadata")
embeddings = spark.read.parquet(
    f"{Path.home()}/shared/animalclef/data/embeddings/dinov2"
)
df = metadata.join(embeddings, on="image_id").select(
    "image_id",
    "identity",
    "dataset",
    "token.*",
    F.count("image_id").over(Window.partitionBy("identity")).alias("identity_count"),
)
df.printSchema()
pdf = df.toPandas()

25/03/13 02:58:14 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


root
 |-- image_id: integer (nullable = true)
 |-- identity: string (nullable = true)
 |-- dataset: string (nullable = true)
 |-- cls: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- avg_patch: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- identity_count: long (nullable = false)



In [8]:
cond = (~pdf.identity.isnull()) & (pdf.identity_count > 2)
train_df, val_df, test_df = split_reid_data(pdf[cond])
summarize_split(train_df, val_df, test_df)

Unnamed: 0,Split,Num Individuals,Num Images,Train Image Overlap,Val Image Overlap,Test Image Overlap,Train Image %,Val Image %,Test Image %,Known Individuals,Unknown Individuals
0,Train,404,3392,3392,0,0,100.0,0.0,0.0,404,0
1,Validation,458,2575,0,2575,0,0.0,100.0,0.0,404,54
2,Test,620,6568,0,0,6568,0.0,0.0,100.0,404,216


### k-nn model

In [9]:
import faiss
from animalclef.metrics import BAKS, BAUS

X_train = np.stack(train_df.cls.values)
X_val = np.stack(val_df.cls.values)
X_test = np.stack(test_df.cls.values)

# Create a FAISS index for efficient nearest neighbor search
index = faiss.IndexFlatL2(X_train.shape[1])
index.add(X_train)  # Add training embeddings to the index

# Perform a search for the validation set
# use the nearest neighbor for now for voting
k = 1
# Distances and indices for validation set
dist_val, idx_val = index.search(X_val, k)
display(dist_val)
# Calculate the accuracy for validation and test sets
# do the actual prediction

# identities in val not in train
predictions_val = train_df.iloc[idx_val.flatten()]["identity"].values
identity_val_only = sorted(
    set(val_df.identity.unique()) - set(train_df.identity.unique())
)

display(
    BAKS(val_df["identity"].values, predictions_val, identity_val_only),
    BAUS(val_df["identity"].values, predictions_val, identity_val_only, "unknown"),
)

array([[ 321.9624 ],
       [ 952.66895],
       [1168.0774 ],
       ...,
       [1026.0625 ],
       [ 580.9336 ],
       [ 623.37305]], dtype=float32)

np.float64(0.30533625990631974)

np.float64(0.0)