In [1]:
from datetime import datetime
import functools as f
from itertools import product

from numpy.typing import NDArray
import numpy as np

from common_generate_predictions import load_data, grid_search_without_nclusters
import clustering

In [2]:
wic_data = False
method = "wsbm"
llm = "llama3.1-8B"
dataset = "dwug_es"
path_to_data = f"../input/llama3.1-8B/{dataset}"
path_to_gold_data = "../test_data_es.csv"
score_paths = ["zs", "fs", "ct"]

In [3]:
def get_clusters(adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict):
    graph = clustering._adjacency_matrix_to_nxgraph(
        adj_matrix, use_disconnected_edges=False
    )
    clusters = clustering.wsbm_clustering(graph, **hyperparameters)

    return clustering._convert_graph_cluster_list_set_to_list(graph, clusters)

In [4]:
model_hyperparameter_combinations = []
distributions = [
    "discrete-geometric",
    "discrete-poisson",
    "discrete-binomial",
    "real-normal",
    "real-exponential",
]

for distribution in distributions:
    model_hyperparameter_combinations.append({"distribution": distribution})

In [5]:
fill_diagonal = [True, False]
normalize = [True, False]
hc = product(fill_diagonal, normalize, score_paths, model_hyperparameter_combinations)

hyperparameter_combinations = [
    {"fill_diagonal": fd, "normalize": nm, "prompt": sp, "model_hyperparameters": mhc}
    for fd, nm, sp, mhc in hc
    if (mhc["distribution"].startswith("real") and nm is True)
    or (mhc["distribution"].startswith("discrete") and nm is False)
]

In [6]:
metadata = {
    "fill_diagonal": True,
    "normalize": True,
    "method": method,
    "path_to_gold_data": path_to_gold_data,
    "path_to_data": path_to_data,
    "llm": llm,
    "score_paths": score_paths,
    "dataset": dataset,
    "wic_data": wic_data,
    "path_to_save_results": f"../cv-experiments-lscd-all-hyperparameter/{method}/{llm}/{dataset}",
}

In [7]:
start_time = datetime.now()

grid_search_without_nclusters(
    f.partial(load_data, path_to_data),
    get_clusters,
    hyperparameter_combinations,
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")

KeyError: 'score_paths'