In [None]:
# Notebook for advanced stats and visualization.
# Run this after you generate output. Change output_dir to the output directory from Auto-Cat
# Update the settings in plot_top_levels_dendrogram to change the resulting image size.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from scipy.cluster.hierarchy import dendrogram
from typing import Dict, Any
from wordcloud import WordCloud

# Set the output directory
output_dir = 'output3'

# Load the data
def load_embeddings(file_path: str):
    with np.load(file_path, allow_pickle=True) as data:
        return {
            'embeddings': data['embeddings'],
            'ids': data['ids'].tolist(),
            'metadatas': data['metadatas'].tolist()
        }

embeddings = load_embeddings(os.path.join(output_dir, "embeddings.npz"))

# Load clustering results
clustering_methods = ["kmeans", "dbscan", "agglomerative"]
all_clusters = {}
for method in clustering_methods:
    clusters_file = os.path.join(output_dir, f"{method}_clusters.npy")
    if os.path.exists(clusters_file):
        all_clusters[method] = np.load(clusters_file)

# Load linkage matrix for agglomerative clustering
linkage_matrix = np.load(os.path.join(output_dir, "agglomerative_linkage.npy"))

# Load category analysis results
all_categories = {}
for method in clustering_methods:
    category_file = os.path.join(output_dir, f"{method}_report", "detailed_report.txt")
    if os.path.exists(category_file):
        categories = {}
        with open(category_file, 'r') as f:
            current_category = None
            for line in f:
                if line.startswith("Category "):
                    current_category = int(line.split()[1].strip(':'))
                    categories[current_category] = {"size": 0, "common_words": [], "representative_items": []}
                elif line.strip().startswith("Size:"):
                    categories[current_category]["size"] = int(line.split(":")[1].strip())
                elif line.strip().startswith("Common Words:"):
                    categories[current_category]["common_words"] = line.split(":")[1].strip().split(", ")
                elif line.strip().startswith("Representative Items:"):
                    items = []
                    for item_line in f:
                        if item_line.strip() and not item_line.startswith("Category"):
                            items.append(item_line.strip())
                        else:
                            break
                    categories[current_category]["representative_items"] = items
        all_categories[method] = categories

# Improved dendrogram function
def plot_top_levels_dendrogram(linkage_matrix: np.ndarray, categories: Dict[int, Dict[str, Any]], n_top_levels: int = None):
    plt.figure(figsize=(80, 45))  # 80x45 inches at 96 DPI is approximately 7680x4320 pixels
    
    def llf(id):
        if id < len(categories):
            common_words = categories[id]["common_words"]
            return ", ".join(common_words[:3]) if common_words else f"Cluster {id}"
        return f"Cluster {id}"
    
    R = dendrogram(
        linkage_matrix,
        truncate_mode=None,  # Show full tree
        no_labels=True,
        leaf_rotation=0,
        leaf_font_size=16,
        show_contracted=True,
    )
    
    ax = plt.gca()
    
    for i, d, c in zip(R['icoord'], R['dcoord'], R['color_list']):
        x = 0.5 * sum(i[1:3])
        y = d[1]
        if y > 0:  # Only annotate internal nodes
            node_id = R['leaves'][int(x/10)]  # Approximate mapping of x-coordinate to node id
            label = llf(node_id)
            ax.plot(x, y, 'o', c=c)
            ax.annotate(label, (x, y), xytext=(0, 5),
                        textcoords='offset points',
                        va='bottom', ha='center',
                        bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.7),
                        fontsize=12, rotation=90)

    plt.title('Hierarchical Clustering Dendrogram with Category Labels', fontsize=24)
    plt.xlabel('Sample Index', fontsize=20)
    plt.ylabel('Distance', fontsize=20)
    
    # Manually adjusting margins to avoid tight_layout warning
    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
    
    # Save the figure with high DPI to ensure sharpness at 8K resolution
    plt.savefig("full_dendrogram_with_labels_8k.png", dpi=150, bbox_inches='tight')
    plt.show()

# Generate the improved dendrogram with labels on forks
plot_top_levels_dendrogram(linkage_matrix, all_categories['agglomerative'])


# Clustering comparison
def compare_clustering_methods(all_categories: Dict[str, Dict[int, Dict[str, Any]]]):
    comparison = {
        method: {
            "num_categories": len(categories),
            "avg_category_size": sum(cat["size"] for cat in categories.values()) / len(categories),
            "max_category_size": max(cat["size"] for cat in categories.values()),
            "min_category_size": min(cat["size"] for cat in categories.values()),
        }
        for method, categories in all_categories.items()
    }
    df = pd.DataFrame(comparison).T
    
    plt.figure(figsize=(12, 6))
    df.plot(kind='bar', y=['num_categories', 'avg_category_size'], ax=plt.gca())
    plt.title("Comparison of Clustering Methods")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "clustering_comparison.png"))
    plt.show()
    
    return df

comparison_df = compare_clustering_methods(all_categories)
print(comparison_df)

# Word cloud generation
def generate_word_cloud(categories: Dict[int, Dict[str, Any]], method: str):
    all_words = ' '.join([' '.join(cat['common_words']) for cat in categories.values() if cat['common_words']])

    if not all_words:
        print(f"No common words found for {method}. Skipping word cloud generation.")
        return

    wordcloud = WordCloud(width=800, height=400, background_color='white').generate(all_words)

    plt.figure(figsize=(10, 5))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis('off')
    plt.title(f"Word Cloud for {method.capitalize()} Clustering")
    plt.tight_layout(pad=0)
    plt.savefig(os.path.join(output_dir, f"{method}_word_cloud.png"))
    plt.show()

# Generate word clouds for each clustering method
for method in clustering_methods:
    generate_word_cloud(all_categories[method], method)

# Category size distribution
def plot_category_size_distribution(categories: Dict[int, Dict[str, Any]], method: str):
    sizes = [cat['size'] for cat in categories.values()]
    plt.figure(figsize=(12, 6))
    sns.histplot(sizes, kde=True)
    plt.title(f"Category Size Distribution for {method.capitalize()} Clustering")
    plt.xlabel("Category Size")
    plt.ylabel("Frequency")
    plt.savefig(os.path.join(output_dir, f"{method}_category_size_distribution.png"))
    plt.show()

# Plot category size distribution for each clustering method
for method in clustering_methods:
    plot_category_size_distribution(all_categories[method], method)

# Top categories analysis
def analyze_top_categories(categories: Dict[int, Dict[str, Any]], method: str, top_n: int = 10):
    sorted_categories = sorted(categories.items(), key=lambda x: x[1]['size'], reverse=True)
    top_categories = sorted_categories[:top_n]
    
    print(f"Top {top_n} Categories for {method.capitalize()} Clustering:")
    for i, (cat_id, cat_info) in enumerate(top_categories, 1):
        print(f"{i}. Category {cat_id}")
        print(f"   Size: {cat_info['size']}")
        print(f"   Common Words: {', '.join(cat_info['common_words'][:5])}")
        print(f"   Sample Item: {cat_info['representative_items'][0][:100]}...")
        print()

# Analyze top categories for each clustering method
for method in clustering_methods:
    analyze_top_categories(all_categories[method], method)

def plot_category_word_heatmap(categories: Dict[int, Dict[str, Any]], method: str):
    unique_words = set()
    for cat in categories.values():
        unique_words.update(cat['common_words'])
    
    word_index = {word: idx for idx, word in enumerate(unique_words)}
    category_names = list(categories.keys())
    
    heatmap_data = np.zeros((len(category_names), len(unique_words)))
    
    for i, (cat_id, cat_info) in enumerate(categories.items()):
        for word in cat_info['common_words']:
            if word in word_index:
                heatmap_data[i, word_index[word]] = 1
    
    plt.figure(figsize=(15, 10))
    sns.heatmap(heatmap_data, cmap="YlGnBu", xticklabels=unique_words, yticklabels=category_names)
    plt.title(f"Heatmap of Categories and Common Words for {method.capitalize()} Clustering")
    plt.xlabel("Common Words")
    plt.ylabel("Categories")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{method}_category_word_heatmap.png"))
    plt.show()

# Generate heatmaps for each clustering method
for method in clustering_methods:
    plot_category_word_heatmap(all_categories[method], method)
    
def plot_sankey_diagram_for_clusters(categories: Dict[int, Dict[str, Any]], method: str):
    from matplotlib.sankey import Sankey

    sankey = Sankey(unit=None)
    for cat_id, cat_info in categories.items():
        size = cat_info['size']
        # TODO Replace logic to determine flows between categories
        sankey.add(flows=[size, -size], labels=[f"Cat {cat_id} start", f"Cat {cat_id} end"], orientations=[1, -1])

    fig, ax = plt.subplots(figsize=(10, 7))
    sankey.finish()
    plt.title(f"Sankey Diagram for {method.capitalize()} Clustering")
    plt.savefig(os.path.join(output_dir, f"{method}_sankey_diagram.png"))
    plt.show()

plot_sankey_diagram_for_clusters(all_categories['agglomerative'], 'agglomerative')
