In [None]:
from kLLMmeans import kLLMmeans, get_embeddings, summarize_cluster
from experiment_utils import load_dataset, cluster_metrics, avg_closest_distance
from sklearn.cluster import KMeans
from sklearn_extra.cluster import KMedoids

import numpy as np
import json, pickle

import warnings
warnings.filterwarnings("ignore")

In [None]:
max_iter = 120

data_list = ['clinic','bank77','massive_I','massive_D','goemo']:

for data in data_list:
    results_dict = {}
    
    with open("processed_data/data_" + data + ".pkl", "rb") as f:
        data_dict = pickle.load(f)
        
    labels = data_dict['labels']
    num_clusters = data_dict['num_clusters']   
    documents = data_dict['documents']
    text_features = data_dict['embeddings']
    #oracle_summaries = data_dict['summaries']
    prompt = data_dict['prompt']
    text_type = data_dict['text_type']
    
    
    oracle_cluster_assignments = labels
    
    for emb_type in ['distilbert', 'openai', 'e5-large', 'sbert']:
        
        results_dict[emb_type] = ""
        emb_data = text_features[emb_type]
        
        #calculate oracle embeddings
        oracle_clustered_embeddings = {i: [] for i in range(num_clusters)}
        for embedding, cluster in zip(emb_data, oracle_cluster_assignments):
            oracle_clustered_embeddings[cluster].append(embedding)
        oracle_centroids = [np.mean(oracle_clustered_embeddings[i], axis=0) if oracle_clustered_embeddings[i] else None for i in range(num_clusters)]
        #oracle_summary_embeddings = get_embeddings(oracle_summaries, emb_type = emb_type)
        oracle_summary_embeddings = oracle_centroids
        
        for seed in range(10):
            results_dict[emb_type][seed] = {}
    
            #kmeans
            kmeans = KMeans(n_clusters=num_clusters, max_iter=max_iter, random_state=seed)
            kmeans_assignments = kmeans.fit_predict(text_features[emb_type])
            kmeans_centroids = kmeans.cluster_centers_
            results = cluster_metrics(np.array(labels), kmeans_assignments, oracle_centroids, kmeans_centroids, oracle_summary_embeddings)
            data_results ={'assignments':kmeans_assignments,
                           'final_centroids':kmeans_centroids,
                           'results':results}
            
            results_dict[emb_type][seed]['kmeans'] = data_results
            
            print([data, emb_type, seed, 'kmeans', results])
    
            #kmedoids
            kmedoids = KMedoids(n_clusters=num_clusters, max_iter = max_iter, random_state=seed)
            kmedoids.fit(text_features[emb_type])
            kmedoids_assignments = kmedoids.labels_
            kmedoids_indices = kmedoids.medoid_indices_
            kmedoids_centroids = text_features[emb_type][kmedoids_indices]
            results = cluster_metrics(np.array(labels), kmedoids_assignments, oracle_centroids, kmedoids_centroids, oracle_summary_embeddings)
            data_results ={'assignments':kmedoids_assignments,
                           'final_centroids':kmedoids_centroids,
                           'results':results}
            
            results_dict[emb_type][seed]['kmedoids'] = data_results
            
            print([data, emb_type, seed, 'kmedoids', results])
            
            for force_context_length in [0, 10]:
                results_dict[emb_type][seed][force_context_length] = {}
                            
                for max_llm_iter in [1, 5]:
                    
                    assignments, final_summaries, final_summary_embeddings, final_centroids, summaries_evolution, centroids_evolution = kLLMmeans(documents,
                                                             prompt = prompt, text_type = text_type,
                                                             num_clusters = num_clusters, 
                                                             force_context_length = force_context_length, max_llm_iter = max_llm_iter, 
                                                             max_iter = max_iter, tol=1e-4, random_state = seed, 
                                                             emb_type = emb_type,
                                                             text_features = text_features[emb_type])
                    
                    results = cluster_metrics(np.array(labels), assignments,
                                              oracle_centroids, final_centroids, 
                                              oracle_summary_embeddings, final_summary_embeddings)
    
                    data_results ={'assignments':assignments,
                                   'final_summaries':final_summaries,
                                   'final_summary_embeddings':final_summary_embeddings,
                                   'final_centroids':final_centroids,
                                   'summaries_evolution':summaries_evolution,
                                   'centroids_evolution':centroids_evolution,
                                   'results':results}
                    
                    results_dict[emb_type][seed][force_context_length][max_llm_iter] = data_results
                    
                    print([data, emb_type, seed, force_context_length, max_llm_iter, results])
    
                    # Save as pkl file
                    with open("results/sims_offline_results_" + emb_type + '_' + data + ".pkl", "wb") as f:
                        pickle.dump(results_dict, f)
            