In [None]:
#   Copyright 2025 UKRI-STFC

#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
# Authors:
# Franck Vidal (URKI-STFC)

# RXSolutionsReader Laminography Demo

## Data format: RX Solutions

The data is in the format used by devices made by [RX Solutions](https://www.rx-solutions.com/en). The projections are saved in TIFF files. They are flatfield corrected using 16-bit unsigned integers. Metadata is saved in two different files, an XML file that can be used with orbital geometries, and a CSV file that can be used with flexible geometries.

## CIL Version

This notebook was developed using CIL v25.0.0

## Dataset
The data is available from Zenodo: https://doi.org/10.5281/zenodo.??????

It is a laminography dataset of ???. 
It was acquired with the ???? platform developed by [RX Solutions](https://www.rx-solutions.com/en) for the [MATEIS Laboratory](https://mateis.insa-lyon.fr/en) of [INSA-Lyon](https://www.insa-lyon.fr/en/).

Update this filepath to where you have saved the dataset:

In [None]:
import os

data_path = "/DATA/CT/2025/DTHE"
number_of_slices_to_reconstruct = 500 # Use 0 to compute it automatically
pixel_pitch_in_mm = (0.15,0.15)
scaling_factor = 3
first_angle=360
last_angle=0

# data_path = "/DATA/CT/2025/RX_Solutions/suzanne_circular"
# number_of_slices_to_reconstruct = 0 # Use 0 to compute it automatically
# pixel_pitch_in_mm = (0.5,0.5)
# scaling_factor = 3
# first_angle=0
# last_angle=360

file_path = os.path.join(data_path, 'unireconstruction.xml')
# file_path = os.path.join(data_path, 'geometry.csv')

In [None]:
import numpy as np
import gc

from cil.utilities.display import show2D, show_geometry, show_system_positions
from cil.processors import TransmissionAbsorptionConverter, Slicer, CentreOfRotationCorrector
from cil.framework import ImageGeometry
from cil.plugins.astra import FBP
from cil.utilities.jupyter import islicer, link_islicer
from cil.io.TIFF import TIFFWriter

import torch

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Device:", device)

from cil.optimisation.functions import Function
import deepinv

from readers.RXSolutionsDataReader import RXSolutionsDataReader

# Loading Geometry

In [None]:
scaling_factor = 3

if scaling_factor == 1:
    roi = None
else:
    roi = {"axis_1": [None, None, scaling_factor], "axis_2": [None, None, scaling_factor]}

In [None]:
reader = RXSolutionsDataReader(XML_file_path, pixel_pitch_in_mm=pixel_pitch_in_mm, first_angle=first_angle, last_angle=last_angle, last_angle_included=False, roi=roi)

In [None]:
acq_geom = reader.get_geometry()

In [None]:
if acq_geom.geom_type != "CONE_FLEX":
    show_geometry(acq_geom)
else:
    show_system_positions(acq_geom)

In [None]:
print(acq_geom)

# Loading Projections

In [None]:
acq_data = reader.read()

In [None]:
show2D(acq_data, origin='upper-left');

# Pre-processing

In [None]:
data_exp = TransmissionAbsorptionConverter()(acq_data)

In [None]:
if acq_geom.geom_type != "CONE_FLEX":
    processor = CentreOfRotationCorrector.image_sharpness("centre", "tigre")
    processor.set_input(data_exp)
    data_corr = processor.get_output()
else:
    data_corr = data_exp

# Prepare the data for Astra
data_corr.reorder(order='astra')

In [None]:
if acq_geom.geom_type != "CONE_FLEX":
    image_geometry = data_corr.geometry.get_ImageGeometry()
        
    image_geometry.voxel_size_x = min(image_geometry.voxel_size_x, image_geometry.voxel_size_y, image_geometry.voxel_size_z)
    image_geometry.voxel_size_y = image_geometry.voxel_size_x
    image_geometry.voxel_size_z = image_geometry.voxel_size_x
else:
    # Use the system magnification to compute the voxel size
    mag = data_corr.geometry.magnification
    mean_mag = np.mean(mag)
    print("Mean magnification: ", mean_mag)
    
    voxel_size_xy = data_corr.geometry.config.panel.pixel_size[0] / mean_mag
    voxel_size_z = data_corr.geometry.config.panel.pixel_size[1] / mean_mag
    
    # Create an image geometry
    num_voxel_xy = int(np.ceil(data_corr.geometry.config.panel.num_pixels[0]))
    num_voxel_z = int(np.ceil(data_corr.geometry.config.panel.num_pixels[1]))
    
    image_geometry = ImageGeometry(num_voxel_xy, num_voxel_xy, num_voxel_z, voxel_size_xy, voxel_size_xy, voxel_size_z)

if number_of_slices_to_reconstruct > 0:
    image_geometry.voxel_num_z = number_of_slices_to_reconstruct // scaling_factor

print(image_geometry)

# Using a FDK for the reconstruction

In [None]:
# Reconstruct using FDK
# Instantiate the reconsruction algorithm
fdk = FBP(image_geometry, data_corr.geometry)
fdk.set_input(data_corr)

# Perform the actual CT reconstruction
FDK_recon = fdk.get_output()

## Release memory

In [None]:
del data_exp
del acq_data
del reader

gc.collect();

## Save the reconstruction as a stack of TIFF files

In [None]:
writer = TIFFWriter(FDK_recon, os.path.join(data_path, "FDK-recon/slice"))
writer.write()

## Visualise the reconstruction

In [None]:
islicer(FDK_recon)

# Using TV regularised least squares solved with FISTA for the reconstruction

In [None]:
from cil.plugins.astra import ProjectionOperator
from cil.optimisation.functions import LeastSquares
from cil.plugins.ccpi_regularisation.functions import FGP_TV
from cil.optimisation.algorithms import FISTA

projector = ProjectionOperator(image_geometry, data_corr.geometry)
LS = LeastSquares(A=projector, b=data_corr)

alpha = 0.05
TV = FGP_TV(alpha=alpha, nonnegativity=True, device='gpu')
fista_TV = FISTA(initial=FDK_recon, f=LS, g=TV, update_objective_interval=10)

In [None]:
fista_TV.objective

In [None]:
fix_range = (FDK_recon.min(), FDK_recon.max())

for i in range(4):
    fista_TV.run(25,verbose=1)
    show2D(fista_TV.solution, title = 'Iteration {}'.format(i*25), fix_range=fix_range, origin='upper-left', size=(5,5))

In [None]:
TV_recon = fista_TV.solution

In [None]:
del fista_TV
del TV

gc.collect();

## Save the reconstruction as a stack of TIFF files

In [None]:
writer = TIFFWriter(TV_recon, os.path.join(data_path, "TV-recon/slice"))
writer.write()

## Visualise the reconstruction

In [None]:
islicer(TV_recon)

# Compare the two reconstructions

In [None]:
half_number_of_slices = FDK_recon.shape[0] // 2

show2D([FDK_recon, TV_recon], origin='upper-left', fix_range=fix_range, slice_list=(('vertical',half_number_of_slices // 2)))
show2D([FDK_recon, TV_recon], origin='upper-left', fix_range=fix_range, slice_list=(('vertical',half_number_of_slices)))
show2D([FDK_recon, TV_recon], origin='upper-left', fix_range=fix_range, slice_list=(('vertical',half_number_of_slices + half_number_of_slices // 2)))

In [None]:
link_islicer(islicer(FDK_recon), islicer(TV_recon))

In [None]:
import matplotlib.pyplot as plt

In [None]:
row_id = FDK_recon.shape[1] // 2
plt.plot(FDK_recon.as_array()[half_number_of_slices//2, row_id], label="FDK")
plt.plot(TV_recon.as_array()[half_number_of_slices//2, row_id], label="TV")
plt.legend()
plt.show()

row_id = FDK_recon.shape[1] // 2
plt.plot(FDK_recon.as_array()[half_number_of_slices, row_id], label="FDK")
plt.plot(TV_recon.as_array()[half_number_of_slices, row_id], label="TV")
plt.legend()
plt.show()

row_id = FDK_recon.shape[1] // 2
plt.plot(FDK_recon.as_array()[half_number_of_slices + half_number_of_slices//2, row_id], label="FDK")
plt.plot(TV_recon.as_array()[half_number_of_slices + half_number_of_slices//2, row_id], label="TV")
plt.legend()
plt.show()
# islicer()

In [None]:

class DenoiserProximal(Function):
    """
    DenoiserProximal is a custom CIL function that, when evaluated (__call__), returns 0. 
    It implements a proximal operator via a torch-based denoiser. When the
    proximal() method is called, the input CIL data container is converted into a PyTorch
    tensor, processed with the denoiser  using the specified noise level (tau), and then
    wrapped back into a CIL data container.

    Parameters:
        denoiser: The torch-based denoiser which accepts an input tensor and a noise level.
        device: The torch device (e.g., 'cuda' or 'cpu') on which the denoiser runs.
    """

    def __init__(self, denoiser, device='cuda'):
        self.device = torch.device(device)
        self.denoiser = denoiser 
        super(DenoiserProximal, self).__init__()

    def __call__(self, x):
        # This function merely returns 0 as its evaluation.
        return 0 

    def cil_to_torch(self, x):
        """
        Convert a CIL data container to a PyTorch tensor.

        This method extracts the 'array' attribute from the input CIL data container,
        moves the data to the designated device, and adjusts the tensor's shape by squeezing
        out the first dimension and adding a channel dimension. This reshaped tensor is then
        ready to be passed to the denoiser denoiser.

        Parameters:
            x: A CIL data container with an 'array' attribute containing the data.

        Returns:
            torch.Tensor: A PyTorch tensor formatted for the denoiser denoiser.
        """
        return (torch.tensor(x.array, device=self.device)
                .squeeze(0).unsqueeze(1))
    
    def torch_to_cil(self, x_tens, out):
        """
        Convert a PyTorch tensor to a CIL data container.

        After the denoiser processes the input, this method converts the resulting PyTorch tensor
        back into the format expected by a CIL data container. It performs a reverse of the shaping operations
        applied in cil_to_torch (i.e., removing the channel dimension and adding back the batch dimension)
        and updates the 'array' attribute of the output container.

        Parameters:
            x_tens (torch.Tensor): The processed tensor from the denoiser.
            out: A pre-allocated CIL data container to store the final output data.
        """
        out.array[:] = (x_tens.squeeze(1).unsqueeze(0)
                        .detach().cpu().numpy())
            
    def proximal(self, x, tau, out=None):
        """
        Apply the proximal operator via a torch-based denoiser to a CIL data container.

        This method implements the proximal step by first converting the input CIL data container
        to a PyTorch tensor using cil_to_torch. The tensor is then passed to the denoiser along
        with the provided noise level 'tau'. The output tensor is converted back into a CIL data container using
        torch_to_cil. If no output container is provided (i.e., out is None), a new container is allocated
        based on the geometry of x.

        Parameters:
            x: The input CIL data container to be processed.
            tau (float): A scalar noise level parameter passed to the denoiser.
            out: (Optional) A pre-allocated CIL data container for returning the result. If not provided,
                 a new container is allocated.

        Returns:
            A CIL data container containing the denoiser-processed data.
        """
        if out is None: 
            out = x.geometry.allocate(None)

        with torch.no_grad():
            x_torch = self.cil_to_torch(x)
            x_torch = self.denoiser(x_torch, tau)
            self.torch_to_cil(x_torch, out)
        return out 
        


In [None]:
class DenoiserProximal3D(DenoiserProximal ):
    
    def __init__(self, denoiser, device='cuda:1'):


        super(DenoiserProximal3D, self).__init__(denoiser, device)

    def proximal(self, x, tau, out=None): 
        if out is None: 
            out = x.geometry.allocate(None)


        with torch.no_grad():
            x_torch = self.cil_to_torch(x)
            
            x_torch = self.denoiser(x_torch, tau)
            
            x_torch = x_torch.permute( 2, 1, 0, 3 ) #permute
            x_torch= self.denoiser(x_torch, tau)
            x_torch = x_torch.permute( 2, 1, 0, 3 ) #permute back
        
            x_torch = x_torch.permute( 3, 1, 2, 0 ) #permute
            x_torch= self.denoiser(x_torch, tau)
            x_torch = x_torch.permute( 3, 1, 2, 0 ) #permute back
            
            self.torch_to_cil(x_torch, out)
        return out 


In [None]:
denoiser = deepinv.models.DRUNet(in_channels=1, out_channels=1, pretrained='download', device=device)

In [None]:
x0 = FDK_recon 
lamb=500
Regulariser = lamb*DenoiserProximal3D(denoiser, device)
FISTA_DRUNet3D = FISTA(f=LS, 
                  g=Regulariser, 
                  initial=x0 ,
                  update_objective_interval = 10)

In [None]:
for i in range(4):
    FISTA_DRUNet3D.run(25,verbose=1)
    show2D(FISTA_DRUNet3D.solution, title = 'Iteration {}'.format(i*25), fix_range=fix_range, origin='upper-left', size=(5,5))

In [None]:
DRUNet3D_recon = FISTA_DRUNet3D.solution

In [None]:
row_id = FDK_recon.shape[1] // 2
plt.plot(FDK_recon.as_array()[half_number_of_slices//2, row_id], label="FDK")
plt.plot(TV_recon.as_array()[half_number_of_slices//2, row_id], label="TV")
plt.plot(DRUNet3D_recon.as_array()[half_number_of_slices//2, row_id], label="DRUNet3D")
plt.legend()
plt.show()

row_id = FDK_recon.shape[1] // 2
plt.plot(FDK_recon.as_array()[half_number_of_slices, row_id], label="FDK")
plt.plot(TV_recon.as_array()[half_number_of_slices, row_id], label="TV")
plt.plot(DRUNet3D_recon.as_array()[half_number_of_slices, row_id], label="DRUNet3D")
plt.legend()
plt.show()

row_id = FDK_recon.shape[1] // 2
plt.plot(FDK_recon.as_array()[half_number_of_slices + half_number_of_slices//2, row_id], label="FDK")
plt.plot(TV_recon.as_array()[half_number_of_slices + half_number_of_slices//2, row_id], label="TV")
plt.plot(DRUNet3D_recon.as_array()[half_number_of_slices + half_number_of_slices//2, row_id], label="DRUNet3D")
plt.legend()
plt.show()
# islicer()

In [None]:
show2D([FDK_recon, TV_recon, DRUNet3D_recon], origin='upper-left', fix_range=fix_range, size=(5,5), slice_list=(('vertical',half_number_of_slices // 2)))
show2D([FDK_recon, TV_recon, DRUNet3D_recon], origin='upper-left', fix_range=fix_range, size=(5,5), slice_list=(('vertical',half_number_of_slices)))
show2D([FDK_recon, TV_recon, DRUNet3D_recon], origin='upper-left', fix_range=fix_range, size=(5,5), slice_list=(('vertical',half_number_of_slices + half_number_of_slices // 2)))