# Visualization Gallery for SUBMARIT

This notebook showcases various visualization techniques for analyzing and presenting clustering results from SUBMARIT. We'll cover:
- Substitution matrix visualizations
- Cluster visualization techniques
- Statistical result plots
- Interactive visualizations
- Publication-quality figures
- Custom visualization functions

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm
from matplotlib.patches import Rectangle, Circle
from matplotlib.collections import PatchCollection
import plotly.graph_objects as go
import plotly.express as px
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
import networkx as nx
import warnings
warnings.filterwarnings('ignore')

# Import SUBMARIT modules
from submarit.algorithms import KSMLocalSearch
from submarit.evaluation import ClusterEvaluator, EvaluationVisualizer
from submarit.validation import RandIndex, run_clusters_topk

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Set random seed
np.random.seed(42)

print("Libraries imported successfully!")

## 1. Substitution Matrix Visualizations

Let's explore different ways to visualize substitution matrices.

In [None]:
# Create sample data
def create_sample_matrix(n_products=30, n_clusters=3):
    """Create a sample substitution matrix with clear cluster structure."""
    matrix = np.zeros((n_products, n_products))
    cluster_size = n_products // n_clusters
    
    for i in range(n_clusters):
        start = i * cluster_size
        end = (i + 1) * cluster_size if i < n_clusters - 1 else n_products
        
        # High within-cluster substitution
        block = np.random.uniform(0.7, 0.95, (end - start, end - start))
        matrix[start:end, start:end] = block
    
    # Low between-cluster substitution
    noise = np.random.uniform(0, 0.2, (n_products, n_products))
    matrix = np.maximum(matrix, noise)
    
    # Make symmetric and remove diagonal
    matrix = (matrix + matrix.T) / 2
    np.fill_diagonal(matrix, 0)
    
    return matrix

# Create sample data
substitution_matrix = create_sample_matrix(30, 3)

# Run clustering
search = KSMLocalSearch(n_clusters=3, random_state=42)
result = search.fit(substitution_matrix)

In [None]:
# 1.1 Basic Heatmap with Cluster Boundaries
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Original matrix
ax = axes[0]
sns.heatmap(substitution_matrix, cmap='YlOrRd', ax=ax, cbar_kws={'label': 'Substitution Score'})
ax.set_title('Original Substitution Matrix', fontsize=14)
ax.set_xlabel('Product')
ax.set_ylabel('Product')

# Reordered matrix
ax = axes[1]
sorted_indices = np.argsort(result.labels)
reordered_matrix = substitution_matrix[sorted_indices][:, sorted_indices]
im = ax.imshow(reordered_matrix, cmap='YlOrRd', aspect='auto')

# Add cluster boundaries
cluster_sizes = [np.sum(result.labels == i) for i in range(3)]
boundaries = np.cumsum([0] + cluster_sizes[:-1])
for boundary in boundaries[1:]:
    ax.axhline(y=boundary-0.5, color='blue', linewidth=3)
    ax.axvline(x=boundary-0.5, color='blue', linewidth=3)

ax.set_title('Reordered Matrix with Cluster Boundaries', fontsize=14)
ax.set_xlabel('Product (reordered)')
ax.set_ylabel('Product (reordered)')
plt.colorbar(im, ax=ax, label='Substitution Score')

# Cluster-averaged matrix
ax = axes[2]
cluster_avg_matrix = np.zeros((3, 3))
for i in range(3):
    for j in range(3):
        mask_i = result.labels == i
        mask_j = result.labels == j
        if i != j:
            cluster_avg_matrix[i, j] = substitution_matrix[mask_i][:, mask_j].mean()

sns.heatmap(cluster_avg_matrix, annot=True, fmt='.3f', cmap='YlOrRd', ax=ax,
            xticklabels=[f'Cluster {i}' for i in range(3)],
            yticklabels=[f'Cluster {i}' for i in range(3)])
ax.set_title('Cluster-Averaged Substitution', fontsize=14)

plt.tight_layout()
plt.show()

In [None]:
# 1.2 Matrix with Annotations and Dendrograms
from scipy.cluster import hierarchy

# Calculate linkage for dendrogram
distance_matrix = 1 - substitution_matrix
condensed_distances = squareform(distance_matrix)
linkage_matrix = linkage(condensed_distances, method='ward')

# Create figure with dendrograms
fig = plt.figure(figsize=(12, 10))

# Dendrogram on the left
ax1 = fig.add_axes([0.09, 0.1, 0.2, 0.6])
Z1 = dendrogram(linkage_matrix, orientation='left', ax=ax1)
ax1.set_xticks([])

# Dendrogram on top
ax2 = fig.add_axes([0.3, 0.71, 0.6, 0.2])
Z2 = dendrogram(linkage_matrix, ax=ax2)
ax2.set_xticks([])
ax2.set_yticks([])

# Reorder matrix based on dendrogram
idx1 = Z1['leaves']
idx2 = Z2['leaves']
reordered_matrix = substitution_matrix[idx1][:, idx2]

# Main heatmap
axmatrix = fig.add_axes([0.3, 0.1, 0.6, 0.6])
im = axmatrix.matshow(reordered_matrix, aspect='auto', cmap='YlOrRd')
axmatrix.set_xticks([])
axmatrix.set_yticks([])

# Color bar
axcolor = fig.add_axes([0.91, 0.1, 0.02, 0.6])
plt.colorbar(im, cax=axcolor)

fig.suptitle('Hierarchical Clustering Dendrogram with Heatmap', fontsize=16)
plt.show()

## 2. Cluster Visualization Techniques

In [None]:
# 2.1 Network Graph Visualization
# Create network from substitution matrix
threshold = 0.5  # Only show strong substitution relationships
G = nx.Graph()

# Add nodes
for i in range(len(substitution_matrix)):
    G.add_node(i, cluster=result.labels[i])

# Add edges for strong substitution relationships
for i in range(len(substitution_matrix)):
    for j in range(i+1, len(substitution_matrix)):
        if substitution_matrix[i, j] > threshold:
            G.add_edge(i, j, weight=substitution_matrix[i, j])

# Create layout
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)

# Plot
plt.figure(figsize=(12, 8))

# Draw nodes colored by cluster
colors = ['red', 'blue', 'green']
for cluster in range(3):
    nodes = [n for n in G.nodes() if G.nodes[n]['cluster'] == cluster]
    nx.draw_networkx_nodes(G, pos, nodelist=nodes, node_color=colors[cluster], 
                          node_size=500, label=f'Cluster {cluster}')

# Draw edges with thickness based on weight
edges = G.edges()
weights = [G[u][v]['weight'] for u, v in edges]
nx.draw_networkx_edges(G, pos, width=[w*3 for w in weights], alpha=0.3)

# Draw labels
nx.draw_networkx_labels(G, pos, font_size=10)

plt.title('Product Substitution Network', fontsize=16)
plt.legend(loc='upper right')
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# 2.2 Chord Diagram (Circular Plot)
from matplotlib.patches import Arc, Wedge
from matplotlib.collections import PatchCollection

fig, ax = plt.subplots(figsize=(10, 10))

# Calculate positions on circle
n_products = len(substitution_matrix)
angles = np.linspace(0, 2*np.pi, n_products, endpoint=False)
radius = 1

# Plot products as points on circle
x = radius * np.cos(angles)
y = radius * np.sin(angles)

# Color by cluster
colors = ['red', 'blue', 'green']
for i in range(n_products):
    cluster = result.labels[i]
    ax.scatter(x[i], y[i], s=200, c=colors[cluster], zorder=5)
    
    # Add product labels
    angle_deg = angles[i] * 180 / np.pi
    if 90 < angle_deg < 270:
        ax.text(x[i]*1.15, y[i]*1.15, f'P{i}', ha='right', va='center', 
               rotation=angle_deg-180, fontsize=8)
    else:
        ax.text(x[i]*1.15, y[i]*1.15, f'P{i}', ha='left', va='center',
               rotation=angle_deg, fontsize=8)

# Draw connections for high substitution scores
threshold = 0.7
for i in range(n_products):
    for j in range(i+1, n_products):
        if substitution_matrix[i, j] > threshold:
            # Draw curved line
            alpha = substitution_matrix[i, j]
            if result.labels[i] == result.labels[j]:
                # Same cluster - use cluster color
                color = colors[result.labels[i]]
                linewidth = 2
            else:
                # Different clusters - use gray
                color = 'gray'
                linewidth = 1
            
            ax.plot([x[i], x[j]], [y[i], y[j]], color=color, 
                   alpha=alpha*0.5, linewidth=linewidth)

# Add cluster arcs
for cluster in range(3):
    cluster_indices = np.where(result.labels == cluster)[0]
    if len(cluster_indices) > 0:
        cluster_angles = angles[cluster_indices]
        min_angle = np.min(cluster_angles) * 180 / np.pi
        max_angle = np.max(cluster_angles) * 180 / np.pi
        
        arc = Arc((0, 0), 2.3, 2.3, angle=0, theta1=min_angle, 
                 theta2=max_angle, color=colors[cluster], linewidth=5)
        ax.add_patch(arc)

ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Circular Product Substitution Diagram', fontsize=16, pad=20)

# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=colors[i], label=f'Cluster {i}') for i in range(3)]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()

## 3. Statistical Result Visualizations

In [None]:
# 3.1 Evaluation Metrics Dashboard
# Run evaluation
evaluator = ClusterEvaluator()
eval_result = evaluator.evaluate(substitution_matrix, result.labels)

# Create dashboard
fig = plt.figure(figsize=(15, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# 1. Silhouette plot
ax1 = fig.add_subplot(gs[0, :])
from sklearn.metrics import silhouette_samples
silhouette_vals = silhouette_samples(substitution_matrix, result.labels)
y_lower = 10
colors_list = ['red', 'blue', 'green']

for i in range(3):
    cluster_silhouette_vals = silhouette_vals[result.labels == i]
    cluster_silhouette_vals.sort()
    
    size_cluster_i = cluster_silhouette_vals.shape[0]
    y_upper = y_lower + size_cluster_i
    
    ax1.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals,
                     facecolor=colors_list[i], edgecolor=colors_list[i], alpha=0.7)
    ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
    y_lower = y_upper + 10

ax1.axvline(x=eval_result.silhouette_score, color="red", linestyle="--", 
           label=f'Average: {eval_result.silhouette_score:.3f}')
ax1.set_xlabel('Silhouette Coefficient')
ax1.set_ylabel('Cluster')
ax1.set_title('Silhouette Analysis', fontsize=14)
ax1.legend()

# 2. Metric comparison
ax2 = fig.add_subplot(gs[1, 0])
metrics = ['Silhouette', 'Davies-Bouldin\n(inverted)', 'Calinski-Harabasz\n(normalized)']
values = [eval_result.silhouette_score, 
          1/(1+eval_result.davies_bouldin_index),
          min(1, eval_result.calinski_harabasz_score/1000)]
bars = ax2.bar(metrics, values, color=['skyblue', 'lightcoral', 'lightgreen'])
ax2.set_ylim(0, 1)
ax2.set_ylabel('Score')
ax2.set_title('Clustering Quality Metrics', fontsize=14)
for bar, val in zip(bars, values):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{val:.3f}', ha='center', va='bottom')

# 3. Cluster sizes pie chart
ax3 = fig.add_subplot(gs[1, 1])
cluster_sizes = [np.sum(result.labels == i) for i in range(3)]
ax3.pie(cluster_sizes, labels=[f'Cluster {i}' for i in range(3)], 
        colors=colors_list, autopct='%1.1f%%', startangle=90)
ax3.set_title('Cluster Size Distribution', fontsize=14)

# 4. Within vs Between cluster similarity
ax4 = fig.add_subplot(gs[1, 2])
categories = ['Within-cluster', 'Between-cluster']
similarities = [eval_result.within_cluster_similarity, eval_result.between_cluster_similarity]
bars = ax4.bar(categories, similarities, color=['green', 'orange'])
ax4.set_ylabel('Average Similarity')
ax4.set_title('Cluster Cohesion vs Separation', fontsize=14)
ax4.set_ylim(0, 1)
for bar, val in zip(bars, similarities):
    ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{val:.3f}', ha='center', va='bottom')

# 5. Stability analysis (simulated)
ax5 = fig.add_subplot(gs[2, :])
n_bootstrap = 20
stability_scores = []
for _ in range(n_bootstrap):
    # Add noise to matrix
    noisy_matrix = substitution_matrix + np.random.normal(0, 0.05, substitution_matrix.shape)
    noisy_matrix = np.clip(noisy_matrix, 0, 1)
    noisy_matrix = (noisy_matrix + noisy_matrix.T) / 2
    np.fill_diagonal(noisy_matrix, 0)
    
    # Re-cluster
    noisy_result = search.fit(noisy_matrix)
    
    # Calculate similarity with original
    rand_calc = RandIndex()
    rand_result = rand_calc.compute(result.labels, noisy_result.labels)
    stability_scores.append(rand_result.adjusted_rand_index)

ax5.hist(stability_scores, bins=10, color='purple', alpha=0.7, edgecolor='black')
ax5.axvline(np.mean(stability_scores), color='red', linestyle='--', 
           label=f'Mean: {np.mean(stability_scores):.3f}')
ax5.set_xlabel('Adjusted Rand Index')
ax5.set_ylabel('Frequency')
ax5.set_title('Clustering Stability (Bootstrap Analysis)', fontsize=14)
ax5.legend()

plt.suptitle('Clustering Evaluation Dashboard', fontsize=16)
plt.tight_layout()
plt.show()

## 4. Interactive Visualizations with Plotly

In [None]:
# 4.1 Interactive Heatmap
# Prepare data
sorted_indices = np.argsort(result.labels)
reordered_matrix = substitution_matrix[sorted_indices][:, sorted_indices]
product_names = [f'Product {i}' for i in sorted_indices]
cluster_labels = result.labels[sorted_indices]

# Create hover text
hover_text = []
for i in range(len(reordered_matrix)):
    hover_row = []
    for j in range(len(reordered_matrix)):
        hover_row.append(f'Products: {sorted_indices[i]}, {sorted_indices[j]}<br>'
                        f'Clusters: {cluster_labels[i]}, {cluster_labels[j]}<br>'
                        f'Substitution: {reordered_matrix[i,j]:.3f}')
    hover_text.append(hover_row)

# Create interactive heatmap
fig = go.Figure(data=go.Heatmap(
    z=reordered_matrix,
    x=product_names,
    y=product_names,
    colorscale='YlOrRd',
    hovertext=hover_text,
    hoverinfo='text',
    colorbar=dict(title='Substitution Score')
))

# Add cluster boundaries
cluster_sizes = [np.sum(cluster_labels == i) for i in range(3)]
boundaries = np.cumsum([0] + cluster_sizes[:-1])

shapes = []
for i, boundary in enumerate(boundaries[1:]):
    shapes.extend([
        dict(type='line', x0=boundary-0.5, x1=boundary-0.5, y0=-0.5, y1=len(reordered_matrix)-0.5,
             line=dict(color='blue', width=3)),
        dict(type='line', x0=-0.5, x1=len(reordered_matrix)-0.5, y0=boundary-0.5, y1=boundary-0.5,
             line=dict(color='blue', width=3))
    ])

fig.update_layout(
    title='Interactive Substitution Matrix Heatmap',
    xaxis_title='Product',
    yaxis_title='Product',
    shapes=shapes,
    width=800,
    height=800
)

fig.show()

In [None]:
# 4.2 3D Network Visualization
# Create 3D layout for network
pos_3d = nx.spring_layout(G, dim=3, k=3, iterations=50, seed=42)

# Extract node positions
node_x = [pos_3d[node][0] for node in G.nodes()]
node_y = [pos_3d[node][1] for node in G.nodes()]
node_z = [pos_3d[node][2] for node in G.nodes()]

# Extract edge positions
edge_x = []
edge_y = []
edge_z = []
for edge in G.edges():
    x0, y0, z0 = pos_3d[edge[0]]
    x1, y1, z1 = pos_3d[edge[1]]
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])
    edge_z.extend([z0, z1, None])

# Create edge trace
edge_trace = go.Scatter3d(
    x=edge_x, y=edge_y, z=edge_z,
    line=dict(width=0.5, color='#888'),
    hoverinfo='none',
    mode='lines')

# Create node trace
node_colors = [colors[result.labels[node]] for node in G.nodes()]
node_text = [f'Product {node}<br>Cluster {result.labels[node]}' for node in G.nodes()]

node_trace = go.Scatter3d(
    x=node_x, y=node_y, z=node_z,
    mode='markers+text',
    hoverinfo='text',
    text=[str(node) for node in G.nodes()],
    hovertext=node_text,
    marker=dict(
        size=10,
        color=node_colors,
        line=dict(width=2, color='white')
    )
)

# Create figure
fig = go.Figure(data=[edge_trace, node_trace])

fig.update_layout(
    title='3D Product Substitution Network',
    showlegend=False,
    hovermode='closest',
    margin=dict(b=0, l=0, r=0, t=40),
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        zaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
    ),
    width=800,
    height=600
)

fig.show()

## 5. Publication-Quality Figures

In [None]:
# 5.1 Publication-ready matrix visualization
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 12,
    'figure.dpi': 300
})

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Left panel: Substitution matrix with annotations
sorted_indices = np.argsort(result.labels)
reordered_matrix = substitution_matrix[sorted_indices][:, sorted_indices]

im1 = ax1.imshow(reordered_matrix, cmap='RdBu_r', aspect='auto', vmin=0, vmax=1)

# Add cluster labels
cluster_sizes = [np.sum(result.labels == i) for i in range(3)]
boundaries = np.cumsum([0] + cluster_sizes)
cluster_centers = [(boundaries[i] + boundaries[i+1]) / 2 for i in range(3)]

for i, (center, size) in enumerate(zip(cluster_centers, cluster_sizes)):
    ax1.text(-2, center-0.5, f'C{i+1}', ha='right', va='center', fontweight='bold')
    ax1.text(center-0.5, -2, f'C{i+1}', ha='center', va='bottom', fontweight='bold')

# Add grid lines for clusters
for boundary in boundaries[1:-1]:
    ax1.axhline(y=boundary-0.5, color='black', linewidth=2)
    ax1.axvline(x=boundary-0.5, color='black', linewidth=2)

ax1.set_xlim(-0.5, len(reordered_matrix)-0.5)
ax1.set_ylim(len(reordered_matrix)-0.5, -0.5)
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_xlabel('Product Index', labelpad=15)
ax1.set_ylabel('Product Index', labelpad=15)
ax1.set_title('(a) Clustered Substitution Matrix', pad=20)

# Add colorbar
cbar1 = plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
cbar1.set_label('Substitution Probability', rotation=270, labelpad=20)

# Right panel: Cluster quality metrics
metrics_names = ['Silhouette\nCoefficient', 'Calinski-\nHarabasz', 'Davies-\nBouldin']
metrics_values = [eval_result.silhouette_score, 
                 eval_result.calinski_harabasz_score/1000,  # Normalized
                 1/(1+eval_result.davies_bouldin_index)]  # Inverted for consistency

x_pos = np.arange(len(metrics_names))
bars = ax2.bar(x_pos, metrics_values, color=['#1f77b4', '#ff7f0e', '#2ca02c'], 
               edgecolor='black', linewidth=1.5)

# Add value labels
for bar, value in zip(bars, metrics_values):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
            f'{value:.3f}', ha='center', va='bottom', fontweight='bold')

ax2.set_xticks(x_pos)
ax2.set_xticklabels(metrics_names)
ax2.set_ylabel('Normalized Score')
ax2.set_ylim(0, 1.2)
ax2.set_title('(b) Clustering Quality Metrics', pad=20)
ax2.grid(axis='y', alpha=0.3, linestyle='--')

# Add horizontal reference line
ax2.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Acceptable threshold')
ax2.legend(loc='upper right')

plt.tight_layout()
plt.savefig('clustering_results_publication.pdf', dpi=300, bbox_inches='tight')
plt.show()

# Reset to default
plt.rcParams.update(plt.rcParamsDefault)

## 6. Custom Visualization Functions

In [None]:
# 6.1 Custom visualization class
class SubmaritVisualizer:
    """Custom visualization utilities for SUBMARIT results."""
    
    def __init__(self, style='seaborn'):
        self.style = style
        self.colors = sns.color_palette("husl", 10)
    
    def plot_cluster_evolution(self, matrices, labels_list, titles=None):
        """Plot evolution of clustering across multiple time periods or parameters."""
        n_plots = len(matrices)
        fig, axes = plt.subplots(1, n_plots, figsize=(5*n_plots, 5))
        
        if n_plots == 1:
            axes = [axes]
        
        for i, (matrix, labels, ax) in enumerate(zip(matrices, labels_list, axes)):
            # Reorder matrix
            sorted_idx = np.argsort(labels)
            reordered = matrix[sorted_idx][:, sorted_idx]
            
            # Plot
            im = ax.imshow(reordered, cmap='YlOrRd', aspect='auto')
            
            # Add cluster boundaries
            unique_labels = np.unique(labels)
            for label in unique_labels[:-1]:
                boundary = np.sum(labels[sorted_idx] <= label)
                ax.axhline(y=boundary-0.5, color='blue', linewidth=2)
                ax.axvline(x=boundary-0.5, color='blue', linewidth=2)
            
            title = titles[i] if titles else f'Time {i+1}'
            ax.set_title(title)
            ax.set_xticks([])
            ax.set_yticks([])
        
        plt.tight_layout()
        return fig
    
    def plot_alluvial_diagram(self, labels1, labels2, label1_name='Time 1', label2_name='Time 2'):
        """Create an alluvial diagram showing cluster membership changes."""
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Count transitions
        n_clusters1 = len(np.unique(labels1))
        n_clusters2 = len(np.unique(labels2))
        transition_matrix = np.zeros((n_clusters1, n_clusters2))
        
        for i, (l1, l2) in enumerate(zip(labels1, labels2)):
            transition_matrix[l1, l2] += 1
        
        # Plot flows
        y1_positions = np.linspace(0, 1, n_clusters1)
        y2_positions = np.linspace(0, 1, n_clusters2)
        
        for i in range(n_clusters1):
            for j in range(n_clusters2):
                if transition_matrix[i, j] > 0:
                    # Draw flow
                    flow_width = transition_matrix[i, j] / len(labels1) * 0.8
                    
                    # Create curved path
                    x = np.linspace(0, 1, 100)
                    y = y1_positions[i] + (y2_positions[j] - y1_positions[i]) * x
                    y += 0.1 * np.sin(np.pi * x) * (y2_positions[j] - y1_positions[i])
                    
                    ax.fill_between(x, y - flow_width/2, y + flow_width/2, 
                                   alpha=0.6, color=self.colors[i])
        
        # Add cluster rectangles
        for i in range(n_clusters1):
            height = np.sum(transition_matrix[i, :]) / len(labels1) * 0.8
            rect = plt.Rectangle((-0.05, y1_positions[i] - height/2), 0.05, height,
                               color=self.colors[i], ec='black')
            ax.add_patch(rect)
            ax.text(-0.1, y1_positions[i], f'C{i}', ha='right', va='center')
        
        for j in range(n_clusters2):
            height = np.sum(transition_matrix[:, j]) / len(labels2) * 0.8
            rect = plt.Rectangle((1, y2_positions[j] - height/2), 0.05, height,
                               color=self.colors[j], ec='black')
            ax.add_patch(rect)
            ax.text(1.1, y2_positions[j], f'C{j}', ha='left', va='center')
        
        ax.set_xlim(-0.2, 1.2)
        ax.set_ylim(-0.1, 1.1)
        ax.set_xticks([0, 1])
        ax.set_xticklabels([label1_name, label2_name])
        ax.set_yticks([])
        ax.set_title('Cluster Membership Flow', fontsize=16)
        
        return fig

# Example usage
visualizer = SubmaritVisualizer()

# Simulate evolution
matrices_evolution = []
labels_evolution = []

for noise_level in [0, 0.1, 0.2]:
    noisy_matrix = substitution_matrix + np.random.normal(0, noise_level, substitution_matrix.shape)
    noisy_matrix = np.clip(noisy_matrix, 0, 1)
    noisy_matrix = (noisy_matrix + noisy_matrix.T) / 2
    np.fill_diagonal(noisy_matrix, 0)
    
    result_noisy = search.fit(noisy_matrix)
    
    matrices_evolution.append(noisy_matrix)
    labels_evolution.append(result_noisy.labels)

# Plot evolution
fig = visualizer.plot_cluster_evolution(
    matrices_evolution, 
    labels_evolution,
    titles=[f'Noise Level: {n}' for n in [0, 0.1, 0.2]]
)
plt.show()

# Plot alluvial diagram
fig = visualizer.plot_alluvial_diagram(
    labels_evolution[0], 
    labels_evolution[2],
    'No Noise', 
    'High Noise'
)
plt.show()

## Summary

This notebook demonstrated various visualization techniques for SUBMARIT results:

1. **Matrix Visualizations**: Heatmaps, dendrograms, reordered matrices
2. **Cluster Visualizations**: Networks, chord diagrams, alluvial plots
3. **Statistical Visualizations**: Silhouette plots, metric dashboards, stability analysis
4. **Interactive Visualizations**: Plotly heatmaps, 3D networks
5. **Publication Quality**: Professional figures with proper formatting
6. **Custom Functions**: Reusable visualization utilities

### Best Practices:

- Choose visualizations that best communicate your specific insights
- Use color consistently across related plots
- Add appropriate annotations and labels
- Consider your audience (technical vs. general)
- Save high-resolution versions for publications
- Make interactive versions for exploration
- Document visualization parameters for reproducibility