In [5]:
import os
import json
import torch
from sklearn.cluster import KMeans
from tqdm import tqdm
import pandas as pd
import numpy as np

# Paths
embedding_folder = '/data1/dxw_data/llm/redbook_final/script_next/concatenated_embeddings_tag'
text_embedding_folder = '/data1/dxw_data/llm/redbook_final/script_next/text_embeddings_tag'
output_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_seg_img_pure_094_cluster_imagebind3'
csv_file = '/data1/dxw_data/llm/redbook_final/script_next/matching_records.csv'
os.makedirs(output_folder, exist_ok=True)

# Load the CSV to get post_tag information
df = pd.read_csv(csv_file)

# Function to load embeddings
def load_embeddings(embedding_files):
    embeddings = []
    for embedding_file in tqdm(embedding_files, desc="Loading embeddings"):
        embedding = torch.load(embedding_file)
        embeddings.append(embedding)
    return torch.cat(embeddings, dim=0).cpu()  # Move to CPU

# Load concatenated embeddings
concatenated_embedding_files = [os.path.join(embedding_folder, fname) for fname in os.listdir(embedding_folder) if fname.endswith('.pt')]
concatenated_embeddings = load_embeddings(concatenated_embedding_files)

# Set number of clusters
n_clusters = 100

# Perform clustering on concatenated embeddings
kmeans = KMeans(n_clusters=n_clusters, init='k-means++', random_state=0)  # Use KMeans++
concatenated_labels = kmeans.fit_predict(concatenated_embeddings.numpy())  # `concatenated_embeddings` is now on CPU

# Save clustering labels to JSON file
labels_json = {os.path.basename(concatenated_embedding_files[idx]): int(label_id) for idx, label_id in enumerate(concatenated_labels)}
with open(os.path.join(output_folder, 'concatenated_labels.json'), 'w') as f:
    json.dump(labels_json, f)

# Initialize dictionary to store post_tags for each cluster
cluster_post_tags = {i: [] for i in range(n_clusters)}

# Extract the corresponding post_tags from the CSV for each cluster
for idx, label in enumerate(concatenated_labels):
    embedding_file = concatenated_embedding_files[idx]
    poster_id, post_id = os.path.basename(embedding_file).replace('_concatenated_embedding.pt', '').split('_')
    
    # Find the corresponding row in the CSV to get the text
    row = df[(df['poster_id'] == poster_id) & (df['post_id'] == post_id)]
    if not row.empty:
        cluster_post_tags[label].append(row['post_tag'].values[0])

# Sample 10 post_tags from each cluster if available
sampled_cluster_post_tags = {cluster: tags[:10] for cluster, tags in cluster_post_tags.items()}

# Save the sampled post_tag information for each cluster to a JSON file
output_json_path = os.path.join(output_folder, 'most_populated_cluster_post_tags.json')
with open(output_json_path, 'w') as f:
    json.dump(sampled_cluster_post_tags, f, ensure_ascii=False, indent=4)

print(f"Post tags for each cluster saved to {output_json_path}")


Loading embeddings: 100%|██████████| 5912/5912 [00:02<00:00, 2824.67it/s]
  super()._check_params_vs_input(X, default_n_init=10)


Post tags for each cluster saved to /data1/dxw_data/llm/redbook_final/script_next/combined_seg_img_pure_094_cluster_imagebind3/most_populated_cluster_post_tags.json
