In [None]:
from datetime import datetime
import functools as f

from sklearn.cluster import SpectralClustering
from numpy.typing import NDArray
import numpy as np

from common_generate_predictions import load_data, grid_search

In [None]:
max_number_clusters = 5
wic_data = True
method = "sc"
llm = "wic"
dataset = "dwug_es"
path_to_data = f"../input/wic-scores/{dataset}"
path_to_gold_data = "../test_data_es.csv"
prompts = ["wic1", "wic2", "wic3", "wic4", "wic5"]

In [None]:
def get_clusters(
    adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict, seed=456
):
    clustering = SpectralClustering(
        n_clusters=hyperparameters["n_clusters"],
        affinity=hyperparameters["affinity"],
        assign_labels=hyperparameters["strategy"],
        random_state=seed,
    ).fit(adj_matrix)

    return clustering.labels_

In [None]:
def generate_hyperparameters_for_sc(max_number_clusters: int):
    combinations = []
    for ncluster in range(2, max_number_clusters + 1):
        for affinity in ["precomputed", "nearest_neighbors", "precomputed_nearest_neighbors", "rbf"]:
            for strategy in ["kmeans", "discretize", "cluster_qr"]:
                combinations.append(
                    {
                        "n_clusters": ncluster,
                        "affinity": affinity,
                        "strategy": strategy,
                    }
                )

    return combinations

In [None]:
metadata = {
    "fill_diagonal": True,
    "normalize": False,
    "method": method,
    "path_to_gold_data": path_to_gold_data,
    "path_to_data": path_to_data,
    "llm": llm,
    "prompts": prompts,
    "dataset": dataset,
    "wic_data": wic_data,
}

In [None]:
start_time = datetime.now()

grid_search(
    f.partial(load_data, path_to_data),
    get_clusters,
    generate_hyperparameters_for_sc(max_number_clusters=max_number_clusters),
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")