# Balance Score Analysis for Clusters

This notebook computes the balance_score and balance_scores_detailed for each cluster in cluster_categorizations.

In [None]:
import polars as pl
import numpy as np
from pathlib import Path
import sys

# Add the data_pipeline package to path
sys.path.append(str(Path().absolute().parent))

# Import the balance score calculation function
from data_pipeline.assets.ai_conversations.utils.find_top_k_users import get_approx_bipartite_match

In [None]:
# Define data directory
data_dir = Path().absolute() / ".." / "data"

# List available files
print("Available dagster run IDs:")
dagster_runs = list((data_dir / "dagster/cluster_categorizations").glob("*.snappy"))
for run in dagster_runs:
    print(f"- {run.stem}")

# Choose the latest run or a specific one
latest_run_id = dagster_runs[-1].stem if dagster_runs else None
run_id = latest_run_id  # Change this if you want to use a specific run ID

In [None]:
run_id

In [None]:
# Load the categorized clusters data
categorized_clusters_path = data_dir / f"dagster/cluster_categorizations/{run_id}.snappy"
categorized_clusters_df = pl.read_parquet(categorized_clusters_path)

# Display basic information
print(f"Loaded {len(categorized_clusters_df)} categorized clusters")
print("\nColumns:")
print(categorized_clusters_df.columns)

# Display sample rows
categorized_clusters_df.head(5)

In [None]:
import math
def calculate_balance_scores(embeddings_current, embeddings_other):
    len_current = len(embeddings_current)
    len_other = len(embeddings_other)
    if len_other == 0 or len_current == 0:
        return float("inf"), {}  # Deprioritize if either side has no conversations

    # Calculate imbalance penalty (smaller is better)
    ratio = len_current / len_other
    imbalance = abs(math.log(ratio))

    # Calculate magnitude bonus (larger total is better)
    total_conversations = len_current + len_other
    magnitude_factor = 1 / total_conversations  # Inverse so smaller is better

    # Calculate cosine similarity between embeddings
    sim = get_approx_bipartite_match(np.array(embeddings_current), np.array(embeddings_other))
    dist = 1 - sim

    return imbalance + magnitude_factor + dist, {
        "imbalance": imbalance,
        "magnitude_factor": magnitude_factor,
        "dist": dist,
    }

In [None]:
from sklearn.preprocessing import StandardScaler

def compute_cluster_balance_scores(df):
    # First pass: collect all metrics
    all_imbalances = []
    all_magnitudes = []
    all_distances = []
    pair_data = []
    
    # Get unique clusters
    unique_clusters = df.select(pl.col("cluster_id")).unique()
    
    for cluster_id in unique_clusters["cluster_id"]:
        # Get conversations for this cluster
        cluster_convos = df.filter(pl.col("cluster_id") == cluster_id)
        
        # Get unique users in this cluster
        users = set()
        users.update(cluster_convos["user_id"].unique().to_list())
        
        # For each pair of users, calculate raw metrics
        users = list(users)
        for i in range(len(users)):
            for j in range(i+1, len(users)):
                user1, user2 = users[i], users[j]
                
                # Get conversations for each user
                user1_convos = cluster_convos.filter(pl.col("user_id") == user1)
                user2_convos = cluster_convos.filter(pl.col("user_id") == user2)
                
                # Calculate raw metrics
                if len(user1_convos) > 0 and len(user2_convos) > 0:
                    ratio = len(user1_convos) / len(user2_convos)
                    imbalance = abs(math.log(ratio))
                    magnitude_factor = 1 / (len(user1_convos) + len(user2_convos))
                    
                    sim = get_approx_bipartite_match(
                        np.array(user1_convos.get_column("embedding").to_list()), 
                        np.array(user2_convos.get_column("embedding").to_list())
                    )
                    dist = 1 - sim
                    
                    all_imbalances.append(imbalance)
                    all_magnitudes.append(magnitude_factor)
                    all_distances.append(dist)
                    
                    pair_data.append({
                        "cluster_id": cluster_id,
                        "user1": user1,
                        "user2": user2,
                        "user1_convos": len(user1_convos),
                        "user2_convos": len(user2_convos),
                        "imbalance": imbalance,
                        "magnitude_factor": magnitude_factor,
                        "dist": dist,
                        "category": cluster_convos["category"][0] if "category" in cluster_convos.columns else None
                    })
    
    # Standardize metrics
    scaler = StandardScaler()
    metrics = np.column_stack([all_imbalances, all_magnitudes, all_distances])
    standardized_metrics = scaler.fit_transform(metrics)
    
    # Second pass: compute standardized balance scores
    results = []
    for i, pair in enumerate(pair_data):
        std_imbalance, std_magnitude, std_dist = standardized_metrics[i]
        balance_score = std_imbalance + std_magnitude + std_dist
        
        result = pair.copy()
        result["balance_score"] = balance_score
        result["std_imbalance"] = std_imbalance
        result["std_magnitude"] = std_magnitude  
        result["std_dist"] = std_dist
        results.append(result)
    
    return pl.DataFrame(results)

In [None]:
# Compute balance scores for all clusters
balance_scores_df = compute_cluster_balance_scores(categorized_clusters_df)

# Display the results
print(f"Computed balance scores for {len(balance_scores_df)} user pairs across clusters")
balance_scores_df.head(10)

## Analysis by Category

In [None]:
# Check if category column exists
if "category" in balance_scores_df.columns:
    # Group by category and calculate average scores
    category_stats = balance_scores_df.group_by("category").agg(
        pl.col("balance_score").mean().alias("avg_balance_score"),
        pl.col("imbalance").mean().alias("avg_imbalance"),
        pl.col("magnitude_factor").mean().alias("avg_magnitude_factor"),
        pl.col("dist").mean().alias("avg_dist"),
        pl.len().alias("count")
    )
    
    # Display category statistics
    print("Balance score statistics by category:")

category_stats

## Component Analysis

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Create a figure with multiple plots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot histograms for each component
components = ["imbalance", "magnitude_factor", "dist"]
titles = ["Imbalance", "Magnitude Factor", "Distance"]

for i, (component, title) in enumerate(zip(components, titles)):
    if "category" in balance_scores_df.columns:
        sns.histplot(data=balance_scores_df.to_pandas(), x=component, hue="category", 
                     element="step", bins=20, common_norm=False, ax=axes[i])
    else:
        sns.histplot(data=balance_scores_df.to_pandas(), x=component, bins=20, ax=axes[i])
    
    axes[i].set_title(f"Distribution of {title}")
    axes[i].set_xlabel(component)
    axes[i].axvline(x=balance_scores_df[component].mean(), color='r', linestyle='--', 
                    label=f"Mean: {balance_scores_df[component].mean():.3f}")
    axes[i].legend()

plt.tight_layout()
plt.show()

## Correlation Analysis

In [None]:
# Calculate correlation between components
correlation_cols = ["balance_score", "imbalance", "magnitude_factor", "dist"]
correlation_df = balance_scores_df.select(correlation_cols).to_pandas()

# Plot correlation matrix
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_df.corr(), annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0)
plt.title("Correlation Between Balance Score Components")
plt.tight_layout()
plt.show()

## Top Clusters by Balance Score

In [None]:
balance_scores_df.filter(pl.col("category") != "coding").sort("balance_score").head(10)

## Summary

This notebook analyzed the balance scores for clusters in the categorized data. The balance score is a composite metric that combines:

1. **Imbalance**: How evenly distributed conversations are between users (lower is better)
2. **Magnitude Factor**: Inverse of total conversation count (lower means more conversations)
3. **Distance**: Semantic distance between user conversations (lower means more similar interests)

These scores help identify the most promising clusters for generating serendipitous connections between users.