In [1]:
import numpy as np
import odl
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.ndimage import gaussian_filter
from skimage.data import shepp_logan_phantom
from skimage.transform import resize,radon,iradon
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.transform import radon, iradon
from scipy.optimize import approx_fprime
from scipy.sparse.linalg import eigsh

In [2]:
    #Get objective function
global reco_space, angle_partition, detector_partition, geometry, ray_trafo,ray_trafo_adjoint,fbp
reco_space = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32')
angle_partition = odl.uniform_partition(0, np.pi, 1800)
detector_partition = odl.uniform_partition(-30, 30, 500)
geometry = odl.tomo.Parallel2dGeometry(angle_partition, detector_partition)
# Create the foward operator 
ray_trafo = odl.tomo.RayTransform(reco_space,geometry)
ray_trafo_adjoint=ray_trafo.adjoint
fbp = odl.tomo.fbp_op(ray_trafo, filter_type='Hann', frequency_scaling=0.8)

In [3]:
def zero_outside_circle(img):
    n = img.shape[0]
    y, x = np.ogrid[:n, :n]
    center = (n - 1) / 2.0
    mask = (x - center)**2 + (y - center)**2 <= center**2
    img[~mask] = 0
    return img

In [4]:
def generate_shepp_logan_data(shape=(300,300)):
    #Create phantom
    anat_phantom = odl.phantom.shepp_logan(reco_space, modified=True)
    # max_val = np.max(anat_phantom)
    # scale_factor = 100.0 / max_val
    # emission_phantom = anat_phantom * scale_factor #Rescale
    shepp_logan = odl.phantom.shepp_logan(reco_space,modified = True)
    phantom_array = shepp_logan.asarray().copy()
    x, y = np.meshgrid(np.linspace(-20, 20, 300), np.linspace(-20, 20, 300))
    hot_spot = ((x - 5)**2 + (y - 5)**2) < 2**2
    phantom_array[hot_spot] = phantom_array.max() * 0.8
    phantom_array *= 0.7
    modified_phantom = reco_space.element(phantom_array)
    pet_data = ray_trafo(modified_phantom)
    pet_data = odl.phantom.poisson_noise(pet_data)* 1.0 #add noise
    # pet_data /= np.max(pet_data)
    return anat_phantom, modified_phantom, pet_data
# # visualization
# fig, axes = plt.subplots(1, 3)
# axes[0].imshow(anat_phantom, cmap='gray')
# axes[0].set_title("Generated CT Image")
# axes[1].imshow(emission_phantom, cmap='hot')
# axes[1].set_title("Generated PET Image")
# axes[2].imshow(pet_data, cmap='gray')
# axes[2].set_title("Generated PET data")
# plt.show()

In [5]:
# Bowsher weighted function
def psi(u, zeta=0.5, rho=0.01):
    return (np.arctan((zeta - u) / rho) / np.pi) + 0.5
    
def bowsher_weights(anat_phantom, voxel, neighbors,zeta=0.5, rho=0.01,epsilon=1e-6):
    central_value = anat_phantom[voxel]
    neighbor_values = np.array([anat_phantom[n] for n in neighbors])
    
    Mj = np.max(np.abs(central_value - neighbor_values))
    Mk_values = [np.max(np.abs(anat_phantom[n] - neighbor_values)) for n in neighbors]
    
    denominator = (Mj + np.array(Mk_values)) / 2
    denominator = np.maximum(denominator, epsilon)
    
    weights = psi(np.abs(central_value - neighbor_values) / denominator, zeta, rho)
    return weights / np.sum(weights)

In [6]:
# MAP objective function
def map_objective(recon_image, pet_data, regularization):
    recon_image = recon_image.reshape(shape)
    recon_projected = ray_trafo(recon_image)
    likelihood = odl.solvers.iterative.statistical.poisson_log_likelihood(recon_projected,pet_data)
    prior = regularization
    gradient = compute_map_gradient(recon_image, pet_data, anat_phantom, shape)
    return -likelihood + prior

In [7]:
def compute_map_gradient(recon_image, pet_data, anat_phantom, shape, alpha=0.01, beta=1.0):
    # Reshape reconstruction image
    recon_image = recon_image.reshape(shape)
    
    # Compute gradient of likelihood term
    recon_projected = ray_trafo(recon_image)
    likelihood_grad = ray_trafo_adjoint(1 - pet_data / (recon_projected + 1e-6))
    
    # Compute gradient of regularization term
    reg_grad = np.zeros_like(recon_image)
    for i in range(shape[0]):
        for j in range(shape[1]):
            voxel = (i, j)
            neighbors = get_neighbors(voxel, shape)
            weights = bowsher_weights(anat_phantom, voxel, neighbors, beta)
            voxel_value = recon_image[voxel]
            for k, neighbor in enumerate(neighbors):
                neighbor_value = recon_image[neighbor]
                reg_grad[voxel] += 2 * alpha * weights[k] * (voxel_value - neighbor_value)
    
    # Combine gradients
    map_gradient = likelihood_grad + reg_grad
    return map_gradient

In [8]:
# Anatomical information regularization term
def anatomical_regularization(parameters, anat_phantom, shape, alpha=0.1, beta=1.0):
    penalty = 0
    for i in range(shape[0]):
        for j in range(shape[1]):
            voxel = (i, j)
            neighbors = get_neighbors(voxel, shape)
            weights = bowsher_weights(anat_phantom, voxel, neighbors, beta)
            voxel_value = parameters[voxel]
            neighbor_values = np.array([parameters[neighbor] for neighbor in neighbors])
            penalty += alpha * np.sum(weights * (voxel_value - neighbor_values) ** 2)
    return penalty


In [9]:
# Get 8 neighbor voxels  
def get_neighbors(voxel, shape):
    x, y = voxel
    neighbors = []
    for dx in [-1, 0, 1]:
        for dy in [-1, 0, 1]:
            if dx == 0 and dy == 0:
                continue
            nx, ny = x + dx, y + dy
            if 0 <= nx < shape[0] and 0 <= ny < shape[1]:
                neighbors.append((nx, ny))
    return neighbors

In [10]:
class MAPObjective(odl.solvers.Functional):
    def __init__(self, space, pet_data, anat_phantom, shape, regularization_func, ray_trafo, alpha=0.01, beta=1.0):
        super().__init__(space)
        self.pet_data = pet_data
        self.anat_phantom = anat_phantom
        self.shape = shape
        self.regularization_func = regularization_func
        self.ray_trafo = ray_trafo
        self.alpha = alpha
        self.beta = beta

    def __call__(self, x):
        recon_image = x.asarray().reshape(self.shape)
        recon_projected = self.ray_trafo(recon_image)
        likelihood = odl.solvers.iterative.statistical.poisson_log_likelihood(
            recon_projected, self.pet_data
        )
        prior = self.regularization_func(recon_image, self.anat_phantom, self.shape, self.alpha, self.beta)
        return -likelihood + prior

    @property
    def gradient(self):
        class GradientOperator(odl.Operator):
            def __init__(self, space, objective):
                super().__init__(space, space)
                self.objective = objective

            def _call(self, x, out):
                recon_image = x.asarray().reshape(self.objective.shape)
            
                grad = compute_map_gradient(
                    recon_image,
                    self.objective.pet_data,
                    self.objective.anat_phantom,
                    self.objective.shape,
                    self.objective.alpha,
                    self.objective.beta
                )
            
                out_arr = out.asarray()
                np.copyto(out_arr, grad)  

        return GradientOperator(self.domain, self)

In [11]:
# Optimization
def optimize_parameters(pet_data, anat_phantom, shape, alpha=0.01, beta=1.0):
    # Using fbp result as initial
    # initial_params=fbp(pet_data)
    # initial_params = np.maximum(initial_params, 0)
    # initial_params = np.asarray(initial_params).flatten()
    
    # Using one matrix as initial
    # initial_params = np.ones(shape).flatten()

    # Using MlEM result as initial
    initial_params= ray_trafo.domain.one()
    odl.solvers.iterative.statistical.mlem(ray_trafo, initial_params, pet_data, 1)
    #Optimization
    space = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20], shape=shape)
    # Objective function
    objective = MAPObjective(space, pet_data, anat_phantom, shape, alpha, beta)
    initial_params = space.element(initial_params)
    odl.solvers.smooth.nonlinear_cg.conjugate_gradient_nonlinear(objective, initial_params, maxiter=15)
    return initial_params
    return initial_params

In [12]:
# Evaluate the image's quality
def evaluate_quality(reference, reconstructed):
    data_range = reference.max() - reference.min()
    psnr = peak_signal_noise_ratio(reference, reconstructed,data_range=1.0)
    ssim = structural_similarity(reference, reconstructed,data_range=1.0)
    return psnr, ssim

In [None]:
# Main
if __name__ == "__main__":
    anat_phantom, emission_phantom,pet_data = generate_shepp_logan_data()
    shape = emission_phantom.shape
    estimated_params = optimize_parameters(pet_data, anat_phantom, shape)
    
    # visualization
plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.title("PET Image (Shepp-Logan Phantom)")
plt.imshow(emission_phantom, cmap='gray')
plt.colorbar()

plt.subplot(1, 3, 2)
plt.title("PET Data")
plt.imshow(pet_data, cmap='gray', aspect='auto')  # 调整 aspect 使其更自然
plt.colorbar()

plt.subplot(1, 3, 3)
plt.title("Reconstructed Image")
plt.imshow(estimated_params, cmap='gray')
plt.colorbar()

plt.tight_layout(w_pad=3.0, h_pad=2.0)
plt.show()

In [None]:
# # Determine whether the function is non-convex
# def hessian(func, x, epsilon=1e-5):
#     """
#      Hessian matrix
#     """
#     x = np.asarray(x)
#     n = x.shape[0]
#     hess = np.zeros((n, n))
#     for i in range(n):
#         for j in range(n):
#             x_ij1 = x.copy()
#             x_ij2 = x.copy()
#             x_ij3 = x.copy()
#             x_ij4 = x.copy()
#             x_ij1[i] += epsilon
#             x_ij1[j] += epsilon
#             x_ij2[i] += epsilon
#             x_ij2[j] -= epsilon
#             x_ij3[i] -= epsilon
#             x_ij3[j] += epsilon
#             x_ij4[i] -= epsilon
#             x_ij4[j] -= epsilon
#             hess[i, j] = (func(x_ij1) - func(x_ij2) - func(x_ij3) + func(x_ij4)) / (4 * epsilon ** 2)
#     return hess

# # Check if Hessian is positive definite
# def is_hessian_positive_semidefinite(hessian_matrix):
#     # characteristic value
#     min_eigenvalue = np.min(np.linalg.eigvalsh(hessian_matrix))
#     return min_eigenvalue >= 0
# def wrapped_map_objective(parameters):
#     emission_phantom = parameters.reshape(shape)
#     regularization = anatomical_regularization(parameters, anat_phantom, shape)
#     return map_objective(emission_phantom, pet_data, regularization)

# shape = emission_phantom.shape
# parameters=emission_phantom
# regularization = anatomical_regularization(parameters, anat_phantom, shape)
# objecticve_function=map_objective(emission_phantom, pet_data, regularization)

# # Hessian matrix
# hess = hessian(wrapped_map_objective, parameters)

# is_convex = is_hessian_positive_semidefinite(hess)
# if is_convex:
#     print("objective function is convex（Hessian matrix is positive definite）。")
# else:
#     print("objective function is not convex（Hessian matrix has minus characteristic value）。")