In [None]:
from datetime import datetime
import functools as f

from sklearn.cluster import AgglomerativeClustering
import numpy as np
from numpy.typing import NDArray

from common_generate_predictions import grid_search, load_data

In [None]:
max_number_clusters = 5
wic_data = True
method = "ac"
llm = "wic"
dataset = "dwug_es"
path_to_data = f"../input/wic-scores/{dataset}"
path_to_gold_data = "../test_data_es.csv"
prompts = ["wic1", "wic2", "wic3", "wic4", "wic5"]

In [None]:
def get_clusters(adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict):
    clustering = AgglomerativeClustering(
        n_clusters=hyperparameters["n_clusters"],
        metric=hyperparameters["metric"],
        linkage=hyperparameters["linkage"],
    ).fit(adj_matrix)

    return clustering.labels_
    

In [None]:
def generate_hyparameters_combinations_for_ac(max_number_clusters: int):
    combinations = []
    for ncluster in range(2, max_number_clusters + 1):
        for metric in ["precomputed"]:
            for l in ["complete", "average", "single"]:
                combinations.append(
                    {
                        "n_clusters": ncluster,
                        "metric": metric,
                        "linkage": l,

                    }
                )

    return combinations
        

In [None]:
metadata = {
    "fill_diagonal": True,
    "normalize": True,
    "method": method,
    "path_to_gold_data": path_to_gold_data,
    "path_to_data": path_to_data,
    "llm": llm,
    "prompts": prompts,
    "dataset": dataset,
    "wic_data": wic_data,
}

In [None]:
start_time = datetime.now()

grid_search(
    f.partial(load_data, path_to_data),
    get_clusters,
    generate_hyparameters_combinations_for_ac(max_number_clusters=max_number_clusters),
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")