In [None]:
from datetime import datetime
import functools as f

from numpy.typing import NDArray
import numpy as np

from common_generate_predictions import load_data, grid_search_without_nclusters
import clustering

In [None]:
method = "wsbm"
llm = "mixtral"
dataset = "dwug_en"
path_to_data = f"../input/mixtral-8xtb-v0.1/{dataset}"
path_to_gold_data = "../test_data_en.csv"
prompts = ["zs", "fs", "ct"]

In [None]:
def get_clusters(adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict):
    graph = clustering._adjacency_matrix_to_nxgraph(
        adj_matrix, use_disconnected_edges=False
    )
    clusters = clustering.wsbm_clustering(graph, **hyperparameters)

    return clustering._convert_graph_cluster_list_set_to_list(graph, clusters)

In [None]:
model_hyperparameter_combinations = []
distributions = [
    "discrete-geometric",
    "discrete-poisson",
    "discrete-binomial",
    "real-normal",
    "real-exponential",
]

for distribution in distributions:
    model_hyperparameter_combinations.append({"distribution": distribution})

In [None]:
metadata = {
    "fill_diagonal": True,
    "normalize": True,
    "method": method,
    "path_to_gold_data":path_to_gold_data,
    "path_to_data": path_to_data,
    "llm": llm,
    "prompts": prompts,
    "dataset": dataset
}

In [None]:
start_time = datetime.now()

grid_search_without_nclusters(
    f.partial(load_data, path_to_data),
    get_clusters,
    model_hyperparameter_combinations,
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")