# OPT reconstruction

Here is an notebook for using [TomoPy](http://tomopy.readthedocs.io/en/latest/) ([citation here](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4181643/pdf/s-21-01188.pdf)) data-cleaning and reconstruction algorithms on image data from the Mesoscopic Imaging Facility.

You may need to [install some packages to get this to work](https://jakevdp.github.io/blog/2017/12/05/installing-python-packages-from-jupyter/).  Be sure to [install Tomopy](http://tomopy.readthedocs.io/en/latest/install.html).

In [12]:
#make sure you correctly install all packages you need to run this.

import sys
!conda install --yes --prefix {sys.prefix} -c conda-forge tomopy

Collecting package metadata: ...working... done
Solving environment: ...working... done

# All requested packages already installed.



## Import required packages and modules

This notebook reads in hdf5 files and constructs tomograms using the [TomoPy package](https://tomopy.readthedocs.io/en/latest/about.html).

matplotlib and ipywidgets provide plotting of the result in this notebook. [Paraview](http://www.paraview.org/) or other tools are available for more sophisticated 3D rendering.

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import tomopy

In [3]:
import matplotlib.pyplot as plt
import h5py
import imageio
import numpy as np
import pickle
import datetime
from skimage import transform as transf
from skimage import img_as_float64

%matplotlib inline                 
from ipywidgets import interact  

## Import an examine data

We will import the data from hdf5 files and examine the data using matplotlib to see what cleaning we will need to do.

Set the path to the tomography data to reconstruct and input names of the hdf5 files.

In [28]:
#User must set this, typically put transmission in first, then fluorescence
dirname = 'D:\Data_folder\sample_name'

fnames = [dirname+r'\input\trans.h5', 
          dirname+r'\input\cy5.h5', 
          dirname+r'\input\ET.h5',
          dirname+r'\input\YFP.h5'
]

Read hdf5 file structure. 

This extracts the structure in the hdf5 file (names of keys, number of images, size of images, data type) and displays it below.

In [29]:
dset_array = array(range(len(fnames)), dtype='|S20')
for ix in range(len(fnames)):
    filename = fnames[ix]
    f = h5py.File(filename, 'r')


    def traverse_datasets(hdf_file):

        def h5py_dataset_iterator(g, prefix=''):
            for key in g.keys():
                item = g[key]
                path = '{0}/{1}'.format(prefix, key)
                if isinstance(item, h5py.Dataset): # test for dataset
                    yield (path, item)
                elif isinstance(item, h5py.Group): # test for group (go down)
                    for ix in h5py_dataset_iterator(item, path):
                        yield ix

        with h5py.File(hdf_file, 'r') as f:
            for path, _ in h5py_dataset_iterator(f):
                yield path
            
    with h5py.File(filename, 'r') as f:
        for dset in traverse_datasets(filename):
            print('Path:', dset)
            dset_array[ix]=dset
            print('Shape:', f[dset].shape)
            shape_var = f[dset].shape
            len_var= sum(len(x) for x in f[dset])
            print('Data type:', f[dset].dtype)

Path: /t0/channel0
Shape: (400, 672, 512)
Data type: uint16
Path: /t0/channel0
Shape: (400, 672, 512)




Data type: uint16
Path: /t0/channel0
Shape: (400, 672, 512)
Data type: uint16
Path: /t0/channel0
Shape: (400, 672, 512)
Data type: uint16


Read and save the data from the hdf5 file in variable *proj*

In [30]:
proj = zeros(shape=(len(fnames), shape_var[0], shape_var[1], shape_var[2]))

for ix in range(len(fnames)):
    print('Reading out channel: ',ix+1)
    filename = fnames[ix]
    f = h5py.File(filename, 'r')
    dataset=f[dset_array[ix]]
    proj[ix] = np.array(dataset[:,:,:])
print(proj.shape)
num_channels, num_images, image_height, image_width = proj.shape
print('Complete.')


Reading out channel:  1
Reading out channel:  2
Reading out channel:  3
Reading out channel:  4
(4, 400, 672, 512)
Complete.


If you'd like to import flat-field and/or dark-field images:

In [None]:
flat_bool = False

if flat_bool==True:
    filename_flat = r'\input\cy5_bkgd.tif'
    flat = plt.imread(filename_flat)

In [126]:
dark_bool = False

if dark_bool=True:
    filename_dark = r'\input\cy5_bkgd.tif'
    dark = plt.imread(filename_dark)

if flat_bool==True and dark_bool=False:
#We don't often take dark-field images, so you can simply use an array of homogeneous values from the dark part of an image.
    dark_value = proj[1,0,0,0]
    dark = np.full((proj.shape[2],proj.shape[3]), dark_value)

Plot images from the data:

In [31]:
channel_num = 1
def plot_proj(image_num=0):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    ax.imshow(proj[channel_num, image_num, :, :], cmap='Greys_r')
    plt.show()
    
interact(plot_proj, image_num=(0,num_images-1))

interactive(children=(IntSlider(value=0, description='image_num', max=399), Output()), _dom_classes=('widget-i…

<function __main__.plot_proj(image_num=0)>

Plot sinograms:

In [32]:
channel_num = 1
def plot_sin(image_num=100):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    ax.imshow(proj[channel_num,:,image_num, :], cmap='Greys_r')
    plt.show()
    
interact(plot_sin, image_num=(0,num_images-1))

interactive(children=(IntSlider(value=100, description='image_num', max=399), Output()), _dom_classes=('widget…

<function __main__.plot_sin(image_num=100)>

## Begin cleaning data and setting up for reconstruction

We will make sure we have the data in the correct formats for Tomopy.

Set the data collection angles.  

In [33]:
#The following assumes 360 degrees rotation divided evenly between all images.
theta = np.linspace(0,np.pi*2,num_images)

If you have the flat-field and dark-field images (light in absence of sample, and sample in absence of light, respectively), perform the flat-field correction of transmission data: $$ \frac{proj - dark} {flat - dark} $$

In [10]:
if flat_bool==True:
    proj[0] = tomopy.normalize(proj[0], flat, dark)

Tomopy provides various methods to [find rotation center](http://tomopy.readthedocs.io/en/latest/api/tomopy.recon.rotation.html).  **There are a couple ways to do rotation, you may need to try both.**


If you've run this before, it has saved the data in the input folder and will automatically import it.

In [34]:
image_180 = int(floor(num_images/2)) #assuming 360 degrees of rotation, this is image 180 degrees apart from first image

try:
    rot_file = open(dirname+r'\input\rotation.pkl','rb')
    rot_center = pickle.load(rot_file)
    rot_file.close()
    print('File found')
except FileNotFoundError:
    rot_center = np.zeros(num_channels)
    for ch in range(num_channels):
        #rot_center[ch] = tomopy.find_center_vo(proj[ch])
        rot_center[ch] = tomopy.find_center_pc(proj[ch][0],proj[ch][image_180])
        print(rot_center[ch])
    #Pickle for future reference
    rot_file = open(dirname+r'\input\rotation.pkl','wb')
    pickle.dump(rot_center, rot_file)
    rot_file.close()
print('Complete.')

256.5
255.75
256.5
255.0
Complete.


Tomopy can align image stack using a [re-projection algorithm](https://www.nature.com/articles/s41598-017-12141-9.pdf).  Indicate whether you want it to run tilt correction or not by setting *tilt_correction = True/False*.

This can take a long time.  It's recommended to try it at first without tilt correction if you are confident that the sample is relatively well-aligned and the rotation stage is reliable.  However, if previous attempts at reconstructions had tilt-artefacts, set tilt_correction = True and run the algorithm.

If you've run this before, it has saved the data in the input folder and will automatically import it.

Choose a channel to run the alignment algorithm on.  It's best to use one with low noise and good contrast.

In [35]:
tilt_correction = False
alignment_channel = 1

if tilt_correction==True:
    #First check if alignment shifts have previously been calculated
    try:
        tilt_file = open(dirname+r'\input\tilt_shift.pkl', 'rb')
        sy, sx = pickle.load(tilt_file)
        tilt_file.close()
        print('File found')

    #Otherwise run alignment module
    except FileNotFoundError:
        print('Tilt file not found. Using TomoPy to calculate alignment.')
        
        # This can take a long time (~5-10 min/iteration).  Often times, ~10 iterations is sufficient.  Ideally err<1.

        align_params = {'algorithm':'mlem', 'iters':10}

        start = datetime.datetime.now()
        print("Aligning channel: ", alignment_channel)
        proj_align_channel, sy , sx, conv_channel = tomopy.prep.alignment.align_joint(proj[alignment_channel], theta, 
                                                                                      center=rot_center[alignment_channel], 
                                                                                      **align_params)
        end = datetime.datetime.now()
        print( int((end - start).total_seconds()/60), 'minutes' )

        #Pickle for future reference
        tilt_file = open(dirname+r'\input\tilt_shift.pkl','wb')
        pickle.dump([sy,sx], tilt_file)
        tilt_file.close()
    print("Complete.")
else:
        print('No alignment correction')

No alignment correction


Apply alignment to all channels

In [36]:
# This will take a minute

if tilt_correction==True:
    align_proj = zeros(shape=proj.shape)
    for ch in range(num_channels):
        print('Aligning channel: ', ch+1)
        channel_proj = np.copy(proj[ch])
        for ix in range(num_images):
            tform = transf.SimilarityTransform(translation=(sy[ix], sx[ix]))
            channel_proj[ix] = transf.warp(channel_proj[ix], tform)
        align_proj[ch] = channel_proj
    print('Complete.')
else:
    print('No alignment correction')
    align_proj=np.copy(proj)

No alignment correction


Check alignment looks correct before saving over projections -- if tilt_correction = False, this doesn't change anything and shows the original projection.

In [37]:
channel_num = 1
def plot_proj_align(image_num=0):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    ax.imshow(align_proj[channel_num, image_num, :, :], cmap='Greys_r')
    plt.show()
    
interact(plot_proj_align, image_num=(0,num_images-1))

interactive(children=(IntSlider(value=0, description='image_num', max=399), Output()), _dom_classes=('widget-i…

<function __main__.plot_proj_align(image_num=0)>

If it looks reasonable, save new projection -- again, if tilt_projection = False, this does not change anything.

In [38]:
proj=np.copy(align_proj)

Calculate $$ -log(proj) $$

In [39]:
proj_inv = zeros(shape = proj.shape)
proj_inv = -np.log(proj)

channel_num = 1
def plot_proj(image_num=0):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    ax.imshow(proj_inv[channel_num, image_num, :, :], cmap='Greys_r')
    plt.show()
    
interact(plot_proj, image_num=(0,num_images-1))

interactive(children=(IntSlider(value=0, description='image_num', max=399), Output()), _dom_classes=('widget-i…

<function __main__.plot_proj(image_num=0)>

## Reconstruct projections.  
Tomopy offers a number of [algorithms and filters](https://tomopy.readthedocs.io/en/latest/api/tomopy.recon.algorithm.html).

Algorithm/filters can be chosen according to needs. 
- fbp is standard filtered back projection.
- gridrec is very fast, and reconstructs reasonably well.
- Iterative methods like mlem, art, and sirt generally outperform direct Fourier-based reconstruction methods, however require many iterations at ~4-6 min/iter.

In [40]:
params = {'algorithm':'mlem', 'num_iter':10}

recon = zeros(shape=(num_channels, image_height, image_width, image_width))
start = datetime.datetime.now()
for ix in range(num_channels):
    print('Reconstructing channel: ', ix+1)
    recon_channel = tomopy.recon(align_proj[ix], theta, center=rot_center[ix], 
                                 **params
                                  )
    recon[ix] = recon_channel
end = datetime.datetime.now()
print( int((end - start).total_seconds()/60), 'minutes' )
print('Complete.')


Reconstructing channel:  1
Reconstructing channel:  2
Reconstructing channel:  3
Reconstructing channel:  4
1173 minutes
Complete.


Mask each reconstructed slice with a circle.

In [41]:
for ix in range(len(recon)):
    recon[ix] = tomopy.circ_mask(recon[ix], axis=0, ratio=0.95)

Check axis=2 to see if data from reconstructions can be back-constructed to original images.

In [42]:
channel_num = 1
def plot_back(image_num=250):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    ax.imshow(recon[channel_num,:, image_num, :], cmap='Greys_r')
    plt.show()
    
interact(plot_back, image_num=(0,recon.shape[2]-1))

interactive(children=(IntSlider(value=250, description='image_num', max=511), Output()), _dom_classes=('widget…

<function __main__.plot_back(image_num=250)>

If the above image looks nothing like the original, there are a couple things you can do:

- Check that **rot-center** makes sense.  
- Check that theta was calculated correctly (in radians, **not** degrees)
- Check that the files imported correctly.
- Add in flat-field corrections to clean your data.

Then you can re-try the reconstruction.

Now plot image from axis=0 to see reconstructed data.

In [43]:
channel_num = 1
def plot_recon(image_num=250):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    plt.imshow(recon[channel_num,image_num,:, :], cmap='Greys_r')
    plt.show()
    
interact(plot_recon, image_num=(0,recon.shape[1]-1))

interactive(children=(IntSlider(value=250, description='image_num', max=671), Output()), _dom_classes=('widget…

<function __main__.plot_recon(image_num=250)>

You can apply filters to final data to clean up noise, but default would be filter_bool = False.

In [44]:
filter_bool = False

if filter_bool == True:  
    filter_axis_list = [1,2]
    filtered = np.copy(recon)
    for ax in range(len(filter_axis_list)):
        print('Filtering along axis: ', ax+1, '/', len(filter_axis_list))
        for ch in range(len(proj)):
            print('Filtering channel: ', ch+1)
            filtered[ch] = tomopy.misc.corr.median_filter(filtered[ch], size=3, axis=ax)
else:
    print('No filter applied')
    filtered=np.copy(recon)
print('Complete.')


channel_num = 1
def plot_filter(image_num=250):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    plt.imshow(filtered[channel_num,image_num,:, :], cmap='Greys_r')
    plt.show()
    
interact(plot_filter, image_num=(0,filtered.shape[1]-1))

No filter applied
Complete.


interactive(children=(IntSlider(value=250, description='image_num', max=671), Output()), _dom_classes=('widget…

<function __main__.plot_filter(image_num=250)>

## Save data in output folder

**Create an output folder**, set it in the output file name below, and we will data export as hdf5:

In [45]:
#Set folder and names of outputs, must be same number as number of inputs, ideally use same labels
file_output = [dirname+r'\output\trans.h5', 
               dirname+r'\output\cy5.h5',
               dirname+r'\output\etgfp.h5'
               ]

for ix in range(len(file_output)):
    print('Writing channel: ', ix+1)
    file_output_h5 = file_output[ix]
    archive = h5py.File(file_output_h5, 'w')
    archive['recon'] = filtered[ix]
    archive.close()
    
print('Complete.')

Writing channel:  1
Writing channel:  2
Writing channel:  3
Complete.
