In [None]:
import os
import numpy as np
from numpy.typing import ArrayLike
import pandas as pd
import csv

In [None]:
TARGET_MEAN_PREDICTION_LENGTH = 5.0
MIN_MIN_SCORE, MAX_MIN_SCORE = 0.001, 1.0

In [None]:
# Directory containing the .npy files
directories = [
'./predictions/probabilities_submission_5h1l_tile_4_5_overlaps_0_0_use_gf_crop_010',
'./predictions/probabilities_submission_hydra_5h1l_s_5h2l_gf_tile_45_00_usegf_crop10',
]

all_quadrat_probs: list[dict[str, ArrayLike]] = []

for directory in directories:
    # List to store the loaded data
    quadrat_probs: dict[str, ArrayLike] = {}

    # Iterate through all files in the directory
    for filename in os.listdir(directory):
        if filename.endswith('.npy'):
            file_path = os.path.join(directory, filename)
            quadrat_id = filename.split('.')[0]
            tile_probabilities = np.load(file_path)

            quadrat_probs[quadrat_id] = tile_probabilities

    all_quadrat_probs.append(quadrat_probs)

In [None]:
from src import data

plant_data_image_info, rare_species = data.get_plant_data_image_info(
    os.path.join(
        "/mnt/storage1/shared_data/plant_clef_2025/",
        "data/",
        "plant_clef_train_281gb/",
    ),
)

species_id_to_index = {
            sid: idx
            for idx, sid in enumerate(
                sorted({info.species_id for info in plant_data_image_info})
            )
        }
species_index_to_id = {idx: sid for sid, idx in species_id_to_index.items()}

In [None]:
def mean_prediction_length(image_predictions: dict[str, list[int]]) -> float:
    lengths = [len(predictions) for predictions in image_predictions.values()]
    return np.mean(lengths)

In [None]:
assert all(quadrat_probs.keys() == all_quadrat_probs[0].keys() for quadrat_probs in all_quadrat_probs[1:]), "All quadrat probabilities should have the same keys"

quadrat_probs = {quadrat_id: sum([quadrat_prob[quadrat_id] / len(all_quadrat_probs) for quadrat_prob in all_quadrat_probs]) for quadrat_id in all_quadrat_probs[0].keys()}

def make_prediction(min_probability: float) -> dict[str, list[int]]:
    image_predictions: dict[str, list[int]] = {}

    for quadrat_id, tile_probabilities in quadrat_probs.items():
        top_species = set()
        for _, tile_probs in enumerate(tile_probabilities):
            max_index = tile_probs.argmax()
            if tile_probs[max_index] < min_probability: continue
            top_species.add(max_index)
        image_predictions[quadrat_id] = list(top_species)
        if len(image_predictions[quadrat_id]) == 0:
            column_sums = np.sum(tile_probabilities, axis=0)
            image_predictions[quadrat_id] = [column_sums.argmax()]

    return image_predictions

In [None]:
bound = (MAX_MIN_SCORE, mean_prediction_length(make_prediction(MAX_MIN_SCORE))), (MIN_MIN_SCORE, mean_prediction_length(make_prediction(MIN_MIN_SCORE)))

assert bound[0][1] < TARGET_MEAN_PREDICTION_LENGTH < bound[1][1], f"Target mean prediction length {TARGET_MEAN_PREDICTION_LENGTH} is not between the bounds {bound[0][1]} and {bound[1][1]}"

for _ in range(50):
    mean_score = (bound[0][0] + bound[1][0]) / 2
    image_predictions = make_prediction(mean_score)
    mean_length = mean_prediction_length(image_predictions)
    if mean_length < TARGET_MEAN_PREDICTION_LENGTH:
        bound = ((mean_score, mean_length), bound[1])
    else:
        bound = (bound[0], (mean_score, mean_length))
    if mean_length == TARGET_MEAN_PREDICTION_LENGTH:
        break
    print(bound)

image_predictions = make_prediction(bound[0][0] if abs(bound[0][1] - TARGET_MEAN_PREDICTION_LENGTH) < abs(bound[1][1] - TARGET_MEAN_PREDICTION_LENGTH) else bound[1][0])

print(mean_prediction_length(image_predictions))

In [None]:
for quadrat_id in image_predictions.keys():
    image_predictions[quadrat_id] = [
        species_index_to_id[idx] for idx in image_predictions[quadrat_id]
    ]

In [None]:
df_run = pd.DataFrame(
    list(image_predictions.items()),
    columns=[
        "quadrat_id",
        "species_ids",
    ],
)
df_run["species_ids"] = df_run["species_ids"].apply(str)
df_run.to_csv(
    "./predictions/submission.csv",
    sep=",",
    index=False,
    quoting=csv.QUOTE_ALL,
)