In [None]:
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations

def compare_heatmaps(heatmap_arrays, metrics=['mse', 'correlation', 'cosine', 'ssim']):
    """
    Compare similarity between multiple heatmap arrays.
    
    Args:
        heatmap_arrays: List of numpy arrays, each with shape (n_images, height, width)
        metrics: List of metrics to compute
        
    Returns:
        Dictionary containing results for each metric and each pair of arrays
    """
    n_arrays = len(heatmap_arrays)
    n_images = heatmap_arrays[0].shape[0]
    
    # Validate input shapes
    for i, arr in enumerate(heatmap_arrays):
        if arr.shape != heatmap_arrays[0].shape:
            raise ValueError(f"Array {i} has shape {arr.shape}, expected {heatmap_arrays[0].shape}")
    
    # Initialize results dictionary
    results = {metric: {} for metric in metrics}
    
    # Generate all pairs of arrays to compare
    array_pairs = list(combinations(range(n_arrays), 2))
    
    # For each pair of arrays
    for (i, j) in array_pairs:
        pair_key = f"{i}_vs_{j}"
        
        # Initialize per-image results
        if 'mse' in metrics:
            results['mse'][pair_key] = np.zeros(n_images)
        if 'correlation' in metrics:
            results['correlation'][pair_key] = np.zeros(n_images)
        if 'cosine' in metrics:
            results['cosine'][pair_key] = np.zeros(n_images)
        if 'ssim' in metrics:
            results['ssim'][pair_key] = np.zeros(n_images)
        
        # Compute metrics for each image pair
        for img_idx in range(n_images):
            heatmap1 = heatmap_arrays[i][img_idx]
            heatmap2 = heatmap_arrays[j][img_idx]
            
            # Flatten arrays for some metrics
            flat1 = heatmap1.flatten()
            flat2 = heatmap2.flatten()
            
            # Calculate metrics
            if 'mse' in metrics:
                results['mse'][pair_key][img_idx] = mean_squared_error(flat1, flat2)
            
            if 'correlation' in metrics:
                # Pearson correlation
                results['correlation'][pair_key][img_idx] = pearsonr(flat1, flat2)[0]
            
            if 'cosine' in metrics:
                # Reshape for cosine similarity
                results['cosine'][pair_key][img_idx] = cosine_similarity(
                    flat1.reshape(1, -1), flat2.reshape(1, -1)
                )[0][0]
            
            if 'ssim' in metrics:
                # Structural similarity
                results['ssim'][pair_key][img_idx] = ssim(
                    heatmap1, heatmap2, data_range=heatmap1.max() - heatmap1.min()
                )
    
    # Compute summary statistics
    summary = {metric: {} for metric in metrics}
    for metric in metrics:
        for pair_key in results[metric]:
            summary[metric][pair_key] = {
                'mean': np.mean(results[metric][pair_key]),
                'std': np.std(results[metric][pair_key]),
                'min': np.min(results[metric][pair_key]),
                'max': np.max(results[metric][pair_key]),
                'median': np.median(results[metric][pair_key])
            }
    
    return {'per_image': results, 'summary': summary}

def visualize_similarity_matrix(heatmap_arrays, metric='correlation'):
    """
    Visualize the similarity matrix between all heatmap arrays.
    
    Args:
        heatmap_arrays: List of numpy arrays
        metric: Metric to use for the matrix ('correlation', 'cosine', 'mse', 'ssim')
    """
    n_arrays = len(heatmap_arrays)
    similarity_matrix = np.zeros((n_arrays, n_arrays))
    
    # Compute average metric for each pair
    for i in range(n_arrays):
        for j in range(n_arrays):
            if i == j:
                # For MSE, same array means 0 error
                # For correlation and cosine, same array means perfect correlation (1.0)
                if metric == 'mse':
                    similarity_matrix[i, j] = 0
                else:
                    similarity_matrix[i, j] = 1
            else:
                # For different arrays, compute the metric
                heatmap1 = heatmap_arrays[i]
                heatmap2 = heatmap_arrays[j]
                
                # Compute average metric across all images
                metric_values = []
                for img_idx in range(heatmap1.shape[0]):
                    img1 = heatmap1[img_idx].flatten()
                    img2 = heatmap2[img_idx].flatten()
                    
                    if metric == 'correlation':
                        metric_values.append(pearsonr(img1, img2)[0])
                    elif metric == 'cosine':
                        metric_values.append(cosine_similarity(img1.reshape(1, -1), img2.reshape(1, -1))[0][0])
                    elif metric == 'mse':
                        metric_values.append(mean_squared_error(img1, img2))
                    elif metric == 'ssim':
                        metric_values.append(ssim(
                            heatmap1[img_idx], 
                            heatmap2[img_idx], 
                            data_range=heatmap1[img_idx].max() - heatmap1[img_idx].min()
                        ))
                
                similarity_matrix[i, j] = np.mean(metric_values)
    
    # Plot matrix
    plt.figure(figsize=(10, 8))
    
    # For MSE, lower is better (more similar), so we invert the colormap
    if metric == 'mse':
        cmap = 'YlOrRd_r'  # Reversed colormap
        vmin = 0
        vmax = np.max(similarity_matrix)
        title = f'Mean Squared Error (lower = more similar)'
    else:
        cmap = 'YlGnBu'
        vmin = -1 if metric == 'correlation' else 0
        vmax = 1
        title = f'{metric.capitalize()} Similarity (higher = more similar)'
    
    sns.heatmap(similarity_matrix, annot=True, fmt=".3f", 
                cmap=cmap, square=True, vmin=vmin, vmax=vmax,
                xticklabels=[f'Array {i}' for i in range(n_arrays)],
                yticklabels=[f'Array {i}' for i in range(n_arrays)])
    
    plt.title(title)
    plt.tight_layout()
    plt.show()
    
    return similarity_matrix