In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import radon, iradon, resize
from scipy.ndimage import rotate
from skimage.draw import disk
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def EST_Angles(num_projections):
    """
    Generate projection angles that are equally spaced in Fourier space.

    Parameters:
    - num_projections: int, number of projections to generate.

    Returns:
    - theta: 1D array, the projection angles.
    """
    return np.arcsin(np.linspace(-1, 1, num_projections)) * 180 / np.pi

def make_phantom(size, num_disks):
    """
    Create a foam-like phantom with multiple disks of varying sizes and intensities.

    Parameters:
    - size: int, the size (height and width) of the phantom image.
    - num_disks: int, the number of disks to generate.

    Returns:
    - phantom: 2D array, the generated phantom image.
    """
    phantom = np.zeros((size, size))

    rng = np.random.default_rng()
    for _ in range(num_disks):
        # Random position
        x, y = rng.integers(0, size, size=2)
        # Random radius
        radius = rng.integers(5, size // 10)
        # Random intensity
        intensity = rng.random()

        rr, cc = disk((x, y), radius, shape=phantom.shape)
        phantom[rr, cc] += intensity
        phantom = np.clip(phantom, 0, 1)  # Ensure intensity stays within [0, 1]

    return phantom

def make_sino(phantom, theta):
    """
    Generate the sinogram from the phantom image using given projection angles.
   
    Parameters:
    - phantom: 2D array, the phantom image.
    - theta: 1D array, the projection angles.
   
    Returns:
    - sinogram: 2D array, the generated sinogram.
    """
    sinogram = radon(phantom, theta=theta, circle=True)
    return np.rot90(sinogram, 1)

def downsample_sino(sinogram, target_num_projections):
    """
    Downsample the sinogram to have the specified number of projections.
   
    Parameters:
    - sinogram: 2D array, the original sinogram.
    - target_num_projections: int, the desired number of projections.
   
    Returns:
    - downsampled_sino: 2D array, the downsampled sinogram.
    """
    downsampled_sino = resize(sinogram, (target_num_projections, sinogram.shape[1]), anti_aliasing=True)
    return downsampled_sino

def unet_SRM(input_shape):
    """
    Create a U-Net model for super-resolution of sinograms.
   
    Parameters:
    - input_shape: tuple, the shape of the input tensor.
   
    Returns:
    - model: the compiled U-Net model.
    """
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)

    # Decoder
    u5 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c4)
    u5 = layers.concatenate([u5, c3])
    c5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u5)
    c5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c5)

    u6 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c2])
    c6 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c6)

    u7 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c1])
    c7 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c7)

    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c7)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='mean_squared_error')
    return model

def augment_data(low_res_sinos, high_res_sinos):
    """
    Augment the training data using random transformations.
   
    Parameters:
    - low_res_sinos: list of low-resolution sinograms.
    - high_res_sinos: list of high-resolution sinograms.
   
    Returns:
    - augmented_low_res_sinos: augmented low-resolution sinograms.
    - augmented_high_res_sinos: augmented high-resolution sinograms.
    """
    datagen = ImageDataGenerator(rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
   
    augmented_low_res_sinos = []
    augmented_high_res_sinos = []
   
    for low_res, high_res in zip(low_res_sinos, high_res_sinos):
        low_res = low_res[np.newaxis, ..., np.newaxis]
        high_res = high_res[np.newaxis, ..., np.newaxis]
       
        for _ in range(5):  # Generate 5 augmentations per image
            for augmented_low, augmented_high in zip(datagen.flow(low_res, batch_size=1), datagen.flow(high_res, batch_size=1)):
                augmented_low_res_sinos.append(augmented_low[0, ..., 0])
                augmented_high_res_sinos.append(augmented_high[0, ..., 0])
                break
   
    return np.array(augmented_low_res_sinos), np.array(augmented_high_res_sinos)

def train_SRM(model, low_res_sinos, high_res_sinos, epochs=500, batch_size=32):
    """
    Train the super-resolution model.
   
    Parameters:
    - model: the CNN model for super-resolution.
    - low_res_sinos: list of low-resolution sinograms.
    - high_res_sinos: list of high-resolution sinograms.
    - epochs: int, number of epochs for training.
    - batch_size: int, batch size for training.
    """
    low_res_sinos = np.array(low_res_sinos)[..., np.newaxis]
    high_res_sinos = np.array(high_res_sinos)[..., np.newaxis]
    model.fit(low_res_sinos, high_res_sinos, epochs=epochs, batch_size=batch_size, validation_split=0.1)

def apply_SRM(model, sinogram):
    """
    Apply the super-resolution model to a sinogram.
   
    Parameters:
    - model: the trained CNN model for super-resolution.
    - sinogram: 2D array, the low-resolution sinogram.
   
    Returns:
    - super_res_sino: 2D array, the super-resolved sinogram.
    """
    sinogram = sinogram[np.newaxis, ..., np.newaxis]
    super_res_sino = model.predict(sinogram)[0, ..., 0]
    return super_res_sino

def EST_reco(sinogram, theta, num_iterations=300, tolerance=1e-5, learning_rate=0.0001):
    """
    Perform Equally Sloped Tomography (EST) reconstruction using simple gradient descent.

    Parameters:
    - sinogram: 2D array, the sinogram data (projections).
    - theta: 1D array, the projection angles.
    - num_iterations: int, number of iterations for the EST algorithm.
    - tolerance: float, tolerance for early stopping based on maximum error change.
    - learning_rate: float, step size for the updates.

    Returns:
    - reconstructed_image: 2D array, the reconstructed image.
    """
    num_projections, num_detectors = sinogram.shape

    # Initial guess using Filtered Back Projection (FBP)
    initial_guess = iradon(np.rot90(sinogram, -1), theta=theta, circle=True)
    estimated_image = initial_guess.copy()

    previous_error = np.inf

    for iteration in range(num_iterations):
        # Forward projection
        projections = np.zeros((num_projections, num_detectors))

        for i, angle in enumerate(theta):
            rotated_image = rotate(estimated_image, angle, reshape=False, order=1)
            projection = np.sum(rotated_image, axis=0)

            if projection.shape[0] > num_detectors:
                projection = projection[:num_detectors]
            elif projection.shape[0] < num_detectors:
                projection = np.pad(projection, (0, num_detectors - projection.shape[0]), 'constant')

            projections[i, :] = projection

        # Compute error
        error = sinogram - projections
        max_error = np.max(np.abs(error))

        print(f"Iteration {iteration + 1}/{num_iterations} - Max error: {max_error}")

        # Early stopping if error change is below the tolerance
        if np.abs(previous_error - max_error) < tolerance:
            print(f"Convergence reached at iteration {iteration + 1}")
            break

        previous_error = max_error

        # Backproject the error
        backprojected_error = np.zeros_like(estimated_image)
        for i, angle in enumerate(theta):
            expanded_error = np.zeros_like(estimated_image)
            expanded_error[:, :num_detectors] = error[i, :].reshape(-1, num_detectors)
            rotated_error = rotate(expanded_error, -angle, reshape=False, order=1)
            backprojected_error += rotated_error

        # Normalize the backprojected error to avoid excessive accumulation
        if np.max(np.abs(backprojected_error)) != 0:
            backprojected_error /= np.max(np.abs(backprojected_error))

        # Apply gradient descent update
        estimated_image += learning_rate * (backprojected_error / num_projections)
        estimated_image = np.maximum(estimated_image, 0)  # Apply non-negativity constraint

    return estimated_image

# Example usage
if __name__ == "__main__":
    # Generate multiple foam-like phantoms and their corresponding sinograms
    num_samples = 50  # Increased number of samples for training
    size = 128
    num_disks = 30

    num_detectors = 128
    low_res_num_projections = 64
    high_res_num_projections = 256

    low_res_sinos = []
    high_res_sinos = []

    for _ in range(num_samples):
        phantom = make_phantom(size, num_disks)
        low_res_theta = EST_Angles(low_res_num_projections)
        high_res_theta = EST_Angles(high_res_num_projections)
        low_res_sinos.append(make_sino(phantom, low_res_theta))
        high_res_sino = make_sino(phantom, high_res_theta)
        downsampled_high_res_sino = downsample_sino(high_res_sino, low_res_num_projections)
        high_res_sinos.append(downsampled_high_res_sino)

    # Augment the training data
    augmented_low_res_sinos, augmented_high_res_sinos = augment_data(low_res_sinos, high_res_sinos)

    # Create and train the super-resolution model
    input_shape = (low_res_num_projections, num_detectors, 1)
    super_res_model = unet_SRM(input_shape)
    train_SRM(super_res_model, augmented_low_res_sinos, augmented_high_res_sinos, epochs=500, batch_size=32)

    # Apply super-resolution to a new low-resolution sinogram
    test_phantom = make_phantom(size, num_disks)
    low_res_theta = EST_Angles(low_res_num_projections)
    test_low_res_sino = make_sino(test_phantom, low_res_theta)

    # Transpose the sinogram to match the expected input shape of the model
    test_low_res_sino_transposed = np.transpose(test_low_res_sino)

    enhanced_sino_transposed = apply_SRM(super_res_model, test_low_res_sino_transposed)

    # Transpose back the enhanced sinogram to the original shape
    enhanced_sino = np.transpose(enhanced_sino_transposed)

    # Ensure that the number of projections in the sinogram matches the theta array
    high_res_theta = EST_Angles(enhanced_sino.shape[0])

    # Perform EST reconstruction on the enhanced sinogram
    reconstructed_image_est = EST_reco(enhanced_sino, high_res_theta, num_iterations=300, learning_rate=0.0001)

    # Perform FBP reconstruction on the enhanced sinogram
    reconstructed_image_fbp_enhanced = iradon(np.rot90(enhanced_sino, -1), theta=high_res_theta, circle=True)

    # Perform FBP reconstruction on the original low-resolution sinogram
    reconstructed_image_fbp_low_res = iradon(np.rot90(test_low_res_sino, -1), theta=low_res_theta, circle=True)

    # Plot results
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Reconstruction Results', fontsize=20)

    axes[0, 0].imshow(test_phantom, cmap='gray', norm=Normalize(vmin=0, vmax=1))
    axes[0, 0].set_title("Foam-like Phantom")
    axes[0, 0].axis('off')

    axes[0, 1].imshow(test_low_res_sino, cmap='gray', aspect='auto', norm=Normalize(vmin=np.min(test_low_res_sino), vmax=np.max(test_low_res_sino)))
    axes[0, 1].set_title("Low-Res Sinogram")
    axes[0, 1].axis('off')

    axes[0, 2].imshow(enhanced_sino, cmap='gray', aspect='auto', norm=Normalize(vmin=np.min(enhanced_sino), vmax=np.max(enhanced_sino)))
    axes[0, 2].set_title("Enhanced Sinogram")
    axes[0, 2].axis('off')

    axes[1, 0].imshow(reconstructed_image_fbp_low_res, cmap='gray', norm=Normalize(vmin=0, vmax=1))
    axes[1, 0].set_title("FBP Recon (Low-Res)")
    axes[1, 0].axis('off')

    axes[1, 1].imshow(reconstructed_image_est, cmap='gray', norm=Normalize(vmin=0, vmax=1))
    axes[1, 1].set_title("EST Recon (Enhanced)")
    axes[1, 1].axis('off')

    axes[1, 2].imshow(reconstructed_image_fbp_enhanced, cmap='gray', norm=Normalize(vmin=0, vmax=1))
    axes[1, 2].set_title("FBP Recon (Enhanced)")
    axes[1, 2].axis('off')

    plt.show()

2024-06-20 13:07:37.427906: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-20 13:07:37.432542: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-20 13:07:37.488965: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  warn(
2024-06-20 13:08:00.159544: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices..

Epoch 1/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 2s/step - loss: 491.1054 - val_loss: 512.6760
Epoch 2/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 467.3536 - val_loss: 511.0824
Epoch 3/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 470.6678 - val_loss: 511.0820
Epoch 4/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 478.6335 - val_loss: 511.0820
Epoch 5/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 465.3578 - val_loss: 511.0820
Epoch 6/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 476.0562 - val_loss: 511.0820
Epoch 7/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 475.9238 - val_loss: 511.0820
Epoch 8/500
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 2s/step - loss: 469.6108 - val_loss: 511.0820
Epoch 9/500
[1m8/8[0m [32m━━━