# Cluster Interpretation & Gene Analysis

This notebook provides biological interpretation of the clustering results by analyzing the most discriminant genes for each cluster.

## Objectives:
- Identify top discriminant genes for each cluster
- Analyze gene expression patterns across clusters
- Provide biological interpretation of cluster characteristics
- Create visualizations for cluster-specific gene signatures

## Prerequisites:
Run `preprocessing.ipynb` and `clustering_analysis.ipynb` first to generate the required input files.

## 1. Import Libraries and Load Data

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, adjusted_rand_score, normalized_mutual_info_score
from sklearn.metrics import calinski_harabasz_score, davies_bouldin_score
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold

print("Libraries imported successfully!")

In [None]:
# Load preprocessed data (assuming clustering_analysis.ipynb was run)
X_reduced = np.load("../dataset/X_reduced.npy")
y = pd.read_csv("../dataset/y_labels.csv")

print("X_reduced shape:", X_reduced.shape)
print("Labels:", y["Class"].unique())

# We'll need to re-run clustering to get cluster assignments
# (In practice, you'd save and load cluster assignments from clustering_analysis.ipynb)
optimal_k = 5
kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
clusters = kmeans.fit_predict(X_reduced)

cluster_counts = pd.Series(clusters).value_counts().sort_index()
print(f"Re-computed cluster assignments: {np.unique(clusters)}")

# Also need PCA for visualization
pca_vis = PCA(n_components=2)
X_vis = pca_vis.fit_transform(X_reduced)

# Create cluster-label dataframe for analysis
cluster_label_df = pd.DataFrame({
    'True_Label': y["Class"], 
    'Cluster': clusters
})

print("Data loaded and clustering re-computed for interpretation.")

In [None]:
# Load original merged dataset to get feature names
print("Loading original feature names...")
df_original = pd.read_csv("../dataset/gene_expression_merged.csv")
X_original = df_original.drop(columns=["sample_id", "Class"])

# Reconstruct which features were selected by VarianceThreshold
scaler = StandardScaler()
X_scaled_full = scaler.fit_transform(X_original)

selector = VarianceThreshold(threshold=preprocessing_metadata['variance_threshold'])
selector.fit(X_scaled_full)
selected_mask = selector.get_support()
selected_gene_names = X_original.columns[selected_mask]

print(f"✓ Original features: {len(X_original.columns):,}")
print(f"✓ Selected features: {len(selected_gene_names):,}")
print(f"✓ Feature names available: {selected_gene_names[:5].tolist()}...")

## 2. Cluster Overview and Characteristics

In [None]:
# Display cluster summary
print("CLUSTER CHARACTERISTICS OVERVIEW")
print("=" * 80)
display(cluster_summary)

# Extract cluster assignments
cluster_assignments = clustering_results['predicted_cluster'].values
optimal_k = clustering_metadata['optimal_k']

print(f"\nClustering Performance Summary:")
print(f"  • Adjusted Rand Index: {clustering_metadata['adjusted_rand_index']:.4f}")
print(f"  • Normalized Mutual Information: {clustering_metadata['normalized_mutual_info']:.4f}")
print(f"  • Overall Accuracy: {clustering_metadata['overall_accuracy']:.1%}")
print(f"  • Weighted Purity: {clustering_metadata['weighted_purity']:.1%}")

## 3. Discriminant Gene Analysis

In [None]:
# CLUSTER INTERPRETATION: Top Discriminant Genes
print("\n=== TOP DISCRIMINANT GENES PER CLUSTER ===")

# Load the original gene names (features that survived variance threshold)
df_original = pd.read_csv("../dataset/gene_expression_merged.csv")
X_original = df_original.drop(columns=["sample_id", "Class"])

# Get boolean mask of selected features from VarianceThreshold
selector = VarianceThreshold(threshold=0.5)
selector.fit(StandardScaler().fit_transform(X_original))
selected_features = selector.get_support()
selected_gene_names = X_original.columns[selected_features]

print(f"Selected {len(selected_gene_names)} genes after variance threshold")

# Calculate mean expression per cluster for each gene
cluster_gene_means = pd.DataFrame(X_reduced, columns=selected_gene_names)
cluster_gene_means['Cluster'] = clusters

# Calculate cluster centroids
cluster_centroids = cluster_gene_means.groupby('Cluster').mean()

# Calculate overall mean for each gene
overall_means = cluster_gene_means.drop('Cluster', axis=1).mean()

# Find top discriminant genes per cluster (highest deviation from overall mean)
top_genes_per_cluster = {}
n_top_genes = 10

for cluster_id in range(optimal_k):
    # Calculate absolute deviation from overall mean
    deviations = abs(cluster_centroids.loc[cluster_id] - overall_means)
    # Get top genes with highest deviation
    top_genes = deviations.nlargest(n_top_genes)
    top_genes_per_cluster[cluster_id] = top_genes
    
    print(f"\nCluster {cluster_id} - Top {n_top_genes} Discriminant Genes:")
    for i, (gene, deviation) in enumerate(top_genes.items(), 1):
        direction = "↑" if cluster_centroids.loc[cluster_id, gene] > overall_means[gene] else "↓"
        print(f"  {i:2d}. {gene}: {direction} {deviation:.4f} deviation")

In [None]:
# Visualize top discriminant genes across clusters
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for cluster_id in range(optimal_k):
    ax = axes[cluster_id]
    
    # Get top 5 genes for this cluster
    top_5_genes = list(top_genes_per_cluster[cluster_id].head(5).index)
    
    # Create heatmap data for these genes across all clusters
    heatmap_data = cluster_centroids[top_5_genes].T
    
    # Create heatmap
    sns.heatmap(heatmap_data, annot=True, fmt='.2f', cmap='RdBu_r', center=0,
                ax=ax, cbar_kws={'shrink': 0.8})
    ax.set_title(f'Cluster {cluster_id} - Top 5 Discriminant Genes')
    ax.set_xlabel('Cluster')
    ax.set_ylabel('Genes')

# Remove empty subplot
axes[5].remove()
plt.tight_layout()
plt.savefig("../figures/discriminant_genes_heatmap.png", dpi=300, bbox_inches="tight")
plt.show()

# Create summary table of cluster characteristics
print("\n=== CLUSTER INTERPRETATION SUMMARY ===")

# Initialize DataFrame with proper structure
cluster_summary = pd.DataFrame({
    'Cluster': range(optimal_k),
    'Size': [cluster_counts[i] for i in range(optimal_k)],
    'Dominant_Cancer_Type': [None] * optimal_k,  # Initialize with None values
    'Purity': [None] * optimal_k,
    'Top_3_Genes': [None] * optimal_k
})

# Fill in dominant cancer type and purity for each cluster
for cluster_id in range(optimal_k):
    cluster_samples = cluster_label_df[cluster_label_df['Cluster'] == cluster_id]
    dominant_type = cluster_samples['True_Label'].mode()[0]
    purity = (cluster_samples['True_Label'] == dominant_type).mean()
    top_3_genes = ', '.join(list(top_genes_per_cluster[cluster_id].head(3).index))
    
    cluster_summary.loc[cluster_id, 'Dominant_Cancer_Type'] = dominant_type
    cluster_summary.loc[cluster_id, 'Purity'] = f"{purity:.1%}"
    cluster_summary.loc[cluster_id, 'Top_3_Genes'] = top_3_genes

print(cluster_summary.to_string(index=False))

## Cluster Interpretation & Gene Analysis

### Biological Insights from Discriminant Genes:

The analysis above identifies the **top discriminant genes** for each cluster - these are the genes that show the highest deviation from the overall population mean within each cluster.

### Key Findings:
- **Gene Expression Signatures**: Each cluster has a unique molecular signature defined by specific genes
- **Cluster Purity**: Shows how "pure" each cluster is in terms of cancer type composition
- **Discriminant Genes**: ↑ indicates overexpression, ↓ indicates underexpression relative to population mean

### Interpretation Approach:
1. **Identify characteristic genes** per cluster that distinguish it from others
2. **Map clusters to cancer types** to understand biological meaning
3. **Analyze gene expression patterns** to find potential biomarkers

This analysis bridges **unsupervised clustering** with **biological interpretation**, providing insights into the molecular mechanisms that distinguish different cancer types.

In [None]:
# Create comprehensive heatmap of top discriminant genes
print("Creating gene expression heatmaps...")

# Collect top 10 genes from each cluster
all_top_genes = set()
for cluster_id in range(optimal_k):
    top_10 = list(top_genes_per_cluster[cluster_id].head(10).index)
    all_top_genes.update(top_10)

all_top_genes = list(all_top_genes)
print(f"Total unique discriminant genes: {len(all_top_genes)}")

# Create heatmap data
heatmap_data = cluster_centroids[all_top_genes].T

# Plot comprehensive heatmap
plt.figure(figsize=(12, max(8, len(all_top_genes) * 0.3)))
sns.heatmap(heatmap_data, 
            annot=False, 
            cmap='RdBu_r', 
            center=0,
            xticklabels=[f'Cluster {i}' for i in range(optimal_k)],
            yticklabels=all_top_genes,
            cbar_kws={'label': 'Standardized Expression'})

plt.title(f'Gene Expression Heatmap: Top Discriminant Genes Across Clusters\n({len(all_top_genes)} genes)')
plt.xlabel('Cluster')
plt.ylabel('Genes')
plt.tight_layout()
plt.savefig('../figures/discriminant_genes_comprehensive_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Create individual heatmaps for each cluster's top genes
fig, axes = plt.subplots(2, 3, figsize=(20, 14))
axes = axes.flatten()

for cluster_id in range(optimal_k):
    ax = axes[cluster_id]
    
    # Get top 8 genes for this cluster
    top_genes = list(top_genes_per_cluster[cluster_id].head(8).index)
    
    # Create heatmap data for these genes across all clusters
    cluster_heatmap_data = cluster_centroids[top_genes].T
    
    # Plot heatmap
    sns.heatmap(cluster_heatmap_data, 
                annot=True, 
                fmt='.2f', 
                cmap='RdBu_r', 
                center=0,
                ax=ax,
                xticklabels=[f'C{i}' for i in range(optimal_k)],
                yticklabels=[gene.replace('gene_', '') for gene in top_genes],
                cbar_kws={'shrink': 0.8})
    
    # Get cluster info
    dominant_cancer = cluster_summary[cluster_summary['Cluster'] == cluster_id]['Dominant_Cancer_Type'].iloc[0]
    ax.set_title(f'Cluster {cluster_id}: {dominant_cancer}\nTop 8 Discriminant Genes')
    ax.set_xlabel('Cluster')
    ax.set_ylabel('Genes')

# Remove empty subplot
if len(axes) > optimal_k:
    axes[optimal_k].remove()

plt.tight_layout()
plt.savefig('../figures/cluster_specific_gene_heatmaps.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Cluster-Specific Gene Signatures

In [None]:
# Create detailed cluster interpretation table
cluster_interpretation = []

for cluster_id in range(optimal_k):
    # Get cluster info
    cluster_info = cluster_summary[cluster_summary['Cluster'] == cluster_id].iloc[0]
    
    # Get top 5 upregulated and downregulated genes
    cluster_expr = cluster_centroids.loc[cluster_id]
    fold_changes = cluster_expr - overall_means
    
    top_up = fold_changes.nlargest(5)
    top_down = fold_changes.nsmallest(5)
    
    interpretation = {
        'Cluster': cluster_id,
        'Dominant_Cancer_Type': cluster_info['Dominant_Cancer_Type'],
        'Size': cluster_info['Size'],
        'Purity': cluster_info['Purity_Percent'],
        'Top_Upregulated': ', '.join([f"{gene}({fc:.2f})" for gene, fc in top_up.items()]),
        'Top_Downregulated': ', '.join([f"{gene}({fc:.2f})" for gene, fc in top_down.items()])
    }
    
    cluster_interpretation.append(interpretation)

# Create interpretation DataFrame
interpretation_df = pd.DataFrame(cluster_interpretation)

print("CLUSTER GENE SIGNATURE ANALYSIS")
print("=" * 100)
for _, row in interpretation_df.iterrows():
    print(f"\nCluster {row['Cluster']} - {row['Dominant_Cancer_Type']} ({row['Purity']} purity, {row['Size']} samples)")
    print(f"  Top Upregulated:   {row['Top_Upregulated']}")
    print(f"  Top Downregulated: {row['Top_Downregulated']}")
    print("-" * 90)

## 6. Gene Expression Distribution Analysis

In [None]:
# Analyze distribution of gene expression across clusters
print("Analyzing gene expression distributions...")

# Select top 3 most discriminant genes overall
all_deviations = {}
for cluster_id in range(optimal_k):
    for gene, deviation in top_genes_per_cluster[cluster_id].items():
        if gene not in all_deviations:
            all_deviations[gene] = []
        all_deviations[gene].append(deviation)

# Calculate max deviation for each gene
max_deviations = {gene: max(devs) for gene, devs in all_deviations.items()}
top_global_genes = sorted(max_deviations.items(), key=lambda x: x[1], reverse=True)[:6]

# Create box plots for top discriminant genes
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, (gene, _) in enumerate(top_global_genes):
    ax = axes[idx]
    
    # Prepare data for box plot
    plot_data = []
    plot_labels = []
    
    for cluster_id in range(optimal_k):
        cluster_mask = cluster_assignments == cluster_id
        gene_idx = list(selected_gene_names).index(gene)
        cluster_expressions = X_reduced[cluster_mask, gene_idx]
        
        plot_data.append(cluster_expressions)
        plot_labels.append(f'C{cluster_id}')
    
    # Create box plot
    box_plot = ax.boxplot(plot_data, labels=plot_labels, patch_artist=True)
    
    # Color boxes
    colors = plt.cm.Set3(np.linspace(0, 1, optimal_k))
    for patch, color in zip(box_plot['boxes'], colors):
        patch.set_facecolor(color)
    
    ax.set_title(f'{gene}\n(Max Deviation: {max_deviations[gene]:.3f})')
    ax.set_xlabel('Cluster')
    ax.set_ylabel('Standardized Expression')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../figures/gene_expression_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Top {len(top_global_genes)} most discriminant genes globally:")
for gene, deviation in top_global_genes:
    print(f"  • {gene}: {deviation:.4f} max deviation")

## 7. Biological Interpretation & Insights

In [None]:
# Generate biological insights summary
print("BIOLOGICAL INTERPRETATION OF CLUSTERING RESULTS")
print("=" * 80)

# Calculate clustering success metrics
high_purity_clusters = interpretation_df[interpretation_df['Purity'].str.rstrip('%').astype(float) >= 90]
perfect_clusters = interpretation_df[interpretation_df['Purity'] == '100.0%']

print(f"\nClustering Success Analysis:")
print(f"  • Total clusters: {optimal_k}")
print(f"  • High purity clusters (≥90%): {len(high_purity_clusters)}")
print(f"  • Perfect clusters (100%): {len(perfect_clusters)}")
print(f"  • Average cluster purity: {interpretation_df['Purity'].str.rstrip('%').astype(float).mean():.1f}%")

print(f"\nKey Biological Findings:")
print(f"  • Gene expression patterns can distinguish cancer types with {clustering_metadata['overall_accuracy']:.1%} accuracy")
print(f"  • Molecular signatures are highly specific to cancer types")
print(f"  • Clustering identified {len(all_top_genes)} key discriminant genes")
print(f"  • Each cancer type has a unique gene expression profile")

print(f"\nCluster-Cancer Type Mapping:")
for _, row in interpretation_df.iterrows():
    print(f"  • Cluster {row['Cluster']}: {row['Dominant_Cancer_Type']} ({row['Purity']} purity)")

print(f"\nImplications for Cancer Research:")
print(f"  • These discriminant genes could serve as potential biomarkers")
print(f"  • Gene signatures can be used for cancer subtype classification")
print(f"  • Expression patterns reflect underlying molecular mechanisms")
print(f"  • Results support precision medicine approaches")

## 8. Save Interpretation Results

In [None]:
# Save all interpretation results
print("Saving cluster interpretation results...")

# 1. Save discriminant genes for each cluster
discriminant_genes_data = []
for cluster_id in range(optimal_k):
    for rank, (gene, deviation) in enumerate(top_genes_per_cluster[cluster_id].items(), 1):
        cluster_expr = cluster_centroids.loc[cluster_id, gene]
        overall_expr = overall_means[gene]
        fold_change = cluster_expr - overall_expr
        direction = "UP" if fold_change > 0 else "DOWN"
        
        discriminant_genes_data.append({
            'cluster': cluster_id,
            'rank': rank,
            'gene': gene,
            'deviation': deviation,
            'cluster_expression': cluster_expr,
            'overall_expression': overall_expr,
            'fold_change': fold_change,
            'direction': direction
        })

discriminant_genes_df = pd.DataFrame(discriminant_genes_data)
discriminant_genes_df.to_csv('../dataset/discriminant_genes.csv', index=False)
print("✓ Saved discriminant genes: ../dataset/discriminant_genes.csv")

# 2. Save cluster interpretation summary
interpretation_df.to_csv('../dataset/cluster_interpretation.csv', index=False)
print("✓ Saved cluster interpretation: ../dataset/cluster_interpretation.csv")

# 3. Save cluster centroids
cluster_centroids.to_csv('../dataset/cluster_centroids.csv')
print("✓ Saved cluster centroids: ../dataset/cluster_centroids.csv")

# 4. Save interpretation metadata
interpretation_metadata = {
    'n_discriminant_genes_per_cluster': n_top_genes,
    'total_unique_discriminant_genes': len(all_top_genes),
    'high_purity_clusters': len(high_purity_clusters),
    'perfect_clusters': len(perfect_clusters),
    'average_purity': float(interpretation_df['Purity'].str.rstrip('%').astype(float).mean()),
    'top_global_discriminant_genes': [gene for gene, _ in top_global_genes],
    'cluster_cancer_mapping': {
        row['Cluster']: {
            'cancer_type': row['Dominant_Cancer_Type'],
            'purity': row['Purity'],
            'size': row['Size']
        }
        for _, row in interpretation_df.iterrows()
    }
}

with open('../dataset/interpretation_metadata.json', 'w') as f:
    json.dump(interpretation_metadata, f, indent=2)
print("✓ Saved interpretation metadata: ../dataset/interpretation_metadata.json")

## 9. Final Summary and Next Steps