In [None]:
submission_files = [
    "./data/submission_5h1l_tile_4_5_overlaps_0_0_topk_1_minscore_002_topn_10_use_gf_crop_010.csv",
    "./data/submission_5h1l_tile_4_5_overlaps_0_0_topk_1_minscore_0015_topn_10_use_gf_crop_010.csv",
]

assert len(submission_files) == 2

In [None]:
from collections import defaultdict
import csv
import pandas as pd
import ast

# Load submission files into DataFrames
dfs = [
    pd.read_csv(file, converters={"species_ids": ast.literal_eval})
    for file in submission_files
]

In [None]:
predictions: dict[str, list[list[int]]] = defaultdict(list)
for df in dfs:
    for quadrat_id, species_ids in zip(df["quadrat_id"], df["species_ids"]):
        predictions[quadrat_id].append(species_ids)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

similarities: dict[str, float] = {}
for quadrat_id, species_ids_list in predictions.items():
    assert len(species_ids_list) == 2
    lhs_species_ids, rhs_species_ids = species_ids_list

    # Convert species IDs to binary vectors
    all_species = list(set(lhs_species_ids + rhs_species_ids))
    lhs_vector = np.array([1 if species in lhs_species_ids else 0 for species in all_species])
    rhs_vector = np.array([1 if species in rhs_species_ids else 0 for species in all_species])

    # Compute cosine similarity
    similarity = cosine_similarity([lhs_vector], [rhs_vector])[0][0]
    similarities[quadrat_id] = similarity

In [None]:
# Compute and print statistics about the similarities
import statistics

mean_similarity = statistics.mean(similarities.values())
median_similarity = statistics.median(similarities.values())
min_similarity = min(similarities.values())
max_similarity = max(similarities.values())
variance_similarity = statistics.variance(similarities.values())
stdev_similarity = statistics.stdev(similarities.values())

# Compute interquartile range (IQR)
sorted_similarities = sorted(similarities.values())
q1 = sorted_similarities[len(sorted_similarities) // 4]
q3 = sorted_similarities[3 * len(sorted_similarities) // 4]
iqr_similarity = q3 - q1

print(f"Mean Similarity: {mean_similarity}")
print(f"Median Similarity: {median_similarity}")
print(f"Min Similarity: {min_similarity}")
print(f"Max Similarity: {max_similarity}")
print(f"Variance: {variance_similarity}")
print(f"Standard Deviation: {stdev_similarity}")
print(f"Interquartile Range (IQR): {iqr_similarity}")