# U2Fusion Image Fusion Model Evaluation

This notebook runs the U2Fusion (Unified Unsupervised Image Fusion) model and saves the results to a dedicated folder. U2Fusion is an advanced deep learning-based method that doesn't require training data, making it useful for medical image fusion.

In [None]:
# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import time
import glob
import sys
from tqdm.notebook import tqdm

# For evaluation metrics
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy import ndimage

# Import our evaluation module
import fusion_evaluation as fe

# Check Python version
print(f"Python version: {sys.version}")

# Check if we're running in the virtual environment
import site
print(f"Using Python from: {sys.executable}")
print(f"Site packages: {site.getsitepackages()}")

## Import Helper Functions

First, let's import some helper functions for loading and processing images.

In [None]:
# Set path to dataset
base_dataset_path = "Medical_Image_Fusion_Methods/Havard-Medical-Image-Fusion-Datasets"
modality_pair = "CT-MRI"  # Can be changed to PET-MRI or SPECT-MRI

def load_image_pair(img_path1, img_path2, resize=True, img_size=256):
    """
    Load a pair of medical images from different modalities
    
    Args:
        img_path1: Path to first image (e.g., CT)
        img_path2: Path to second image (e.g., MRI)
        resize: Whether to resize images
        img_size: Target size for resizing
        
    Returns:
        A tuple containing both images as numpy arrays
    """
    # Read images
    img1 = cv2.imread(img_path1, cv2.IMREAD_GRAYSCALE)
    img2 = cv2.imread(img_path2, cv2.IMREAD_GRAYSCALE)
    
    # Check if images were loaded successfully
    if img1 is None or img2 is None:
        raise ValueError(f"Failed to load images: {img_path1} or {img_path2}")
    
    # Resize if needed
    if resize and (img1.shape[0] != img_size or img1.shape[1] != img_size):
        img1 = cv2.resize(img1, (img_size, img_size))
        img2 = cv2.resize(img2, (img_size, img_size))
    
    # Normalize to [0, 1]
    img1 = img1 / 255.0
    img2 = img2 / 255.0
    
    return img1, img2

def get_image_pairs(dataset_path, modality_pair, count=None):
    """
    Get paths to pairs of medical images
    
    Args:
        dataset_path: Base path to dataset
        modality_pair: Type of modality pair, e.g., 'CT-MRI', 'PET-MRI', 'SPECT-MRI'
        count: Number of pairs to return (None for all)
    
    Returns:
        List of tuples containing paths to image pairs
    """
    # Get full path to specific modality folder
    modality_path = os.path.join(dataset_path, modality_pair)
    
    # Split modality names
    modalities = modality_pair.split('-')
    mod1 = modalities[0].lower()  # e.g., ct
    mod2 = modalities[1].lower()  # e.g., mri
    
    # Get lists of image paths for each modality
    mod1_paths = sorted(glob.glob(os.path.join(modality_path, f"*_{mod1}.png")))
    mod2_paths = sorted(glob.glob(os.path.join(modality_path, f"*_{mod2}.png")))
    
    # Ensure same number of images for both modalities
    assert len(mod1_paths) == len(mod2_paths), "Number of images in both modalities should be the same"
    
    # Create pairs of image paths
    pairs = list(zip(mod1_paths, mod2_paths))
    
    # Limit number of pairs if specified
    if count is not None:
        pairs = pairs[:min(count, len(pairs))]
    
    return pairs

## Implement the U2Fusion Model

Now, let's implement the key components of the U2Fusion model.

In [None]:
def guided_filter(p, I, r=5, eps=0.1):
    """
    Edge-preserving smoothing filter used in U2Fusion
    
    Args:
        p: Input image to be filtered
        I: Guidance image (can be the same as p)
        r: Filter radius
        eps: Regularization parameter
        
    Returns:
        Filtered image
    """
    # Convert inputs to float32
    I = I.astype(np.float32)
    p = p.astype(np.float32)
    
    # Get dimensions
    h, w = I.shape
    
    # Step 1: Mean filter
    mean_I = cv2.boxFilter(I, -1, (r, r), normalize=True, borderType=cv2.BORDER_REFLECT)
    mean_p = cv2.boxFilter(p, -1, (r, r), normalize=True, borderType=cv2.BORDER_REFLECT)
    
    # Correlation of I and p
    corr_Ip = cv2.boxFilter(I * p, -1, (r, r), normalize=True, borderType=cv2.BORDER_REFLECT)
    
    # Auto-correlation of I
    corr_II = cv2.boxFilter(I * I, -1, (r, r), normalize=True, borderType=cv2.BORDER_REFLECT)
    
    # Step 2: Linear coefficients
    var_I = corr_II - mean_I * mean_I
    cov_Ip = corr_Ip - mean_I * mean_p
    
    # Compute a and b
    a = cov_Ip / (var_I + eps)
    b = mean_p - a * mean_I
    
    # Step 3: Mean filter for a and b
    mean_a = cv2.boxFilter(a, -1, (r, r), normalize=True, borderType=cv2.BORDER_REFLECT)
    mean_b = cv2.boxFilter(b, -1, (r, r), normalize=True, borderType=cv2.BORDER_REFLECT)
    
    # Step 4: Output
    q = mean_a * I + mean_b
    
    return q

def calculate_saliency(img, kernel_size=3):
    """
    Calculate visual saliency map
    
    Args:
        img: Input image
        kernel_size: Size of kernel for filtering
        
    Returns:
        Saliency map
    """
    # Convert to float32
    img = img.astype(np.float32)
    
    # Apply Gaussian filter to get global mean
    global_mean = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
    
    # Calculate saliency as squared difference between image and its mean
    saliency = (img - global_mean) ** 2
    
    # Apply guided filter to smooth saliency map while preserving edges
    saliency = guided_filter(saliency, img, r=5, eps=0.1)
    
    # Normalize to [0, 1]
    saliency = (saliency - np.min(saliency)) / (np.max(saliency) - np.min(saliency) + 1e-10)
    
    return saliency

def decompose_image(img, r=5, eps=0.1):
    """
    Decompose image into base layer and detail layer using guided filter
    
    Args:
        img: Input image
        r: Filter radius
        eps: Regularization parameter
        
    Returns:
        Base layer and detail layer
    """
    # Base layer (structure) via guided filter
    base_layer = guided_filter(img, img, r, eps)
    
    # Detail layer (texture) via subtraction
    detail_layer = img - base_layer
    
    return base_layer, detail_layer

In [None]:
def soft_threshold(x, T):
    """
    Soft thresholding function
    
    Args:
        x: Input value
        T: Threshold value
        
    Returns:
        Soft thresholded value
    """
    return np.sign(x) * np.maximum(np.abs(x) - T, 0)

def fusion_unified_framework(base1, base2, detail1, detail2, saliency1, saliency2):
    """
    Fusion using the unified framework (U2Fusion)
    
    Args:
        base1: Base layer of first image
        base2: Base layer of second image
        detail1: Detail layer of first image
        detail2: Detail layer of second image
        saliency1: Saliency map of first image
        saliency2: Saliency map of second image
        
    Returns:
        Fused base layer and fused detail layer
    """
    # Normalize saliency maps to create weights
    weight_sum = saliency1 + saliency2 + 1e-10
    weight1 = saliency1 / weight_sum
    weight2 = saliency2 / weight_sum
    
    # Fuse base layers using weighted average
    fused_base = weight1 * base1 + weight2 * base2
    
    # L1-norm for detail layers
    abs_detail1 = np.abs(detail1)
    abs_detail2 = np.abs(detail2)
    
    # Adaptive threshold
    T = 0.1 * np.mean(abs_detail1 + abs_detail2)
    
    # Soft thresholding
    soft_detail1 = soft_threshold(detail1, T)
    soft_detail2 = soft_threshold(detail2, T)
    
    # Choose maximum absolute value for detail layers (L1-norm)
    detail_mask = (abs_detail1 >= abs_detail2).astype(np.float32)
    fused_detail = detail_mask * soft_detail1 + (1 - detail_mask) * soft_detail2
    
    return fused_base, fused_detail

def u2fusion(img1, img2, r=5, eps=0.1):
    """
    Complete U2Fusion method
    
    Args:
        img1: First input image
        img2: Second input image
        r: Filter radius for guided filter
        eps: Regularization parameter for guided filter
        
    Returns:
        Fused image
    """
    # Calculate saliency maps
    saliency1 = calculate_saliency(img1)
    saliency2 = calculate_saliency(img2)
    
    # Decompose images into base and detail layers
    base1, detail1 = decompose_image(img1, r, eps)
    base2, detail2 = decompose_image(img2, r, eps)
    
    # Fusion using unified framework
    fused_base, fused_detail = fusion_unified_framework(
        base1, base2, detail1, detail2, saliency1, saliency2
    )
    
    # Reconstruct fused image
    fused_img = fused_base + fused_detail
    
    # Ensure pixel values are in valid range [0, 1]
    fused_img = np.clip(fused_img, 0, 1)
    
    return fused_img

## Get Dataset and Load Image Pairs

In [None]:
# Get all image pairs
image_pairs = get_image_pairs(base_dataset_path, modality_pair)
print(f"Found {len(image_pairs)} image pairs for {modality_pair}")

# Display a sample pair
if image_pairs:
    img1_path, img2_path = image_pairs[0]
    img1, img2 = load_image_pair(img1_path, img2_path)
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img1, cmap='gray')
    plt.title(os.path.basename(img1_path))
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(img2, cmap='gray')
    plt.title(os.path.basename(img2_path))
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No image pairs found. Please check the dataset path.")

## Create Directory for U2Fusion Results

In [None]:
# Create directory for U2Fusion results
u2fusion_dir = "fused_images/U2Fusion"
os.makedirs(u2fusion_dir, exist_ok=True)
print(f"Created directory: {u2fusion_dir}")

## Process Image Pairs with U2Fusion

Now, let's apply the U2Fusion to our image pairs and save the results.

In [None]:
# Set the number of image pairs to process (None for all)
max_pairs = 5

# Process image pairs
if image_pairs:
    # Limit the number of pairs if specified
    if max_pairs is not None:
        pairs_to_process = image_pairs[:min(max_pairs, len(image_pairs))]
    else:
        pairs_to_process = image_pairs
    
    # Initialize results storage
    results = []
    
    # Process each pair
    for idx, (img1_path, img2_path) in enumerate(tqdm(pairs_to_process, desc="Processing")):
        # Load images
        img1, img2 = load_image_pair(img1_path, img2_path)
        
        # Apply U2Fusion
        start_time = time.time()
        fused_img = u2fusion(img1, img2, r=5, eps=0.1)
        execution_time = time.time() - start_time
        
        # Save the result using our evaluation module
        output_path = fe.save_fusion_result(fused_img, img1_path, img2_path, "U2Fusion")
        
        # Calculate metrics
        metrics = fe.evaluate_fusion(fused_img, img1, img2)
        metrics['time'] = execution_time
        
        # Store results
        results.append({
            'img1': os.path.basename(img1_path),
            'img2': os.path.basename(img2_path),
            'output': output_path,
            **metrics
        })
        
        # Display the first result
        if idx == 0:
            plt.figure(figsize=(15, 5))
            
            # First input image
            plt.subplot(1, 3, 1)
            plt.imshow(img1, cmap='gray')
            plt.title(os.path.basename(img1_path))
            plt.axis('off')
            
            # Second input image
            plt.subplot(1, 3, 2)
            plt.imshow(img2, cmap='gray')
            plt.title(os.path.basename(img2_path))
            plt.axis('off')
            
            # Fused result
            plt.subplot(1, 3, 3)
            plt.imshow(fused_img, cmap='gray')
            plt.title('U2Fusion Result')
            plt.axis('off')
            
            # Add metrics as text
            metrics_text = (
                f"PSNR: {metrics['psnr']:.2f} dB\n"
                f"SSIM: {metrics['ssim']:.4f}\n"
                f"Time: {execution_time:.3f} s"
            )
            plt.figtext(0.5, 0.01, metrics_text, ha='center', fontsize=12, 
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            plt.tight_layout()
            plt.show()
    
    # Calculate average metrics
    avg_psnr = np.mean([r['psnr'] for r in results])
    avg_ssim = np.mean([r['ssim'] for r in results])
    avg_time = np.mean([r['time'] for r in results])
    
    print(f"\nProcessed {len(results)} image pairs")
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average execution time: {avg_time:.3f} seconds per image pair")
    print(f"Results saved to {u2fusion_dir}/")

## Compare Results from Different Models

Let's create a function to compare all three models: LRD, NSST-PAPCNN, and U2Fusion.

In [None]:
def compare_all_models(img1_path, img2_path):
    """
    Compare fusion results from LRD, NSST-PAPCNN, and U2Fusion models
    
    Args:
        img1_path: Path to first source image
        img2_path: Path to second source image
    """
    # Load images
    img1, img2 = load_image_pair(img1_path, img2_path)
    
    # Get base filenames for finding saved fusion results
    img1_name = os.path.basename(img1_path).split('.')[0]
    img2_name = os.path.basename(img2_path).split('.')[0]
    
    # Try to load fused images from saved results
    # Format: {pair_name}_{model}.png
    pair_name = f"{img1_name}_{img2_name}"
    
    # Try to load saved results, or generate if not available
    model_results = {}
    model_names = ['LRD', 'NSST_PAPCNN', 'U2Fusion']
    
    for model_name in model_names:
        result_path = os.path.join('fused_images', model_name, f"{pair_name}.png")
        if os.path.exists(result_path):
            # Load saved result
            fused_img = cv2.imread(result_path, cv2.IMREAD_GRAYSCALE) / 255.0
            model_results[model_name] = fused_img
            print(f"Loaded saved result for {model_name}")
    
    # Visualize comparison
    if model_results:
        plt.figure(figsize=(15, 8))
        
        # Input images
        plt.subplot(2, 3, 1)
        plt.imshow(img1, cmap='gray')
        plt.title(os.path.basename(img1_path))
        plt.axis('off')
        
        plt.subplot(2, 3, 2)
        plt.imshow(img2, cmap='gray')
        plt.title(os.path.basename(img2_path))
        plt.axis('off')
        
        # Results from each model
        for i, (model_name, fused_img) in enumerate(model_results.items()):
            plt.subplot(2, 3, i+3)
            plt.imshow(fused_img, cmap='gray')
            plt.title(f'{model_name}')
            plt.axis('off')
            
            # Calculate metrics
            metrics = fe.evaluate_fusion(fused_img, img1, img2)
            metrics_text = f"PSNR: {metrics['psnr']:.2f} dB\nSSIM: {metrics['ssim']:.4f}"
            plt.xlabel(metrics_text, fontsize=10)
        
        plt.tight_layout()
        plt.savefig(f'fused_images/comparison_{pair_name}.png', dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Saved comparison to fused_images/comparison_{pair_name}.png")
    else:
        print("No saved results found. Please run the fusion models first.")

# Compare all models for the first image pair
if image_pairs:
    img1_path, img2_path = image_pairs[0]
    compare_all_models(img1_path, img2_path)

## Quantitative Evaluation of All Models

Let's create a function to quantitatively evaluate all three models on multiple image pairs.

In [None]:
def evaluate_all_models(image_pairs, max_pairs=5):
    """
    Evaluate all fusion models on multiple image pairs
    
    Args:
        image_pairs: List of image path pairs
        max_pairs: Maximum number of pairs to evaluate
    """
    # Limit the number of pairs if specified
    if max_pairs is not None:
        pairs_to_process = image_pairs[:min(max_pairs, len(image_pairs))]
    else:
        pairs_to_process = image_pairs
    
    # Model names
    model_names = ['LRD', 'NSST_PAPCNN', 'U2Fusion']
    
    # Initialize results dictionary
    results = {model: {'psnr': [], 'ssim': []} for model in model_names}
    
    # Process each pair
    for img1_path, img2_path in tqdm(pairs_to_process, desc="Evaluating"):
        # Load source images
        img1, img2 = load_image_pair(img1_path, img2_path)
        
        # Get pair name for finding saved results
        img1_name = os.path.basename(img1_path).split('.')[0]
        img2_name = os.path.basename(img2_path).split('.')[0]
        pair_name = f"{img1_name}_{img2_name}"
        
        # Evaluate each model
        for model_name in model_names:
            result_path = os.path.join('fused_images', model_name, f"{pair_name}.png")
            if os.path.exists(result_path):
                # Load saved result
                fused_img = cv2.imread(result_path, cv2.IMREAD_GRAYSCALE) / 255.0
                
                # Calculate metrics
                metrics = fe.evaluate_fusion(fused_img, img1, img2)
                
                # Store metrics
                results[model_name]['psnr'].append(metrics['psnr'])
                results[model_name]['ssim'].append(metrics['ssim'])
    
    # Calculate average metrics for each model
    avg_results = {}
    for model_name in model_names:
        if results[model_name]['psnr']:
            avg_results[model_name] = {
                'avg_psnr': np.mean(results[model_name]['psnr']),
                'avg_ssim': np.mean(results[model_name]['ssim']),
                'std_psnr': np.std(results[model_name]['psnr']),
                'std_ssim': np.std(results[model_name]['ssim']),
            }
    
    # Display results in a table
    if avg_results:
        print("Quantitative Evaluation Results:")
        print("-" * 70)
        print(f"{'Model':<15} | {'Avg PSNR (dB)':<15} | {'Std PSNR':<10} | {'Avg SSIM':<10} | {'Std SSIM':<10}")
        print("-" * 70)
        for model_name in model_names:
            if model_name in avg_results:
                print(f"{model_name:<15} | {avg_results[model_name]['avg_psnr']:<15.2f} | "
                      f"{avg_results[model_name]['std_psnr']:<10.2f} | "
                      f"{avg_results[model_name]['avg_ssim']:<10.4f} | "
                      f"{avg_results[model_name]['std_ssim']:<10.4f}")
        print("-" * 70)
        
        # Visualize results
        plt.figure(figsize=(12, 5))
        
        # PSNR comparison
        plt.subplot(1, 2, 1)
        model_list = list(avg_results.keys())
        psnr_vals = [avg_results[model]['avg_psnr'] for model in model_list]
        psnr_std = [avg_results[model]['std_psnr'] for model in model_list]
        
        plt.bar(model_list, psnr_vals, yerr=psnr_std, alpha=0.8, capsize=10)
        plt.ylabel('PSNR (dB)')
        plt.title('Average PSNR Comparison')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # SSIM comparison
        plt.subplot(1, 2, 2)
        ssim_vals = [avg_results[model]['avg_ssim'] for model in model_list]
        ssim_std = [avg_results[model]['std_ssim'] for model in model_list]
        
        plt.bar(model_list, ssim_vals, yerr=ssim_std, alpha=0.8, capsize=10)
        plt.ylabel('SSIM')
        plt.title('Average SSIM Comparison')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.savefig('fused_images/model_comparison_chart.png', dpi=300, bbox_inches='tight')
        plt.show()
        print("Saved comparison chart to fused_images/model_comparison_chart.png")
    else:
        print("No results found. Please run the fusion models first.")

# Evaluate all models
if image_pairs:
    evaluate_all_models(image_pairs, max_pairs=5)

## Summary of Results

We have successfully applied the U2Fusion method to the medical image pairs and compared it with the LRD and NSST-PAPCNN methods. Here's what we've accomplished:

1. Set up a Python virtual environment for running image fusion models
2. Created a folder structure for organizing fused images from different models
3. Implemented three different fusion methods:
   - LRD (Laplacian Re-Decomposition)
   - NSST-PAPCNN (Non-Subsampled Shearlet Transform with Parameter-Adaptive Pulse Coupled Neural Network)
   - U2Fusion (Unified Unsupervised Image Fusion)
4. Applied the fusion to multiple image pairs
5. Saved the results and calculated performance metrics
6. Compared results from different fusion models both visually and quantitatively

The U2Fusion method is particularly interesting because it's designed to work well on various types of image pairs without requiring training data. It achieves this by using a unified framework that adaptively combines base and detail layers of the input images.