In [61]:
import os
import numpy as np
from PIL import Image, ImageSequence
import cv2
from scipy.ndimage import median_filter
from skimage import filters
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import jaccard_score
from skimage.filters import threshold_otsu
import matplotlib.pyplot as plt

In [62]:
def load_image_stack(path):
    frames = []
    with Image.open(path) as img:
        for frame in ImageSequence.Iterator(img):
            image = np.array(frame)
            frames.append(image)
    return np.array(frames)

def denoise_image_stack(image_stack, gaussian_sigma=1.0, median_size=3):
    denoised_frames = []
    for i, image in enumerate(image_stack):
        image_min, image_max = np.min(image), np.max(image)
        image = median_filter(image, size=median_size)
        image = filters.gaussian(image, sigma=gaussian_sigma, preserve_range=True)

        image_norm = (image - image_min) / (image_max - image_min)
        image_norm = image_norm.astype(np.float32)

        height, width = image.shape[:2]
        d = max(5, min(height, width) // 50)
        sigmaColor = np.std(image_norm) * 0.15
        sigmaSpace = d * 1

        image_filtered = cv2.bilateralFilter(image_norm, d, sigmaColor, sigmaSpace)

        image_filtered = image_filtered * (image_max - image_min) + image_min
        image_filtered = np.clip(image_filtered, 0, 65535)

        out_dtype = np.uint16 if image_max > 255 else np.uint8
        final_frame = image_filtered.astype(out_dtype)
        denoised_frames.append(final_frame)

        #print(f"Processed slice {i+1}/{len(image_stack)}")
    return np.array(denoised_frames)

def fuse_stack(image_stack, method='max'):
    if method == 'max':
        return np.max(image_stack, axis=0)
    elif method == 'sum':
        return np.clip(np.sum(image_stack, axis=0), 0, 65535)
    elif method == 'average':
        return np.clip(np.mean(image_stack, axis=0), 0, 65535)
    else:
        raise ValueError("Unknown fusion method. Use 'max', 'sum', or 'average'.")

In [63]:
import os
import numpy as np
from PIL import Image, ImageSequence
import cv2
from scipy.ndimage import median_filter
from skimage import filters
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

# Function to load image stack
def load_image_stack(path):
    frames = []
    with Image.open(path) as img:
        for frame in ImageSequence.Iterator(img):
            image = np.array(frame)
            frames.append(image)
    return np.array(frames)

# Function to denoise image stack
def denoise_image_stack(image_stack, gaussian_sigma=1.0, median_size=3):
    denoised_frames = []
    for i, image in enumerate(image_stack):
        image_min, image_max = np.min(image), np.max(image)
        image = median_filter(image, size=median_size)
        image = filters.gaussian(image, sigma=gaussian_sigma, preserve_range=True)

        image_norm = (image - image_min) / (image_max - image_min)
        image_norm = image_norm.astype(np.float32)

        height, width = image.shape[:2]
        d = max(5, min(height, width) // 50)
        sigmaColor = np.std(image_norm) * 0.15
        sigmaSpace = d * 1

        image_filtered = cv2.bilateralFilter(image_norm, d, sigmaColor, sigmaSpace)

        image_filtered = image_filtered * (image_max - image_min) + image_min
        image_filtered = np.clip(image_filtered, 0, 65535)

        out_dtype = np.uint16 if image_max > 255 else np.uint8
        final_frame = image_filtered.astype(out_dtype)
        denoised_frames.append(final_frame)

    return np.array(denoised_frames)

# Function to fuse image stack
def fuse_stack(image_stack, method='max'):
    if method == 'max':
        return np.max(image_stack, axis=0)
    elif method == 'sum':
        return np.clip(np.sum(image_stack, axis=0), 0, 65535)
    elif method == 'average':
        return np.clip(np.mean(image_stack, axis=0), 0, 65535)
    else:
        raise ValueError("Unknown fusion method. Use 'max', 'sum', or 'average'.")

# Function to resize images to match
def resize_images_to_match(image1, image2):
    """
    Resize image1 to match the dimensions of image2.
    Ensures that both images are non-empty before attempting resizing.
    """
    if image1 is None or image2 is None:
        raise ValueError("One or both images are empty.")
    
    # Resize image1 to match the shape of image2
    return cv2.resize(image1, (image2.shape[1], image2.shape[0]))

# Function to normalize image based on percentiles
def percentile_normalization(image, pmin=2, pmax=99.8, axis=None):
    if not (np.isscalar(pmin) and np.isscalar(pmax) and 0 <= pmin < pmax <= 100 ):
        raise ValueError("Invalid values for pmin and pmax")

    low_percentile = np.percentile(image, pmin, axis=axis, keepdims=True)
    high_percentile = np.percentile(image, pmax, axis=axis, keepdims=True)

    if low_percentile == high_percentile:
        print(f"Same min {low_percentile} and high {high_percentile}, image may be empty")
        return image

    return (image - low_percentile) / (high_percentile - low_percentile)

# Evaluation function
def evaluate_fusion(fused_image, reference_image, input_image, do_normalize=True):
    # Ensure images are the same shape by resizing
    reference_image_resized = resize_images_to_match(reference_image, fused_image)
    input_image_resized = resize_images_to_match(input_image, fused_image)

    if do_normalize:
        fused_image = percentile_normalization(fused_image)
        reference_image_resized = percentile_normalization(reference_image_resized)

    # Compute SSIM for input vs. reference (used for prediction SSIM)
    input_ssim, _ = ssim(input_image_resized, reference_image_resized, full=True)

    # Compute SSIM for fused vs. reference (used for reference SSIM)
    fused_ssim, _ = ssim(fused_image, reference_image_resized, full=True)

    # Compute normalized SSIM
    n_ssim = (fused_ssim - input_ssim) / (1 - input_ssim)

    # Compute PSNR
    psnr_value = psnr(fused_image, reference_image_resized)

    # Compute MSE
    mse_value = mean_squared_error(fused_image.flatten(), reference_image_resized.flatten())

    # Print results
    print(f"Prediction SSIM : {input_ssim:.4f}")
    print(f"Reference SSIM  : {fused_ssim:.4f}")
    print(f"N SSIM          : {n_ssim:.4f}")
    print(f"PSNR            : {psnr_value:.4f} dB")
    print(f"MSE             : {mse_value:.4f}")

    return {
        "SSIM": fused_ssim,
        "N SSIM": n_ssim,
        "PSNR": psnr_value,
        "MSE": mse_value
    }

In [64]:
def show_side_by_side(image1, image2, title1='Image 1', title2='Image 2'):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(image1, cmap='gray')
    axes[0].set_title(title1)
    axes[0].axis('off')
    axes[1].imshow(image2, cmap='gray')
    axes[1].set_title(title2)
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()

In [67]:
# Main processing loop
nucleus_ids = [169, 170, 173, 174, 177]
membrane_ids = [171, 172, 175, 176, 179]
base_dir = r"C:\Users\gronea\Box\Fuse My Cells Challenge\image_169-180"
fusion_method = 'max'

for image_id in nucleus_ids + membrane_ids:
    if image_id in nucleus_ids:
        type_tag = "nucleus"
    elif image_id in membrane_ids:
        type_tag = "membrane"
    else:
        continue

    print(f"\n=== Processing Image {image_id} ({type_tag}) ===")

    input_path = os.path.join(base_dir, f"image_{image_id}_{type_tag}_angle.tif")
    reference_path = os.path.join(base_dir, f"image_{image_id}_{type_tag}_fused.tif")

    if not os.path.exists(input_path) or not os.path.exists(reference_path):
        print(f"Missing files for image {image_id}. Skipping.")
        continue

    # Load and process
    input_stack = load_image_stack(input_path)
    reference_stack = load_image_stack(reference_path)

    denoised_stack = denoise_image_stack(input_stack)
    fused_input = fuse_stack(denoised_stack, method=fusion_method)
    fused_reference = fuse_stack(reference_stack, method=fusion_method)

    # Evaluate the fusion
    evaluation_results = evaluate_fusion(fused_input, fused_reference, input_stack[0], reference_stack[0], do_normalize=True)

    # Visualize the fused and reference images
    show_side_by_side(fused_input, fused_reference,
                      title1=f'{type_tag.capitalize()} Fused {image_id}',
                      title2=f'{type_tag.capitalize()} Reference {image_id}')


=== Processing Image 169 (nucleus) ===


TypeError: evaluate_fusion() got multiple values for argument 'do_normalize'