In [1]:
from datetime import datetime
import functools as f
from itertools import product

from sklearn.cluster import SpectralClustering
from numpy.typing import NDArray
import numpy as np

from common_generate_predictions import load_data, grid_search

In [2]:
max_number_clusters = 5
wic_data = False
method = "sc"
llm = "xl-lexeme"
dataset = "dwug_en"
path_to_data = f"../input/{llm}/{dataset}"
path_to_gold_data = "../test_data_en.csv"
score_paths = ["wic1"]

In [3]:
def get_clusters(
    adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict, seed=456
):
    clustering = SpectralClustering(
        n_clusters=hyperparameters["n_clusters"],
        affinity=hyperparameters["affinity"],
        assign_labels=hyperparameters["strategy"],
        random_state=seed,
    ).fit(adj_matrix)

    return clustering.labels_

In [4]:
def generate_hyperparameters_for_sc(max_number_clusters: int):
    combinations = []
    for affinity in ["precomputed"]:
        for strategy in ["kmeans"]:
            combinations.append(
                {
                    "affinity": affinity,
                    "strategy": strategy,
                }
            )

    return combinations

In [None]:
generate_hyperparameters_for_sc(max_number_clusters)

In [6]:
quantile = range(0, 7)
fill_diagonal = [True, False]
normalize = [True, False]
hc = product(
    quantile,
    fill_diagonal,
    normalize,
    score_paths,
    generate_hyperparameters_for_sc(max_number_clusters),
)

hyperparameter_combinations = [
    {
        "quantile": q,
        "fill_diagonal": fd,
        "normalize": nm,
        "prompt": sp,
        "model_hyperparameters": mhc,
    }
    for q, fd, nm, sp, mhc in hc
]

In [None]:
print(hyperparameter_combinations)
print(len(hyperparameter_combinations))

In [8]:
metadata = {
    "fill_diagonal": True,
    "normalize": True,
    "method": method,
    "path_to_gold_data": path_to_gold_data,
    "path_to_data": path_to_data,
    "path_to_sense_data": "../dwug_en_sense.csv",
    "llm": llm,
    "score_paths": score_paths,
    "dataset": dataset,
    "wic_data": wic_data,
    "max_n_clusters": max_number_clusters,
    "path_to_save_results": f"../cv-experiments-lscd-ari-dwug-cleaned/{method}/{llm}/{dataset}",
}

In [None]:
start_time = datetime.now()

grid_search(
    f.partial(load_data, path_to_data),
    get_clusters,
    hyperparameter_combinations,
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")