In [None]:
# load libaries
import os
import SimpleITK as sitk
from skimage.transform import radon, iradon
from scipy.ndimage import gaussian_filter
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from perlin_noise import PerlinNoise
import torch
import torch.nn as nn
from torch.optim import Adam
from scipy.ndimage import rotate
import cv2
from skimage.restoration import denoise_tv_chambolle
import torch.nn.functional as F
import matplotlib.image as mpimg

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import the DIP repository, make sure to clone https://github.com/DmitryUlyanov/deep-image-prior
import sys
sys.path.append('C:/Users/Administrator/deep-image-prior')
from models import skip

In [None]:
def generate_sinogram_2d(image_2d, angles):
    # Get image dimensions
    rows, cols = image_2d.shape
    # Calculate the radius of the circle
    radius = min(rows, cols) // 2
    # Create a mask to zero out areas outside the circle
    Y, X = np.ogrid[:rows, :cols]
    center_y, center_x = rows / 2, cols / 2
    mask = (X - center_x)**2 + (Y - center_y)**2 <= radius**2
    image_masked = image_2d * mask
    # Generate sinogram
    sinogram = radon(image_masked, theta=angles, circle=True)
    return sinogram


def add_poisson_noise(sinogram, scale=1e4):
    # Ensure there are no negative values (Poisson noise requires non-negative input)
    sinogram_clipped = np.clip(sinogram, 0, None)
    
    # Scale up the intensity values to approximate "photon counts"
    sinogram_scaled = sinogram_clipped * scale
    
    # Generate Poisson noise (random photon counts)
    sinogram_noisy_scaled = np.random.poisson(sinogram_scaled).astype(np.float32)
    
    # Scale back down to the original magnitude
    sinogram_noisy = sinogram_noisy_scaled / scale
    return sinogram_noisy


def backproject(sinogram, angles):
    reconstructed = np.zeros((sinogram.shape[1], sinogram.shape[1]))
    for i, angle in enumerate(angles):
        rotated_projection = rotate(np.tile(sinogram[i], (sinogram.shape[1], 1)).T, -angle, reshape=False)
        reconstructed += rotated_projection
    return reconstructed / len(angles)


def generate_sinogram(image, num_angles=180):
    angles = np.linspace(0, 180, num_angles, endpoint=False)
    sinogram = np.array([np.sum(rotate(image, angle, reshape=False), axis=0) for angle in angles])
    return sinogram, angles

In [None]:
def MLEM_reconstruction(sinogram, angles, reference_image, num_iterations=50):
    # Initialize the image with ones
    image_shape = (sinogram.shape[1], sinogram.shape[1])
    reconstructed = np.ones(image_shape, dtype=np.float32)
    psnr_values = []
    epsilon = 1e-6

    for iteration in range(num_iterations):
        # E-Step: Forward projection
        forward_projection = radon(reconstructed, theta=angles, circle=True)
        
        # Ensure matching shapes for division
        if forward_projection.shape != sinogram.shape:
            forward_projection = forward_projection[: sinogram.shape[0], :]
        
        # Compute the ratio
        ratio = sinogram / (forward_projection + epsilon)
        
        # M-Step: Backprojection of the ratio
        back_projection = iradon(ratio, theta=angles, filter_name=None, circle=True)
        
        # Update the reconstructed image
        reconstructed *= back_projection
        
        # Regularization to ensure non-negative image
        reconstructed = np.maximum(reconstructed, 0)
        
        # Compute PSNR
        psnr_value = cv2.PSNR(reference_image.astype(np.float32), reconstructed.astype(np.float32))
        psnr_values.append(psnr_value)
    
    return reconstructed, psnr_values

def plot_psnr(psnr_values):
    plt.figure(figsize=(8, 4))
    plt.plot(range(1, len(psnr_values) + 1), psnr_values,  linestyle='-')
    plt.xlabel("Number of Iterations")
    plt.ylabel("PSNR (dB)")
    plt.title("PSNR vs. Number of Iterations for MLEM Reconstruction")
    plt.grid()
    plt.show()

In [None]:
## Dicom reader
def load_dicom_series(folder_path):
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(folder_path)
    if not dicom_names:
        raise ValueError(f"No DICOM files found in {folder_path}")
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    return image

def normalize_image(image):
    img_array = sitk.GetArrayFromImage(image).astype(np.float32)
    img_min = np.min(img_array)
    img_max = np.max(img_array)
    if img_max - img_min < 1e-5:
        return np.zeros_like(img_array)
    else:
        return (img_array - img_min) / (img_max - img_min)

In [None]:
## Image registration
def register_pet_ct(ct_image, pet_image, mode="upsample_PET"):
    # Initialize the registration method
    registration_method = sitk.ImageRegistrationMethod()

    # Set multi-resolution pyramid strategy
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Set registration metric
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    # Set transformation type (rigid transformation)
    initial_transform = sitk.CenteredTransformInitializer(
        ct_image,
        pet_image,
        sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )

    registration_method.SetInitialTransform(initial_transform, inPlace=False)

    # Set optimizer
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0,
                                                      numberOfIterations=100,
                                                      convergenceMinimumValue=1e-6,
                                                      convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Set interpolation method
    registration_method.SetInterpolator(sitk.sitkLinear)

    try:
        final_transform = registration_method.Execute(sitk.Cast(ct_image, sitk.sitkFloat32),
                                                      sitk.Cast(pet_image, sitk.sitkFloat32))
        print("Optimizer Converged:", registration_method.GetOptimizerStopConditionDescription())
        print("Final Metric Value:", registration_method.GetMetricValue())
    except Exception as e:
        print(f"\nRegistration failed: {e}")
        raise e

    resampler = sitk.ResampleImageFilter()
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(final_transform)

    if mode == "upsample_PET":
        resampler.SetReferenceImage(ct_image)
        pet_resampled = resampler.Execute(pet_image)
        print("Resampling completed.")
        return ct_image, pet_resampled

    elif mode == "downsample_CT":
        resampler.SetReferenceImage(pet_image)
        ct_resampled = resampler.Execute(ct_image)
        print("Resampling completed.")
        return ct_resampled, pet_image

    else:
        raise ValueError("Invalid mode selection, mode should be 'upsample_PET' or 'downsample_CT'")

In [None]:
# Design for Bowsher prior Filter 
def apply_bayesian_filter(image_2d, sigma=1.0, zeta=0.5, rho=0.01, alpha=0.001):
    # Compute the Bayesian kernel (Bowsher-like prior)
    kernel = compute_bowsher_kernel_2d(image_2d, zeta=zeta, rho=rho, alpha=alpha)
    filtered_image = image_2d * kernel
    return filtered_image

def compute_kernel_2d(ct_2d, sigma=1.0):
    # Compute the gradient of the CT image
    gradient_x = gaussian_filter(ct_2d, sigma=sigma, order=1, mode='nearest')
    gradient_y = gaussian_filter(ct_2d, sigma=sigma, order=1, mode='nearest')
    gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)

    # Design a bilateral kernel, assigning lower weights to edges
    kernel = np.exp(- (gradient_magnitude ** 2))
    return kernel

def psi(u, zeta=0.5, rho=0.01):
    return (np.arctan((zeta - u) / rho) / np.pi) + 0.5

def Tq_linear(value):
    return value

def compute_bowsher_kernel_2d(
    ct_image_2d, 
    zeta=0.5, 
    rho=0.01, 
    Tq_func=Tq_linear, 
    epsilon=1e-6,
    alpha = 0.001
):
    # Offsets for the 8 neighbors in (dy, dx)
    neighbor_offsets = [
        (-1, -1), (-1, 0), (-1, 1),
        ( 0, -1),          ( 0, 1),
        ( 1, -1), ( 1, 0), ( 1, 1)
    ]
    
    H, W = ct_image_2d.shape
    weights_2d = np.zeros((H, W, len(neighbor_offsets)), dtype=np.float32)

    # Pre-compute Tq for the entire image to avoid repeated calls
    transformed_ct = Tq_func(ct_image_2d)

    # For each pixel, we need M_j = max|Tq(mu_j) - Tq(mu_k)| over neighbors
    M_array = np.zeros((H, W), dtype=np.float32)

    for y in range(H):
        for x in range(W):
            central_val = transformed_ct[y, x]
            diffs = []
            for dy, dx in neighbor_offsets:
                ny, nx = y + dy, x + dx
                if 0 <= ny < H and 0 <= nx < W:
                    neighbor_val = transformed_ct[ny, nx]
                    diffs.append(abs(central_val - neighbor_val))
            if len(diffs) > 0:
                M_array[y, x] = max(diffs)
            else:
                M_array[y, x] = 0.0

    # Compute the actual weights for each neighbor
    for y in range(H):
        for x in range(W):
            central_val = transformed_ct[y, x]
            Mj = M_array[y, x]
            
            for n_idx, (dy, dx) in enumerate(neighbor_offsets):
                ny, nx = y + dy, x + dx
                if not (0 <= ny < H and 0 <= nx < W):
                    # Out of bounds, weight = 0
                    weights_2d[y, x, n_idx] = 0.0
                    continue
                
                neighbor_val = transformed_ct[ny, nx]
                Mk = M_array[ny, nx]
                
                denom = (Mj + Mk) / 2.0
                denom = denom if denom > epsilon else epsilon
                
                diff = abs(central_val - neighbor_val)
                u_jk = diff / denom
                
                w_jk = psi(u_jk, zeta=zeta, rho=rho)
                weights_2d[y, x, n_idx] = w_jk

    kernel = np.sum(weights_2d, axis=-1)
    kernel = np.exp(alpha * kernel) 
    return kernel

In [None]:
def project(image, angles):
    return radon(image, theta=angles, circle=True)

def backproject(sinogram, angles):
    if isinstance(sinogram, torch.Tensor):
        sinogram = sinogram.cpu().numpy()
    return iradon(sinogram, theta=angles, circle=True)


def apply_filter(image, method="bilateral", sigma=1.0, zeta=0.5, rho=0.01, alpha=0.001, ct_prior=None):

    image_np = image.detach().cpu().numpy() 

    if method == "bilateral":
        filtered_image = cv2.bilateralFilter(image_np.astype(np.float32), d=9, sigmaColor=sigma*50, sigmaSpace=sigma*50)
    
    elif method == "anisotropic":
        filtered_image = denoise_tv_chambolle(image_np, weight=sigma)
    
    elif method == "gaussian":
        filtered_image = gaussian_filter(image_np, sigma=sigma)
    
    elif method == "bayesian":
        filtered_image = apply_bayesian_filter(image_np, sigma=sigma, zeta=zeta, rho=rho, alpha=alpha)
    
    elif method == "gradient":
        if ct_prior is None:
            raise ValueError("CT prior required for gradient-based kernel filtering.")

        # Precompute the anatomical kernel
        if ct_prior is not None:
            ct_prior_np = ct_prior.squeeze().detach().cpu().numpy()
            kernel = compute_kernel_2d(ct_prior_np, sigma=sigma)


        #ct_prior_np = ct_prior.squeeze().detach().cpu().numpy()
        #kernel = compute_kernel_2d(ct_prior_np, sigma=sigma)  # Compute anatomical gradient kernel
        filtered_image = image_np * kernel  # Weight input image by anatomical kernel

    else:
        raise ValueError("Invalid method. Choose 'bilateral', 'anisotropic', 'gaussian', 'bayesian', or 'gradient'.")

    return torch.tensor(filtered_image, dtype=torch.float32).to(image.device)


def KEM_step(image, sinogram, angles, kernel_size, sigma=1.0, filter_method="bilateral", ct_prior=None, num_em_iterations=5):
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    updated_image = image.clone()  # Ensure iterate on a copy

    for _ in range(num_em_iterations):  # Iterate multiple EM steps
        # Step 1: Forward project
        forward_projection = project(updated_image.squeeze().detach().cpu().numpy(), angles) # radon transformation
        forward_projection = torch.tensor(forward_projection, dtype=torch.float32).to(device)

        # Step 2: Compute ratio
        ratio = sinogram / (forward_projection + 1e-8)
        ratio = ratio.clamp(min=0, max=10)  # Prevent extreme values

        # Step 3: Backproject ratio update
        back_projection = backproject(ratio.cpu().numpy(), angles)
        back_projection = torch.tensor(back_projection, dtype=torch.float32).to(device)

        # Step 4: Apply filtering
        smoothed_image = apply_filter(updated_image.squeeze(), method=filter_method, sigma=sigma, ct_prior=ct_prior)

        # Step 5: Update estimate
        updated_image.mul_(smoothed_image) # same as updated_image = smoothed_image * back_projection but this is better for memory allocation
        #updated_image = updated_image / updated_image.max()  # Normalize

    return updated_image.unsqueeze(0).unsqueeze(0)

In [None]:
def run_KEM_reconstruction(sinogram, angles, filter_method="bayesian", num_iterations=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    image_shape = (sinogram.shape[1], sinogram.shape[1]) 
    print("Sinogram shape:", sinogram.shape)
    print("Reconstructed image shape:", image_shape)
    reconstructed = torch.ones(image_shape, dtype=torch.float32, device=device)

    # Convert sinogram to torch tensor
    sinogram_torch = torch.tensor(sinogram, dtype=torch.float32, device=device)

    for _ in range(num_iterations):
        forward_projection = torch.tensor(radon(reconstructed, theta=angles, circle=True), dtype=torch.float32, device=device)

        ratio = sinogram_torch / (forward_projection + 1e-6)
        
        back_projection_np = iradon(ratio.cpu().numpy(), theta=angles, filter_name=None, circle=True)

        back_projection = torch.tensor(back_projection_np, dtype=torch.float32, device=device)

        if filter_method is not None:
            filtered_image = apply_filter(reconstructed, method=filter_method, sigma=sigma, ct_prior=ct_prior)
        else:
            filtered_image = back_projection  # No filtering

        reconstructed = back_projection * filtered_image
        reconstructed = torch.clamp(reconstructed, min=0)  # Ensure non-negative values

    return reconstructed.cpu().numpy()

In [None]:
def run_DIP_reconstruction(sinogram, angles, ct_prior, num_epochs=1000):
    device = torch.device("cpu")  # Use CPU for computation

    # Initialize image with FBP reconstruction
    #initial_image = backproject(sinogram, angles)
    initial_image = run_KEM_reconstruction(sinogram, angles, filter_method="bayesian", num_iterations=5)
    current_image = torch.tensor(initial_image, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Normalize and resize CT prior
    ct_prior = torch.tensor(ct_prior, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    #ct_prior = (ct_prior - ct_prior.min()) / (ct_prior.max() - ct_prior.min())  # Normalize to [0,1]

    if ct_prior.shape != current_image.shape:
        ct_prior = torch.nn.functional.interpolate(ct_prior, size=current_image.shape[2:], mode="bilinear", align_corners=False)

    # current_image and ct_prior should have 2 channels
    # Concatenate along the channel dimension (dim=1)
    multi_channel_input = torch.cat((current_image, ct_prior), dim=1)

    input_depth = 2
    net = skip(
        input_depth, 1,  # Single-channel grayscale output
        num_channels_down=[16, 32, 64, 128, 128],  # Downsampling layers
        num_channels_up=[16, 32, 64, 128, 128],  # Upsampling layers
        num_channels_skip=[4, 4, 4, 4, 4],  # Skip connections to maintain spatial info
        upsample_mode='bilinear',  # Interpolation mode for upsampling
        need_sigmoid=True,  # Sigmoid activation for output (keeps values in [0,1])
        need_bias=True,  # Adds bias terms
        pad='reflection',  # Padding type (reflection padding helps prevent artifacts)
        act_fun='LeakyReLU'  # Activation function in convolution layers
    ).to(device)
    net = net.to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Define optimizer for DIP
    optimizer = torch.optim.Adam(net.parameters(), lr=0.008)
    mse_loss = torch.nn.MSELoss()

    for _ in range(num_epochs):
        optimizer.zero_grad()
        output = net(multi_channel_input)  # The input should now have 2 channels
        
        # Now both output and multi_channel_input should have the same size
        loss = mse_loss(output, multi_channel_input)
        loss.backward()
        optimizer.step()
    
    return output.detach().cpu().numpy().squeeze()


In [None]:
# not used right now

def DIP_to_KEM(sinogram, angles, ct_prior, kernel_size, num_iterations=10, dip_iterations=100, sigma=1.0, filter_method="bilateral", alpha=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Step 1: Initial Reconstruction using KEM
    initial_kem = run_KEM_reconstruction(sinogram, angles, filter_method=filter_method, num_iterations=num_iterations)
    plt.imshow(initial_kem, cmap='gray')
    plt.title("Initial KEM as INPUT")
    plt.show
    current_image = torch.tensor(initial_kem, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Normalize and resize CT prior
    ct_prior = ct_prior.clone().detach().float().unsqueeze(0).unsqueeze(0).to(device)
    ct_prior = (ct_prior - ct_prior.min()) / (ct_prior.max() - ct_prior.min())  # Normalize CT prior
    if ct_prior.shape[2:] != current_image.shape[2:]:
        ct_prior = F.interpolate(ct_prior, size=current_image.shape[2:], mode="bilinear", align_corners=False)  # Resize CT prior

    # Define DIP Network
    input_depth = 2
    net = skip(
        input_depth, 1,
        num_channels_down=[16, 32, 64, 128, 128],
        num_channels_up=[16, 32, 64, 128, 128],
        num_channels_skip=[4, 4, 4, 4, 4],
        upsample_mode='bilinear',
        need_sigmoid=True,
        need_bias=True,
        pad='reflection',
        act_fun='LeakyReLU'
    ).to(device)

    # Define optimizer for DIP
    optimizer = torch.optim.Adam(net.parameters(), lr=0.008)
    mse_loss = torch.nn.MSELoss()

    # DIP input (multi-channel: noise + CT), FBP + CT is best
    multi_channel_input = torch.cat((current_image, ct_prior), dim=1)

    for i in range(num_iterations):
        # Perform DIP first (in image-space regularization)
        for _ in range(dip_iterations):
            optimizer.zero_grad()
            dip_output = net(multi_channel_input)
            
            # Use KEM to enforce fidelity (data space to image space)
            kem_output = KEM_step(dip_output, sinogram, angles, kernel_size, sigma, filter_method="bayesian", ct_prior=ct_prior)

            # Combine DIP and KEM outputs
            fidelity_loss = mse_loss(dip_output, kem_output)

            # Convert tensors to NumPy for SSIM calculation
            dip_np = dip_output.squeeze().detach().cpu().numpy()
            ct_np = ct_prior.squeeze().detach().cpu().numpy()

            # Structural similarity loss (align with CT)
            ssim_loss = 1 - ssim(dip_np, ct_np, data_range=1.0)
            total_loss = fidelity_loss + 0.1 * ssim_loss  # Weighted loss function
            total_loss.backward()

            optimizer.step()

        # Update current_image with the DIP output and apply KEM next
        current_image = dip_output.detach()

        # Perform KEM (data-space regularization)
        kem_output = KEM_step(current_image, sinogram, angles, kernel_size, sigma, filter_method="bayesian", ct_prior=ct_prior)
        current_image = kem_output.clone()  # Update current image with KEM result

    # Convert final image to NumPy array for visualization
    final_image = current_image.squeeze().detach().cpu().numpy()

    return final_image


Evaluation to evaluate with PSNR, MSE and SSIM value

In [None]:
# Evaluation
def evaluate_reconstruction(ground_truth, reconstructed):
    mse_val = mse(ground_truth, reconstructed)
    rmse_val = np.sqrt(mse_val)
    psnr_val = cv2.PSNR(ground_truth.astype(np.float32), reconstructed.astype(np.float32))
    ssim_val = ssim(ground_truth, reconstructed, data_range=ground_truth.max() - ground_truth.min())

    return mse_val, psnr_val, ssim_val, rmse_val

Load functions if DICOM data is available

In [None]:
def load_images(ct_folder_path, pet_folder_path):
    # Load images
    print("Loading CT image...")
    ct_image = load_dicom_series(ct_folder_path)
    print("Loading PET image...")
    pet_image = load_dicom_series(pet_folder_path)

    # Print image information
    def print_image_info(name, image):
     print(f"\n{name} Image Information:")
     print(f"  Size: {image.GetSize()}")
     print(f"  Spacing: {image.GetSpacing()}")
     print(f"  Origin: {image.GetOrigin()}")
     print(f"  Direction: {image.GetDirection()}")
     print(f"  Dimension: {image.GetDimension()}")

    print_image_info("CT", ct_image)
    print_image_info("PET", pet_image)

    # Ensure both are 3D images
    if ct_image.GetDimension() != 3 or pet_image.GetDimension() != 3:
        raise ValueError("Both CT and PET images must be 3D.")

    # Ensure both are float32
    ct_image = sitk.Cast(ct_image, sitk.sitkFloat32)
    pet_image = sitk.Cast(pet_image, sitk.sitkFloat32)

    #mode = "upsample_PET"
    mode = "downsample_CT"
    ct_image_resampled, pet_image_resampled = register_pet_ct(ct_image, pet_image, mode=mode)

    # ================== 3) Select Middle Slice and Generate Sinogram ==================

    # Select middle slice
    ct_array = sitk.GetArrayFromImage(ct_image_resampled)  # shape: [slices, height, width]
    pet_array = sitk.GetArrayFromImage(pet_image_resampled)  # shape: [slices, height, width]

    middle_slice_idx = ct_array.shape[0] // 2
    ct_slice = ct_array[middle_slice_idx, :, :]
    pet_slice = pet_array[middle_slice_idx, :, :]

    # Normalize
    ct_norm = normalize_image(ct_image_resampled)
    pet_norm = normalize_image(pet_image_resampled)
    ct_slice_norm = ct_norm[middle_slice_idx, :, :]
    pet_slice_norm = pet_norm[middle_slice_idx, :, :]
    initial_kernel_size = ct_slice_norm

    # Generate sinogram
    angles = np.linspace(0., 180., max(ct_slice.shape), endpoint=False)
    pet_sinogram = generate_sinogram_2d(pet_slice_norm, angles)
    pet_sinogram = add_poisson_noise(pet_sinogram, scale=1) # decrease the number to increase the noise, 0.1 was extremly noisy 

    return pet_sinogram, ct_slice, ct_slice_norm, initial_kernel_size, pet_slice, pet_slice_norm, angles


Here starts main: Load CT Image, PET Image and create noisy PET Sinogram -----------------------------------------------------

In [None]:
# Load an example image DICOM
#ct_folder_path = r"C:\Users\Administrator\OneDrive - stud.hs-mannheim.de\Dokumente\PCC3\NIH\PET\CT_1.3.6.1.4.1.14519.5.2.1.7009.2403.798861112144623421423086363370"
#pet_folder_path = r"C:\Users\Administrator\OneDrive - stud.hs-mannheim.de\Dokumente\PCC3\NIH\PET\PT_1.3.6.1.4.1.14519.5.2.1.7009.2403.291916523874874486349020167447"


In [None]:
#pet_sinogram, ct_slice, ct_slice_norm, initial_kernel_size, pet_slice, pet_slice_norm, angles = load_images(ct_folder_path, pet_folder_path)

#ct_prior_norm = ct_slice_norm
#ct_prior = ct_slice

#Normalize PET slice to [0,1]
#pet_slice_norm = (pet_slice - np.min(pet_slice)) / (np.max(pet_slice) - np.min(pet_slice))
#ground_truth = pet_slice_norm

In [None]:
ct_prior = cv2.imread("ct_image.png", cv2.IMREAD_GRAYSCALE).astype(np.float32)
ground_truth = cv2.imread("ground_truth.png", cv2.IMREAD_GRAYSCALE).astype(np.float32)

print("CT Prior Shape:", ct_prior.shape)
print("Ground Truth Shape:", ground_truth.shape)

In [None]:
# Default values for reconstruction
kernel_size = 5  # Example kernel size for KEM
num_iterations = 10  # Total iterations for the KEM + DIP process
dip_iterations = 100  # Number of iterations for DIP
sigma = 1.0  # Standard deviation for filtering
angles = np.linspace(0., 180., max(ct_prior.shape), endpoint=False)

pet_sinogram = generate_sinogram_2d(ground_truth, angles)
pet_sinogram = add_poisson_noise(pet_sinogram, scale=0.1) # decrease the number to increase the noise, 0.1 was extremly noisy 


# Plot CT image, PET Image and both sinograms
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
im1 = axes[0].imshow(ct_prior, cmap='gray')
axes[0].set_title("CT Prior")
axes[0].axis("off")
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

im2 = axes[1].imshow(pet_sinogram, cmap='gray', aspect='auto')
axes[1].set_title("PET Sinogram")
axes[1].axis("off")
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

im3 = axes[2].imshow(ground_truth, cmap='gray')
axes[2].set_title("Ground Truth")
axes[2].axis("off")
fig.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

Function as standalone for comparison

In [None]:
only_kem_result = run_KEM_reconstruction(pet_sinogram, angles, filter_method="bayesian")

In [None]:
only_dip_result = run_DIP_reconstruction(pet_sinogram, angles, ct_prior)

In [None]:
# Plot results
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(only_kem_result, cmap='gray')
axes[0].set_title('Only KEM Reconstruction')
axes[1].imshow(only_dip_result, cmap='gray')
axes[1].set_title('Only DIP Reconstruction with CT as Prior')
plt.show()

In [None]:
def DIP_to_KEM_new(pet_sinogram, ct_prior, angles, filter_method="bayesian", num_iterations=5, num_epochs=1000):
    # Step 1: Initial KEM reconstruction
    kem_result = run_KEM_reconstruction(pet_sinogram, angles, filter_method=filter_method, num_iterations=num_iterations)
    
    # Step 2: Use KEM output as initialization for DIP instead of raw FBP
    dip_result = run_DIP_reconstruction(kem_result, angles, ct_prior, num_epochs=num_epochs)

    # alpha = 0.7 # More influence from KEM (good for structured details)
    # alpha=0.5 # Equal influence
    alpha = 0.5
    # alpha=0.3 # More influence from DIP (good for noise reduction but may lose detail)

    # Step 3: Combine KEM and DIP outputs with a weighted update
    combined_result = alpha * kem_result + (1 - alpha) * dip_result

    # Step 4: Refine further by passing combined result to KEM again
    refined_kem_result = run_KEM_reconstruction(combined_result, angles, filter_method=filter_method, num_iterations=num_iterations)

    # Step 5: Use refined KEM output as a prior in another DIP pass
    refined_dip_result = run_DIP_reconstruction(refined_kem_result, angles, ct_prior, num_epochs=num_epochs)

    # Step 6: Final adaptive combination of KEM and DIP
    final_result = alpha* refined_kem_result + (1 - alpha) * refined_dip_result

    return final_result

In [None]:
new_approach = DIP_to_KEM_new(pet_sinogram, ct_prior, angles, filter_method="bayesian")

In [None]:
mlem_reconstructed, psnr_values = MLEM_reconstruction(pet_sinogram, angles, ground_truth, num_iterations=50)
plot_psnr(psnr_values)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))  
im1 = axes[0].imshow(mlem_reconstructed, cmap='gray')#, clim=[0, 0.07])  
axes[0].set_title("MLEM Reconstruction")  
axes[0].axis('off')  
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04) 
 
im2 = axes[1].imshow(ground_truth, cmap='gray')  
axes[1].set_title("Ground Truth")  
axes[1].axis('off')  
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)  

im3 = axes[2].imshow(ground_truth-mlem_reconstructed, cmap='gray')
axes[2].set_title("Difference (Ground Truth - Reconstruction)")
axes[2].axis('off')
fig.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)

plt.show()

In [None]:
# Plot results 
# alpha = 0.5
fig, axes = plt.subplots(1, 3, figsize=(20, 10))
img0 = axes[0].imshow(new_approach, cmap='gray')
axes[0].set_title('New Combination')
axes[0].axis('off')
fig.colorbar(img0, ax=axes[0], fraction=0.046, pad=0.04)

img1= axes[1].imshow(mlem_reconstructed, cmap='gray')
axes[1].set_title('MLEM')
axes[1].axis('off')
fig.colorbar(img1, ax=axes[1], fraction=0.046, pad=0.04)

plt.show()

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from torchmetrics.image import StructuralSimilarityIndexMeasure

def total_variation_loss(image):
    """TV loss to remove noise while keeping structure."""
    return torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
           torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))

def gradient_difference_loss(output, ct_prior):
    """Encourages sharp edges from CT to be preserved."""
    grad_output_x = output[:, :, :, :-1] - output[:, :, :, 1:]
    grad_output_y = output[:, :, :-1, :] - output[:, :, 1:, :]
    
    grad_ct_x = ct_prior[:, :, :, :-1] - ct_prior[:, :, :, 1:]
    grad_ct_y = ct_prior[:, :, :-1, :] - ct_prior[:, :, 1:, :]
    
    return torch.mean(torch.abs(grad_output_x - grad_ct_x)) + torch.mean(torch.abs(grad_output_y - grad_ct_y))



def run_dip_kem_refinement(kem_image, ct_image, angles, filter_method="bayesian", num_iterations=1000, lr=0.001):
    # Ensure device compatibility
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Convert to tensors
    kem_recon = torch.tensor(kem_image, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)  # (1,1,H,W)
    ct_prior = torch.tensor(ct_image, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)  # (1,1,H,W)

    # Normalize images
    kem_recon = (kem_recon - kem_recon.min()) / (kem_recon.max() - kem_recon.min())
    ct_prior = (ct_prior - ct_prior.min()) / (ct_prior.max() - ct_prior.min())

    # Resize CT prior if needed
    if ct_prior.shape != kem_recon.shape:
        ct_prior = F.interpolate(ct_prior, size=kem_recon.shape[2:], mode="bilinear", align_corners=False)

    # Define DIP Model (Use a skip network like U-Net)
    net = skip(2, 1,  # Two channels: KEM + CT Prior
        num_channels_down=[16, 32, 64, 128, 128],
        num_channels_up=[16, 32, 64, 128, 128],
        num_channels_skip=[4, 4, 4, 4, 4],
        upsample_mode='bilinear',
        need_sigmoid=True, 
        need_bias=True,
        pad='reflection', 
        act_fun='LeakyReLU').to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    mse_loss = torch.nn.MSELoss()
    ssim_loss_fn = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)


    # DIP Input (Concatenating KEM + CT prior)
    dip_input = torch.cat((kem_recon, ct_prior), dim=1)

    # Training Loop
    for i in range(num_iterations):
        optimizer.zero_grad()
    
        dip_output = net(dip_input)  # Get refined image
    
        # Fidelity Loss (KEM)
        fidelity_loss = mse_loss(dip_output, kem_recon)

        # Structural Similarity Loss (CT Prior)
        ssim_loss = 1 - ssim_loss_fn(dip_output, ct_prior)

        # Regularization Losses
        tv_loss = total_variation_loss(dip_output)
        edge_loss = gradient_difference_loss(dip_output, ct_prior)

        # Adjust weights (tune these experimentally)
        total_loss = fidelity_loss + 0.3 * ssim_loss + 0.05 * tv_loss + 0.5 * edge_loss  

        total_loss.backward()
        optimizer.step()

    # Final refined image
    final_image = dip_output.squeeze().detach().cpu().numpy()

    # Display results
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(kem_recon.squeeze().cpu().numpy(), cmap="gray")
    plt.title("Initial KEM")
    plt.subplot(1, 3, 2)
    plt.imshow(ct_prior.squeeze().cpu().numpy(), cmap="gray")
    plt.title("CT Prior")
    plt.subplot(1, 3, 3)
    plt.imshow(final_image, cmap="gray")
    plt.title("Final Refined Image")
    plt.show()

    return final_image


refined_image = run_dip_kem_refinement(only_kem_result, ct_prior, angles)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


# created a class and put the CT prior deeper in the layers so it's not cutting out the PET information (Tumors) but keep the sharp edges from the CT image
class DeepDIP_with_CT(nn.Module):
    def __init__(self):
        super(DeepDIP_with_CT, self).__init__()

        # DIP network (KEM as input)
        self.net = skip(
            1,  # Only KEM as input
            1,  # Output single-channel image
            num_channels_down=[32, 64, 128, 256, 256],  
            num_channels_up=[32, 64, 128, 256, 256],
            num_channels_skip=[8, 8, 8, 8, 8],  
            upsample_mode='bilinear',
            need_sigmoid=True, 
            need_bias=True,
            pad='reflection', 
            act_fun='LeakyReLU'
        )

        # Edge extraction for CT prior
        self.sobel_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        self.sobel_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        sobel_x_kernel = torch.tensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], dtype=torch.float32).unsqueeze(0)
        sobel_y_kernel = torch.tensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]], dtype=torch.float32).unsqueeze(0)
        self.sobel_x.weight = nn.Parameter(sobel_x_kernel, requires_grad=False)
        self.sobel_y.weight = nn.Parameter(sobel_y_kernel, requires_grad=False)

    def forward(self, kem, ct_prior):
        """
        Forward pass with adaptive edge control.
        """
        # Extract edges from CT using Sobel filters
        ct_edges_x = self.sobel_x(ct_prior)
        ct_edges_y = self.sobel_y(ct_prior)
        ct_edges = torch.sqrt(ct_edges_x ** 2 + ct_edges_y ** 2)  # Edge magnitude

        # **🔹 Apply a threshold to filter weak edges**
        edge_threshold = 0.2  # Experiment with this value
        ct_edges = torch.where(ct_edges > edge_threshold, ct_edges, torch.tensor(0.0).to(ct_edges.device))

        # **🔹 Normalize edges to avoid overly bright areas**
        ct_edges = (ct_edges - ct_edges.min()) / (ct_edges.max() - ct_edges.min() + 1e-6)  # Avoid division by zero

        # Pass through DIP network
        output = self.net(kem)

        # **🔹 Learnable Weight Map for CT Edge Contribution**
        weight_map = torch.sigmoid(output)  # Generates values between 0 and 1
        refined_output = output + weight_map * ct_edges  # Scale CT edges dynamically

        return refined_output

# -----------------
# prevent tumor details from getting blurred
def gradient_difference_loss(output, reference):
    """
    Encourages the output to preserve high-frequency details (sharp edges).
    """
    grad_output_x = output[:, :, :-1, :] - output[:, :, 1:, :]
    grad_output_y = output[:, :, :, :-1] - output[:, :, :, 1:]

    grad_ref_x = reference[:, :, :-1, :] - reference[:, :, 1:, :]
    grad_ref_y = reference[:, :, :, :-1] - reference[:, :, :, 1:]

    loss_x = F.l1_loss(grad_output_x, grad_ref_x)
    loss_y = F.l1_loss(grad_output_y, grad_ref_y)

    return loss_x + loss_y


# --------------

def model2(kem_image, ct_image, angles, filter_method="bayesian"):
    # Ensure device compatibility
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    kem_recon = torch.tensor(kem_image, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    ct_prior = torch.tensor(ct_image, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Normalize
    kem_recon = (kem_recon - kem_recon.min()) / (kem_recon.max() - kem_recon.min())
    ct_prior = (ct_prior - ct_prior.min()) / (ct_prior.max() - ct_prior.min())

    # Resize CT to match KEM
    if ct_prior.shape != kem_recon.shape:
        ct_prior = F.interpolate(ct_prior, size=kem_recon.shape[2:], mode="bilinear", align_corners=False)

    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dip_model = DeepDIP_with_CT().to(device)
    optimizer = optim.Adam(dip_model.parameters(), lr=0.001)
    mse_loss = nn.MSELoss()

    # Training loop
    num_iterations = 500
    for i in range(num_iterations):
        optimizer.zero_grad()

        # Forward pass
        dip_output = dip_model(kem_recon, ct_prior)

        # Compute losses
        fidelity_loss = mse_loss(dip_output, kem_recon)  # Keep PET features
        ssim_loss = 1 - ssim(dip_output.squeeze().detach().cpu().numpy(), ct_prior.squeeze().detach().cpu().numpy(), data_range=1.0)
        edge_loss = gradient_difference_loss(dip_output, ct_prior)  # Edge preservation

        # Total loss (adjust weights)
        total_loss = fidelity_loss + 0.3 * ssim_loss + 0.5 * edge_loss

        total_loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print(f"Iteration {i}/{num_iterations}, Loss: {total_loss.item()}")

    # Get final refined image
    final_image = dip_output.squeeze().detach().cpu().numpy()

    # Display results
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(kem_recon.squeeze().cpu().numpy(), cmap="gray")
    plt.title("Initial KEM")
    plt.subplot(1, 3, 2)
    plt.imshow(ct_prior.squeeze().cpu().numpy(), cmap="gray")
    plt.title("CT Prior")
    plt.subplot(1, 3, 3)
    plt.imshow(final_image, cmap="gray")
    plt.title("Final Refined Image")
    plt.show()

    return final_image


refined_image = model2(only_kem_result, ct_prior, angles)
plt.figure(figsize=(7, 7))
plt.subplot(1, 2, 1)
plt.imshow(mlem_reconstructed, cmap='gray')
plt.title("MLEM")
plt.subplot(1, 2, 2)
plt.imshow(ground_truth, cmap='gray')
plt.title("Ground Truth")
plt.show()

In [None]:
#final_image_bayesian = DIP_to_KEM(
 #   sinogram=torch.tensor(pet_sinogram, dtype=torch.float32),
  #  angles=angles,
   # ct_prior=torch.tensor(ct_prior, dtype=torch.float32),
    #kernel_size=kernel_size,
#    num_iterations=num_iterations,
 #   dip_iterations=dip_iterations,
  #  sigma=sigma, 
   # filter_method="bayesian"
#)

In [None]:
#fig, axes = plt.subplots(1, 3, figsize=(18, 5))
#im1 = axes[0].imshow(final_image_bayesian cmap='gray')
#axes[0].set_title("Gaussian Kernel")
#axes[0].axis('off')
#fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)  

#im2 = axes[1].imshow(ground_truth, cmap='gray')
#axes[1].set_title("Ground Truth")  
#axes[1].axis('off')  
#fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)  

#im3 = axes[2].imshow(ground_truth-final_image_bayesian, cmap='gray')
#axes[2].set_title("Difference (Ground Truth - Reconstruction)")
#axes[2].axis('off')
#fig.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)
#plt.show()

In [None]:
# Evaluate
metrics_mlem = evaluate_reconstruction(ground_truth, mlem_reconstructed)
#metrics_model2_reversed_bilateral = evaluate_reconstruction(ground_truth, final_image_bilateral)
#metrics_model2_reversed_anisotropic = evaluate_reconstruction(ground_truth, final_image_anisotropic)
#metrics_model2_reversed_gaussian = evaluate_reconstruction(ground_truth, final_image_gaussian)
metrics_model2_reversed_bayesian = evaluate_reconstruction(ground_truth, refined_image)

print("MLEM:")
print(f"MSE: {metrics_mlem[0]:.4f}, PSNR: {metrics_mlem[1]:.2f}, SSIM: {metrics_mlem[2]:.4f}, RMSE: {metrics_mlem[3]:.4f} \n")

#print("Bilateral Kernel:")
#print(f"MSE: {metrics_model2_reversed_bilateral[0]:.4f}, PSNR: {metrics_model2_reversed_bilateral[1]:.2f}, SSIM: {metrics_model2_reversed_bilateral[2]:.4f}, RMSE: {metrics_model2_reversed_bilateral[3]:.4f} \n")

#print("Anisotropic Kernel:")
#print(f"MSE: {metrics_model2_reversed_anisotropic[0]:.4f}, PSNR: {metrics_model2_reversed_anisotropic[1]:.4f}, SSIM: {metrics_model2_reversed_anisotropic[2]:.4f}, RMSE: {metrics_model2_reversed_anisotropic[3]:.4f} \n")

#print("Gaussian Kernel:")
#print(f"MSE: {metrics_model2_reversed_gaussian[0]:.4f}, PSNR: {metrics_model2_reversed_gaussian[1]:.2f}, SSIM: {metrics_model2_reversed_gaussian[2]:.4f}, RMSE: {metrics_model2_reversed_gaussian[3]:.4f} \n")

print("Bayesian Kernel:")
print(f"MSE: {metrics_model2_reversed_bayesian[0]:.4f}, PSNR: {metrics_model2_reversed_bayesian[1]:.2f}, SSIM: {metrics_model2_reversed_bayesian[2]:.4f}, RMSE: {metrics_model2_reversed_bayesian[3]:.4f} \n")