# Medical Image Fusion Models Comparison

This notebook compares three different medical image fusion models:
1. LRD (Laplacian Re-Decomposition)
2. NSST-PAPCNN (Non-Subsampled Shearlet Transform with Parameter-Adaptive Pulse Coupled Neural Network)
3. U2Fusion (Unified Unsupervised Image Fusion)

We will compare their performance using quantitative metrics and visual assessment.

In [None]:
# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pandas as pd
import time
import glob
import sys
from tqdm.notebook import tqdm
import seaborn as sns

# For evaluation metrics
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Import our evaluation module
import fusion_evaluation as fe

# Check Python version
print(f"Python version: {sys.version}")

# Check if we're running in the virtual environment
import site
print(f"Using Python from: {sys.executable}")
print(f"Site packages: {site.getsitepackages()}")

## Helper Functions for Loading Images

In [None]:
# Set path to dataset
base_dataset_path = "Medical_Image_Fusion_Methods/Havard-Medical-Image-Fusion-Datasets"

def load_image_pair(img_path1, img_path2, resize=True, img_size=256):
    """
    Load a pair of medical images from different modalities
    
    Args:
        img_path1: Path to first image (e.g., CT)
        img_path2: Path to second image (e.g., MRI)
        resize: Whether to resize images
        img_size: Target size for resizing
        
    Returns:
        A tuple containing both images as numpy arrays
    """
    # Read images
    img1 = cv2.imread(img_path1, cv2.IMREAD_GRAYSCALE)
    img2 = cv2.imread(img_path2, cv2.IMREAD_GRAYSCALE)
    
    # Check if images were loaded successfully
    if img1 is None or img2 is None:
        raise ValueError(f"Failed to load images: {img_path1} or {img_path2}")
    
    # Resize if needed
    if resize and (img1.shape[0] != img_size or img1.shape[1] != img_size):
        img1 = cv2.resize(img1, (img_size, img_size))
        img2 = cv2.resize(img2, (img_size, img_size))
    
    # Normalize to [0, 1]
    img1 = img1 / 255.0
    img2 = img2 / 255.0
    
    return img1, img2

def get_image_pairs(dataset_path, modality_pair, count=None):
    """
    Get paths to pairs of medical images
    
    Args:
        dataset_path: Base path to dataset
        modality_pair: Type of modality pair, e.g., 'CT-MRI', 'PET-MRI', 'SPECT-MRI'
        count: Number of pairs to return (None for all)
    
    Returns:
        List of tuples containing paths to image pairs
    """
    # Get full path to specific modality folder
    modality_path = os.path.join(dataset_path, modality_pair)
    
    # Split modality names
    modalities = modality_pair.split('-')
    mod1 = modalities[0].lower()  # e.g., ct
    mod2 = modalities[1].lower()  # e.g., mri
    
    # Get lists of image paths for each modality
    mod1_paths = sorted(glob.glob(os.path.join(modality_path, f"*_{mod1}.png")))
    mod2_paths = sorted(glob.glob(os.path.join(modality_path, f"*_{mod2}.png")))
    
    # Ensure same number of images for both modalities
    assert len(mod1_paths) == len(mod2_paths), "Number of images in both modalities should be the same"
    
    # Create pairs of image paths
    pairs = list(zip(mod1_paths, mod2_paths))
    
    # Limit number of pairs if specified
    if count is not None:
        pairs = pairs[:min(count, len(pairs))]
    
    return pairs

## Load Fused Results from Each Model

First, let's check what results are available from each model.

In [None]:
def get_available_results():
    """
    Get available fusion results from all models
    
    Returns:
        Dictionary with available results for each model
    """
    # Models to check
    models = ['LRD', 'NSST_PAPCNN', 'U2Fusion']
    
    # Dictionary to store available results
    available_results = {}
    
    # Check each model
    for model in models:
        model_dir = os.path.join('fused_images', model)
        if os.path.exists(model_dir):
            # Get list of result files
            result_files = sorted(glob.glob(os.path.join(model_dir, '*.png')))
            available_results[model] = result_files
            print(f"{model}: {len(result_files)} results available")
        else:
            print(f"{model}: Directory not found")
    
    return available_results

# Get available results
available_results = get_available_results()

## Find Common Image Pairs Processed by All Models

For fair comparison, we'll identify image pairs that have been processed by all models.

In [None]:
def find_common_pairs(available_results):
    """
    Find common image pairs processed by all models
    
    Args:
        available_results: Dictionary with available results for each model
        
    Returns:
        List of common pair names
    """
    # Get models with available results
    models_with_results = [model for model, results in available_results.items() if results]
    
    # If no models have results, return empty list
    if not models_with_results:
        return []
    
    # Extract pair names from first model
    first_model = models_with_results[0]
    first_model_pairs = [os.path.basename(path).split('.')[0] for path in available_results[first_model]]
    
    # Find common pairs across all models
    common_pairs = set(first_model_pairs)
    for model in models_with_results[1:]:
        model_pairs = [os.path.basename(path).split('.')[0] for path in available_results[model]]
        common_pairs &= set(model_pairs)
    
    return sorted(list(common_pairs))

# Find common pairs
common_pairs = find_common_pairs(available_results)
print(f"Found {len(common_pairs)} common image pairs processed by all models")

## Load Source Images and Fusion Results

Now, let's load the source images and fusion results for common pairs.

In [None]:
def extract_modality_from_filename(filename):
    """
    Extract modality names from paired filenames
    
    Args:
        filename: Paired filename (e.g., "1_ct_1_mri")
        
    Returns:
        Tuple with modality names and pair ID
    """
    parts = filename.split('_')
    if len(parts) == 4:  # Format: {id}_{mod1}_{id}_{mod2}
        mod1 = parts[1].upper()
        mod2 = parts[3].upper()
        pair_id = f"{parts[0]}_{parts[2]}"
        return mod1, mod2, pair_id
    return None, None, None

def get_source_images(pair_name, modality_pair="CT-MRI"):
    """
    Get source images for a pair name
    
    Args:
        pair_name: Name of the image pair
        modality_pair: Type of modality pair
        
    Returns:
        Tuple with source images and their paths
    """
    # Extract modalities from pair name
    mod1, mod2, pair_id = extract_modality_from_filename(pair_name)
    
    # If extraction failed, use the default modality pair
    if mod1 is None:
        mod1, mod2 = modality_pair.split('-')
    
    # Construct paths to source images
    modality_path = os.path.join(base_dataset_path, f"{mod1}-{mod2}")
    
    # Look for matching files
    mod1_files = glob.glob(os.path.join(modality_path, f"*_{mod1.lower()}.png"))
    mod2_files = glob.glob(os.path.join(modality_path, f"*_{mod2.lower()}.png"))
    
    # Try to find a matching pair based on the pair_name
    img1_path = None
    img2_path = None
    
    if pair_id:
        # Try to find by pair ID
        for file in mod1_files:
            if pair_id.split('_')[0] in os.path.basename(file):
                img1_path = file
                break
        
        for file in mod2_files:
            if pair_id.split('_')[1] in os.path.basename(file):
                img2_path = file
                break
    
    # If no specific match found, just use the first pair
    if img1_path is None or img2_path is None:
        if mod1_files and mod2_files:
            img1_path = mod1_files[0]
            img2_path = mod2_files[0]
    
    # If paths were found, load the images
    if img1_path and img2_path:
        img1, img2 = load_image_pair(img1_path, img2_path)
        return img1, img2, img1_path, img2_path
    
    return None, None, None, None

def load_all_results(pair_name, available_results):
    """
    Load fusion results for all models for a given pair
    
    Args:
        pair_name: Name of the image pair
        available_results: Dictionary with available results for each model
        
    Returns:
        Dictionary with fused images for each model
    """
    # Dictionary to store fused images
    fused_images = {}
    
    # Load results for each model
    for model, results in available_results.items():
        for result_path in results:
            if pair_name in os.path.basename(result_path):
                # Load fused image
                fused_img = cv2.imread(result_path, cv2.IMREAD_GRAYSCALE) / 255.0
                fused_images[model] = fused_img
                break
    
    return fused_images

# Show an example of loaded source images and fusion results
if common_pairs:
    example_pair = common_pairs[0]
    print(f"Example pair: {example_pair}")
    
    # Load source images
    img1, img2, img1_path, img2_path = get_source_images(example_pair)
    
    # Load fusion results
    fused_images = load_all_results(example_pair, available_results)
    
    if img1 is not None and img2 is not None and fused_images:
        # Display source images and fusion results
        plt.figure(figsize=(15, 10))
        
        # Source images
        plt.subplot(2, 3, 1)
        plt.imshow(img1, cmap='gray')
        plt.title(f"Source 1: {os.path.basename(img1_path)}")
        plt.axis('off')
        
        plt.subplot(2, 3, 2)
        plt.imshow(img2, cmap='gray')
        plt.title(f"Source 2: {os.path.basename(img2_path)}")
        plt.axis('off')
        
        # Fusion results
        for i, (model, fused_img) in enumerate(fused_images.items()):
            plt.subplot(2, 3, i+3)
            plt.imshow(fused_img, cmap='gray')
            plt.title(f"Fused: {model}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()
    else:
        print("Failed to load source images or fusion results.")
else:
    print("No common pairs found.")

## Quantitative Comparison of Models

Now, let's perform a quantitative comparison of all models using various metrics.

In [None]:
def compare_models_quantitatively(common_pairs, available_results):
    """
    Compare models quantitatively using various metrics
    
    Args:
        common_pairs: List of common pair names
        available_results: Dictionary with available results for each model
        
    Returns:
        DataFrame with metrics for each model and pair
    """
    # List to store results
    results = []
    
    # Models to evaluate
    models = list(available_results.keys())
    
    # Evaluate each pair
    for pair_name in tqdm(common_pairs, desc="Evaluating"):
        # Load source images
        img1, img2, img1_path, img2_path = get_source_images(pair_name)
        
        # Skip if source images couldn't be loaded
        if img1 is None or img2 is None:
            continue
        
        # Load fusion results
        fused_images = load_all_results(pair_name, available_results)
        
        # Calculate metrics for each model
        for model, fused_img in fused_images.items():
            # Calculate metrics
            metrics = fe.evaluate_fusion(fused_img, img1, img2)
            
            # Store results
            results.append({
                'pair': pair_name,
                'model': model,
                'psnr': metrics['psnr'],
                'ssim': metrics['ssim']
            })
    
    # Convert to DataFrame
    df_results = pd.DataFrame(results)
    
    return df_results

# Compare models quantitatively
if common_pairs:
    df_results = compare_models_quantitatively(common_pairs, available_results)
    
    if not df_results.empty:
        # Display results
        print("\nMetrics for each model and pair:")
        print(df_results)
        
        # Calculate summary statistics
        summary = df_results.groupby('model').agg(['mean', 'std', 'min', 'max'])
        print("\nSummary statistics:")
        print(summary)
        
        # Plot results
        plt.figure(figsize=(15, 10))
        
        # PSNR comparison
        plt.subplot(2, 2, 1)
        sns.boxplot(x='model', y='psnr', data=df_results)
        plt.title('PSNR Distribution')
        plt.xlabel('Model')
        plt.ylabel('PSNR (dB)')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # SSIM comparison
        plt.subplot(2, 2, 2)
        sns.boxplot(x='model', y='ssim', data=df_results)
        plt.title('SSIM Distribution')
        plt.xlabel('Model')
        plt.ylabel('SSIM')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Barplot for average PSNR
        plt.subplot(2, 2, 3)
        avg_psnr = df_results.groupby('model')['psnr'].mean()
        std_psnr = df_results.groupby('model')['psnr'].std()
        avg_psnr.plot(kind='bar', yerr=std_psnr, capsize=4, rot=0)
        plt.title('Average PSNR')
        plt.xlabel('Model')
        plt.ylabel('PSNR (dB)')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Barplot for average SSIM
        plt.subplot(2, 2, 4)
        avg_ssim = df_results.groupby('model')['ssim'].mean()
        std_ssim = df_results.groupby('model')['ssim'].std()
        avg_ssim.plot(kind='bar', yerr=std_ssim, capsize=4, rot=0)
        plt.title('Average SSIM')
        plt.xlabel('Model')
        plt.ylabel('SSIM')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.savefig('fused_images/model_comparison_detailed.png', dpi=300, bbox_inches='tight')
        plt.show()
        print("Saved detailed comparison chart to fused_images/model_comparison_detailed.png")
        
        # Save results to CSV
        df_results.to_csv('fused_images/model_comparison_metrics.csv', index=False)
        print("Saved metrics to fused_images/model_comparison_metrics.csv")
    else:
        print("No results to compare.")
else:
    print("No common pairs found.")

## Visual Comparison of Models

Let's create a function to visually compare the fusion results from all models.

In [None]:
def visual_comparison(common_pairs, available_results, num_examples=3):
    """
    Create a visual comparison of fusion results from all models
    
    Args:
        common_pairs: List of common pair names
        available_results: Dictionary with available results for each model
        num_examples: Number of example pairs to show
    """
    # Limit number of examples
    examples = common_pairs[:min(num_examples, len(common_pairs))]
    
    # Get models with results
    models = list(available_results.keys())
    
    # Number of columns for display (source images + models)
    num_cols = 2 + len(models)
    
    # Create figure for visual comparison
    plt.figure(figsize=(20, 5 * num_examples))
    
    # For each example pair
    for i, pair_name in enumerate(examples):
        # Load source images
        img1, img2, img1_path, img2_path = get_source_images(pair_name)
        
        # Skip if source images couldn't be loaded
        if img1 is None or img2 is None:
            continue
        
        # Load fusion results
        fused_images = load_all_results(pair_name, available_results)
        
        # Display source images
        plt.subplot(num_examples, num_cols, i * num_cols + 1)
        plt.imshow(img1, cmap='gray')
        plt.title(f"Source 1: {os.path.basename(img1_path)}")
        plt.axis('off')
        
        plt.subplot(num_examples, num_cols, i * num_cols + 2)
        plt.imshow(img2, cmap='gray')
        plt.title(f"Source 2: {os.path.basename(img2_path)}")
        plt.axis('off')
        
        # Display fusion results for each model
        for j, model in enumerate(models):
            if model in fused_images:
                plt.subplot(num_examples, num_cols, i * num_cols + j + 3)
                plt.imshow(fused_images[model], cmap='gray')
                
                # Calculate metrics
                metrics = fe.evaluate_fusion(fused_images[model], img1, img2)
                plt.title(f"{model}\nPSNR: {metrics['psnr']:.2f} dB, SSIM: {metrics['ssim']:.4f}")
                plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('fused_images/visual_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("Saved visual comparison to fused_images/visual_comparison.png")

# Perform visual comparison
if common_pairs:
    visual_comparison(common_pairs, available_results, num_examples=3)
else:
    print("No common pairs found.")

## Comparison of Computation Time

Let's compare the computation time of each model by running them on the same image pair.

In [None]:
def compare_computation_time(img1, img2, num_runs=5):
    """
    Compare computation time of different fusion models
    
    Args:
        img1: First source image
        img2: Second source image
        num_runs: Number of runs for each model
        
    Returns:
        Dictionary with average computation time for each model
    """
    # Import fusion models
    from run_lrd_fusion import lrd_fusion
    from run_nsst_papcnn_fusion import nsst_papcnn_fusion
    from run_u2fusion_fusion import u2fusion
    
    # Dictionary to store computation times
    computation_times = {}
    
    # LRD Fusion
    times = []
    for _ in range(num_runs):
        start_time = time.time()
        _ = lrd_fusion(img1, img2)
        times.append(time.time() - start_time)
    computation_times['LRD'] = np.mean(times)
    
    # NSST-PAPCNN Fusion
    times = []
    for _ in range(num_runs):
        start_time = time.time()
        _ = nsst_papcnn_fusion(img1, img2)
        times.append(time.time() - start_time)
    computation_times['NSST_PAPCNN'] = np.mean(times)
    
    # U2Fusion
    times = []
    for _ in range(num_runs):
        start_time = time.time()
        _ = u2fusion(img1, img2)
        times.append(time.time() - start_time)
    computation_times['U2Fusion'] = np.mean(times)
    
    return computation_times

# Compare computation time on a sample image pair
if common_pairs:
    sample_pair = common_pairs[0]
    img1, img2, _, _ = get_source_images(sample_pair)
    
    if img1 is not None and img2 is not None:
        print("Comparing computation time...")
        computation_times = compare_computation_time(img1, img2, num_runs=3)
        
        # Display results
        print("\nAverage computation time:")
        for model, time_taken in computation_times.items():
            print(f"{model}: {time_taken:.4f} seconds")
        
        # Plot results
        plt.figure(figsize=(10, 6))
        plt.bar(computation_times.keys(), computation_times.values())
        plt.xlabel('Model')
        plt.ylabel('Computation Time (seconds)')
        plt.title('Average Computation Time Comparison')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig('fused_images/computation_time_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()
        print("Saved computation time comparison to fused_images/computation_time_comparison.png")
    else:
        print("Failed to load source images.")
else:
    print("No common pairs found.")

## Detailed Analysis of Best Model

Based on the quantitative metrics and visual comparison, let's identify and analyze the best-performing model.

In [None]:
def identify_best_model(df_results):
    """
    Identify the best-performing model based on metrics
    
    Args:
        df_results: DataFrame with metrics for each model and pair
        
    Returns:
        Name of the best model
    """
    # Calculate average metrics for each model
    avg_metrics = df_results.groupby('model').mean()
    
    # Normalize metrics to [0, 1] range for fair comparison
    normalized_psnr = (avg_metrics['psnr'] - avg_metrics['psnr'].min()) / (avg_metrics['psnr'].max() - avg_metrics['psnr'].min())
    normalized_ssim = (avg_metrics['ssim'] - avg_metrics['ssim'].min()) / (avg_metrics['ssim'].max() - avg_metrics['ssim'].min())
    
    # Combined score (equal weight for PSNR and SSIM)
    combined_score = normalized_psnr * 0.5 + normalized_ssim * 0.5
    
    # Get best model
    best_model = combined_score.idxmax()
    
    # Print scores for all models
    print("Normalized Scores (higher is better):")
    print("\nPSNR:")
    for model, score in normalized_psnr.items():
        print(f"{model}: {score:.4f}")
    
    print("\nSSIM:")
    for model, score in normalized_ssim.items():
        print(f"{model}: {score:.4f}")
    
    print("\nCombined Score:")
    for model, score in combined_score.items():
        print(f"{model}: {score:.4f}")
    
    print(f"\nBest Model: {best_model}")
    
    return best_model

# Identify the best model
if 'df_results' in locals() and not df_results.empty:
    best_model = identify_best_model(df_results)
else:
    print("No results to analyze.")

## Summary and Conclusion

In this notebook, we have compared three different medical image fusion models:

1. LRD (Laplacian Re-Decomposition)
2. NSST-PAPCNN (Non-Subsampled Shearlet Transform with Parameter-Adaptive Pulse Coupled Neural Network)
3. U2Fusion (Unified Unsupervised Image Fusion)

We evaluated their performance using quantitative metrics (PSNR, SSIM) and visual assessment. We also compared their computation time.

Based on our analysis, we have identified the best-performing model, considering both the quality of the fused images and the computational efficiency.

The results of this comparison provide valuable insights for choosing the most appropriate fusion model for medical image fusion applications, depending on specific requirements such as fusion quality, computation time, and the type of medical images being processed.

In [None]:
# Export summary of findings
if 'df_results' in locals() and not df_results.empty and 'computation_times' in locals():
    # Create a comprehensive summary
    summary_dict = {
        'Model': [],
        'Average PSNR': [],
        'Average SSIM': [],
        'Computation Time': []
    }
    
    # Add data for each model
    for model in df_results['model'].unique():
        summary_dict['Model'].append(model)
        summary_dict['Average PSNR'].append(df_results[df_results['model'] == model]['psnr'].mean())
        summary_dict['Average SSIM'].append(df_results[df_results['model'] == model]['ssim'].mean())
        summary_dict['Computation Time'].append(computation_times.get(model, float('nan')))
    
    # Create DataFrame and save to CSV
    summary_df = pd.DataFrame(summary_dict)
    summary_df.to_csv('fused_images/model_comparison_summary.csv', index=False)
    
    # Display summary
    print("\nSummary of Findings:")
    print(summary_df)
    print("\nSaved summary to fused_images/model_comparison_summary.csv")