In [1]:
from datetime import datetime
import functools as f
from itertools import product

from sklearn.cluster import AgglomerativeClustering
import numpy as np
from numpy.typing import NDArray

from common_generate_predictions import grid_search, load_data

In [2]:
max_number_clusters = 5
wic_data = True
method = "ac"
llm = "wic"
dataset = "dwug_es"
path_to_data = f"../input/wic-scores/{dataset}_cleaned"
path_to_gold_data = "../test_data_es.csv"
score_paths = ["wic1", "wic2", "wic3", "wic4", "wic5", "wic6", "wic7"]

In [3]:
def get_clusters(adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict):
    clustering = AgglomerativeClustering(
        n_clusters=hyperparameters["n_clusters"],
        metric=hyperparameters["metric"],
        linkage=hyperparameters["linkage"],
    ).fit(adj_matrix)

    return clustering.labels_
    

In [4]:
def generate_hyparameters_combinations_for_ac(max_number_clusters: int):
    combinations = []
    for metric in ["cosine"]:
        for l in ["complete"]:
            combinations.append(
                {
                    "metric": metric,
                    "linkage": l,
                }
            )

    return combinations

In [None]:
generate_hyparameters_combinations_for_ac(5)

In [6]:
quantile = range(0, 7)
fill_diagonal = [True, False]
normalize = [True, False]
hc = product(
    quantile,
    fill_diagonal,
    normalize,
    score_paths,
    generate_hyparameters_combinations_for_ac(5)
)

hyperparameter_combinations = [
    {
        "quantile": q,
        "fill_diagonal": fd,
        "normalize": nm,
        "prompt": sp,
        "model_hyperparameters": mhc,
        
    }
    for q, fd, nm, sp, mhc in hc
]

In [None]:
print(hyperparameter_combinations)
print(len(hyperparameter_combinations))

In [8]:
metadata = {
    "fill_diagonal": True,
    "normalize": True,
    "method": method,
    "path_to_gold_data": path_to_gold_data,
    "path_to_data": path_to_data,
    "path_to_sense_data": "../dwug_es_sense.csv",
    "llm": llm,
    "score_paths": score_paths,
    "dataset": dataset,
    "wic_data": wic_data,
    "max_n_clusters": max_number_clusters,
    "path_to_save_results": f"../cv-experiments-lscd-ari-dwug-cleaned/{method}/{llm}/{dataset}",
}

In [None]:
start_time = datetime.now()

grid_search(
    f.partial(load_data, path_to_data),
    get_clusters,
    hyperparameter_combinations,
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")