## Sirius/LNLS - Scientific Computing Group
## Tomography pre-processing pipeline

This notebook contains an example on how to perform pre-processing of phase-contrast images before Tomography.
The sinogram alignment functions are called from **sscRaft package**.

## Imports

In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import json, h5py

import sscCdi
print(f'sscCdi version: {sscCdi.__version__}')

import sscRaft
print(f'sscRaft version: {sscRaft.__version__}')

plt.rcParams['image.cmap'] = 'cividis' # select standard colormap for plots here
dic = {}

## Load data

In [None]:
obj = np.load('../data/tomo/data.npy')
angles = np.load('../data/tomo/angles.npy') # column 1 contains angles in degrees

In [None]:
sscCdi.visualize_magnitude_and_phase(obj,axis=0,cmap='gray',aspect_ratio='')

In [None]:
sscCdi.deploy_visualizer(obj,type='abs',title='Original sinogram',cmap='gray') # type = amplitude, phase, real or imag

In [None]:
dic["bad_frames"] = []
dic['sinogram_filepath'] = "" # path to save new sinogram

obj, angles = sscCdi.remove_frames_from_sinogram(obj,angles,dic["bad_frames"],ordered_object_filepath="")

## Crop data

In [None]:
dic["top_crop"]    = 2 # n of pixels to remove on top
dic["bottom_crop"] = 2 # n of pixels to remove on bottom
dic["left_crop"]   = 2 # n of pixels to remove on left
dic["right_crop"]  = 2 # n of pixels to remove on right

cropped_data = sscCdi.processing.crop_volume(obj,dic["top_crop"],dic["bottom_crop"],dic["left_crop"] ,dic["right_crop"] ,cropped_savepath='',crop_mode=0)

## Sort data

In [None]:
sorted_data, sorted_angles = sscCdi.sort_sinogram_by_angle(cropped_data, angles,object_savepath='',angles_savepath='') 

## Alignment Cross Correlation (CC) 

In [None]:
dic["CC_correlate_via_variance_field"] = True # if True, performs alignment of the variance (derivative) of the images. Only use True if phase wrapping is present!
dic["CC_return_common_valid_region"] = True    # if True, will return images containing only common ROI for all slices
dic["CC_remove_null_borders"] = True           # if True, removes the null borders of the image which represent the biggest shift in each direction
dic["CC_downscaling_factor"] = 1             # how many times to downsample the image before alignment. Recommended is 4.
dic["CC_fft_upsampling"] = 10                  # how much to upsample the data in reciprocal space for alignment. Recommended is 10. See: https://opg.optica.org/ol/abstract.cfm?uri=ol-33-2-156

neighbor_shifts, total_shift = sscRaft.alignment.get_shifts_of_local_variance_parallel(sorted_data,fft_upsampling=dic["CC_fft_upsampling"],downscaling_factor=dic["CC_downscaling_factor"], use_gradient=dic["CC_correlate_via_variance_field"],plot=True)

In [None]:
aligned_data_CC = sscRaft.alignment.shift_and_crop_volume(sorted_data,total_shift,return_common_valid_region=dic["CC_return_common_valid_region"], remove_null_borders = dic["CC_remove_null_borders"])

In [None]:
sscCdi.misc.deploy_visualizer(aligned_data_CC,type='phase',title='CC',cmap='gray',axis=0) # select axis

## Alignment Vertical Mass Fluctuation (VMF)

In [None]:
dic["VMF_vertical_region"] = (0,100)       # (top, bottom). Select a region with good data (no phase wrapping is better) avoiding the borders
dic["VMF_use_phase_gradient"] = False         # if True, will align over the gradient of the phase images, making it indifferent to phase ramps
dic["VMF_filter_sigma"] = 0                  # sigma of a Gaussian filter applied to the curves prior to alignment. May help to reduce the influence of noise and fine features when overlapping curves
dic["VMF_return_common_valid_region"] = True # if True, will return images containing only common ROI for all slices
dic["VMF_remove_null_borders"] = True        # if True, removes the null borders of the image which represent the biggest shift in each direction
dic["VMF_plot"] = None                       # if True, shows VMF curves of the VMF_vertical_region before and after alignment

aligned_data_VMF, curves, total_shift = sscRaft.alignment_vertical_mass_fluctuation(aligned_data_CC, curve_portion = dic["VMF_vertical_region"],  use_phase_gradient = dic["VMF_use_phase_gradient"],  filter_sigma = dic["VMF_filter_sigma"], return_common_valid_region=dic["VMF_return_common_valid_region"], remove_null_borders = dic["VMF_remove_null_borders"], plot = dic["VMF_plot"]) # if data is not equalized, phase gradient should be used

In [None]:
sscCdi.misc.deploy_visualizer(aligned_data_VMF,type='phase',title='VMF',cmap='gray',axis=0)

#### Visualize vertical mass

In [None]:
curves_CC = sscRaft.get_VMF_curves(aligned_data_CC,use_phase_gradient=True,filter_sigma=0,curve_portion=None)
curves_VMF = sscRaft.get_VMF_curves(aligned_data_VMF,use_phase_gradient=True,filter_sigma=0,curve_portion=None)

fig, ax = plt.subplots(1,2)
ax[0].imshow(curves_CC.T)
ax[1].imshow(curves_VMF.T)

## Unwrap

In [None]:
data_to_unwrap = np.angle(aligned_data_VMF)

unwrapped_sinogram = sscCdi.processing.unwrap_sinogram(data_to_unwrap,unwrapped_savepath="")

In [None]:
sscCdi.misc.deploy_visualizer(unwrapped_sinogram,type='real',title='VMF',cmap='gray',axis=0)

## 2D Equalization / Phase-ramp removal

In [None]:
dic["CPUs"] = 32

dic["equalize_invert"] = True                           # invert phase shift signal from negative to positive
dic["equalize_ROI"] = [0,10,0,10]                    # region of interest of null region around the sample used for phase ramp and offset corrections
dic["equalize_remove_phase_gradient"] = True            # if empty and equalize_ROI = [], will subtract best plane fit from whole image
dic["equalize_remove_phase_gradient_iterations"] = 5    # number of times the gradient fitting is performed
dic["equalize_local_offset"] = True                     # remove offset of each frame from the mean of ROI 
dic["equalize_set_min_max"]= []                         # [minimum,maximum] threshold values for whole volume
dic["equalize_non_negative"] = False                    # turn any remaining negative values to zero

In [None]:
equalized_sinogram = sscCdi.processing.equalize_sinogram(dic,unwrapped_sinogram,save=False)

### Equalization (Alternative Method #1)

In [None]:
projection = np.abs(unwrapped_sinogram.sum(0))
mask = np.zeros_like(projection)

mask[:, 0:20] = 1
mask[:, 200:] = 1

fig, ax = plt.subplots(1,3)
ax[0].imshow(projection)
ax[1].imshow(mask)
ax[2].imshow(mask*projection)

In [None]:
equalized_sinogram = sscCdi.equalize_scipy_optimization_parallel(unwrapped_sinogram,mask,initial_guess=(0,0,0),method='Nelder-Mead',max_iter = 1,processes=32)

In [None]:
sscCdi.deploy_visualizer(equalized_sinogram,title="Equalized sinogram",cmap='gray')

In [None]:
sscCdi.deploy_visualizer(equalized_sinogram,title="Equalized sinogram",cmap='gray',axis=1,aspect_ratio='auto')

## Rotation axis adjustment

In [None]:
slice_to_reconstruct = 100 # select which slice in the vertical direction to reconstruct
sinogram = equalized_sinogram

displacements = np.linspace(-20,20,41,dtype=int) # select list of displacement values to evaluate
print("Displacements: ",displacements)

In [None]:
dic["algorithm_dic"] = { # if FBP: filtered back-projection
    'algorithm': "FBP",
    'gpu': [0],
    'filter': 'lorentz', # 'gaussian','lorentz','cosine','rectangle'
    'angles': angles[:,1]*np.pi/180,
    'paganin regularization': 0, # 0 <= regularization <= 1; use for smoothening
}

biggest_side = np.max(sinogram[0].shape)
tomos = np.empty((len(displacements),biggest_side,biggest_side))

for i, dx in enumerate(displacements):
    shifted_sino = np.roll(sinogram[:,slice_to_reconstruct,:],shift=dx,axis=1)
    tomo = sscRaft.fbp(shifted_sino, dic["algorithm_dic"])
    tomos[i] = tomo

In [None]:
chosen_dx = displacements[23] # select which displacement to use
sinogram_adjusted_axis = np.roll(sinogram,shift=chosen_dx,axis=2)
                 
sscCdi.misc.deploy_visualizer(tomos,type='real',title='',cmap='gray',axis=0)

## Alignment (Iterative Reprojection) 

In [None]:
dic["algorithm_dic"] = { # if FBP: filtered back-projection
    'angles': angles[:,1]*np.pi/180,
    'algorithm': "FBP",
    'gpu': [0],
    'filter': 'lorentz', # 'gaussian','lorentz','cosine','rectangle'
    'paganin regularization': 0.1, # 0 <= regularization <= 1; use for smoothening
}

aligned_tomo, sinogram, cumulative_shifts = sscRaft.iterative_reprojection(equalized_sinogram,dic, max_iterations=3, downsampling=2,plot=True,find_shift_method='correlation_parallel',apply_shift_method='scipy',tomo_method='raft',radon_method='raft',n_cpus=32)