In [1]:
from datetime import datetime
import functools as f
from itertools import product

from numpy.typing import NDArray
import numpy as np

from common_generate_predictions import load_data, grid_search_without_nclusters
import clustering

In [2]:
wic_data = False
method = "wsbm"
llm = "llama3.1-8B"
dataset = "dwug_en"
path_to_data = f"../input/llama3.1-8B/{dataset}"
path_to_gold_data = "../test_data_en.csv"
score_paths = ["zs", "fs", "ct"]

In [3]:
def get_clusters(adj_matrix: NDArray[np.float64 | np.int32], hyperparameters: dict):
    graph = clustering._adjacency_matrix_to_nxgraph(
        adj_matrix, use_disconnected_edges=False
    )
    clusters = clustering.wsbm_clustering(graph, **hyperparameters)

    return clustering._convert_graph_cluster_list_set_to_list(graph, clusters)

In [4]:
model_hyperparameter_combinations = []
distributions = [
    "discrete-geometric",
    "discrete-poisson",
    "discrete-binomial",
    "real-normal",
    "real-exponential",
]

for distribution in distributions:
    model_hyperparameter_combinations.append({"distribution": distribution})

In [5]:
fill_diagonal = [True, False]
normalize = [False]
hc = product(fill_diagonal, normalize, score_paths, model_hyperparameter_combinations)

hyperparameter_combinations = [
    {
        "fill_diagonal": fd,
        "normalize": nm,
        "prompt": sp,
        "model_hyperparameters": mhc,
    }
    for fd, nm, sp, mhc in hc
    if (mhc["distribution"].startswith("real") and nm is True)
    or (mhc["distribution"].startswith("discrete") and nm is False)
]

In [7]:
print(len(hyperparameter_combinations))
print(hyperparameter_combinations)

18
[{'fill_diagonal': True, 'normalize': False, 'prompt': 'zs', 'model_hyperparameters': {'distribution': 'discrete-geometric'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'zs', 'model_hyperparameters': {'distribution': 'discrete-poisson'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'zs', 'model_hyperparameters': {'distribution': 'discrete-binomial'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'fs', 'model_hyperparameters': {'distribution': 'discrete-geometric'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'fs', 'model_hyperparameters': {'distribution': 'discrete-poisson'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'fs', 'model_hyperparameters': {'distribution': 'discrete-binomial'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'ct', 'model_hyperparameters': {'distribution': 'discrete-geometric'}}, {'fill_diagonal': True, 'normalize': False, 'prompt': 'ct', 'model_hyperparameters': {'distribution': 'discrete-poisson'

In [10]:
metadata = {
    "fill_diagonal": True,
    "normalize": True,
    "method": method,
    "path_to_gold_data": path_to_gold_data,
    "path_to_data": path_to_data,
    "llm": llm,
    "score_paths": score_paths,
    "dataset": dataset,
    "wic_data": wic_data,
    "path_to_save_results": f"../cv-experiments-lscd-all-hyperparameter/{method}/{llm}/{dataset}",
}

In [11]:
start_time = datetime.now()

grid_search_without_nclusters(
    f.partial(load_data, path_to_data),
    get_clusters,
    hyperparameter_combinations,
    metadata=metadata,
)

print(f"Elapsed time: {datetime.now() - start_time}")

2024-09-11 23:34:01,182 - INFO - loading data from ../input/llama3.1-8B/dwug_es/zs ...
2024-09-11 23:34:01,393 - INFO - data loaded ...
2024-09-11 23:34:01,394 - INFO - loading data from ../input/llama3.1-8B/dwug_es/fs ...
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_data["score"] = pd.to_numeric(filtered_data["score"], errors="raise")
2024-09-11 23:34:01,596 - INFO - data loaded ...
2024-09-11 23:34:01,597 - INFO - loading data from ../input/llama3.1-8B/dwug_es/ct ...
2024-09-11 23:34:01,787 - INFO - data loaded ...
2024-09-11 23:34:01,788 - INFO - training wsbm method ...
2024-09-11 23:34:01,788 - INFO -   1/30 - {'fill_diagonal': True, 'normalize': True, 'prompt': 'zs', 'model_hyperparameters': {'distribution': 'real-normal'}}
2024-09-11 23:34:01,792 - INFO 