# Reconstruction of X-Ray tomograms (Hamburg)

### start with defining the dataset and directories

In [1]:



import numpy as np
import os
import sys
import time
import glob

import gc


from maximus48 import var
from maximus48.tomo_proc3 import init_Npad, init_names_custom, F,rotscan
from maximus48 import SSIM_131119 as SSIM_sf 
from maximus48.SSIM_131119 import SSIM_const
from maximus48 import multiCTF2 as multiCTF


from maximus48.tomo_proc3 import rotaxis_rough


from skimage.io import imread, imsave
# from scipy.ndimage import rotate
# from maximus48.tomo_proc3 import rotate,rotrough_compute

from pybdv import make_bdv 

from dask import delayed
import dask.array as da
from dask.array.image import imread
from dask.distributed import Client, progress
from dask_jobqueue import SLURMCluster

import tomopy



In [2]:


# =============================================================================
#           initial parameters for phase retrieval with CTF
# =============================================================================
N_steps = 10                                                                   # Number of projections per degree
N_start = 1                                                                    # index of the first file
N_finish = 3600                                                                # index of the last file

pixel = 0.1625 * 1e-6                                                          # pixel size 
distance = np.array((6.1, 6.5, 7.1, 8), dtype = 'float32') * 1e-2              # distances of your measurements 
energy = 18                                                                    # photon energy in keV
beta_delta = 0.15
zero_compensation = 0.05
ROI = (0,100,2048,2048) 
                                                      # ROI of the image to be read (x,y,x1,y1 at the image - inverse to numpy!)
cp_count = 48 
chunksz = 30

                                                               # number of cores for multiprocessing
inclination = -0.23                                                            # angle to compansate the tilt of the rotaxis

#data_name = 'ew21_5'
data_name = 'Platy-12601'
folder_base = '/scratch/schorb/HH_platy'
folder = os.path.join(folder_base,'raw/')
folder_temp = os.path.join(folder_base,'tmp/')
folder_result = os.path.join(folder_base,'rec/')
distances = (1,2,3,4)
N_distances  = 4  


#calculate parameters for phase-retrieval
wavelength = var.wavelen(energy)
fresnelN = pixel**2/(wavelength*distance)


#create save folder if it doesn't exist

if not os.path.exists(folder_temp):
    os.makedirs(folder_temp)


if not os.path.exists(folder_result):
    os.makedirs(folder_result)


# Variable structure:
            # ['N_files',
            #  'N_start',
            #  'Npad',
            #  'ROI',
            #  'ROI_ff',            
            #  'flats',
            #  'im_shape',
            #  'images',
            #  'init_paths']
            
            
Npad = init_Npad(ROI, compression = 8)


In [None]:

# Support functions
# =============================

def init_paths(data_name, path, distance_indexes):
    """Generate paths images & flatfields"""

    #set data_names
    data_names, ff_names = init_names_custom(data_name = data_name,
                                             distance_indexes = distance_indexes)
    
    #find images
    imlist = var.im_folder(path)
    
    # create a filter for unnecessary images
    tail = '_00000.tiff'
    
    if len(data_names[0])!=len(data_names[-1]):
        print("""
        WARNING! Different distances in your dataset 
        have different naming lengths. 
        File names can't be arranged. 
        Try to reduce the number of distances (<10) or modify the script.
        """)
        sys.exit()
    else:
        data_lencheck = len(data_names[0]+tail)
        ff_lencheck = len(ff_names[0]+tail)
    

    
    #set proper paths
    N_distances = len(distance_indexes) 
    images = np.zeros(N_distances, 'object') 
    flats = np.zeros(N_distances, 'object')
            
    for i in np.arange(len(images)):
        
        #sort image paths
        images[i] = [path+im for im in imlist 
                     if (im.startswith(data_names[i])) 
                     and not (im.startswith('.'))
                     and (len(im) == data_lencheck)]
        
        flats[i] = [path+im for im in imlist 
                    if im.startswith(ff_names[i])
                    and (len(im)==ff_lencheck)]

    return images,flats




# =============================


## holographic reconstruction





def read_flat(j,images=[], ROI_ff=[], ROI=[],flats=[],ff_file='',ffcon_file='',distances=(), N_start=0, Npad=0):
    """
    j: int
        an index of the file that should be processed 
    Please note, j always starts from zero
    To open correct file, images array uses images[i][j + N_start-1]
    """
    
        
        
    ff_con = np.load(ffcon_file,allow_pickle=True)
    
    ff = np.load(ff_file,allow_pickle=True)
   
    #read image and do ff-retrieval 
    
    # =============================


    
    ## FLAT field correction
    
    
    filt = []
    
    for i in np.arange(len(images)):
        im = imread(images[i][j])[ROI[1]:ROI[3], ROI[0]:ROI[2]]
     
        maxcorridx=np.argmax(SSIM_sf.SSIM(SSIM_const(im[ROI_ff[1]:ROI_ff[3], ROI_ff[0]:ROI_ff[2]]),ff_con[i]).ssim())        
        filt.append(im/ff[i][maxcorridx])

        
        
    im_gau0 = var.filt_gauss_laplace(filt[0][ROI_ff[1]:ROI_ff[3], ROI_ff[0]:ROI_ff[2]],
                                    sigma = 5)
    thisshift = []
    
    for i in range(len(filt)):
        im_gau1 = var.filt_gauss_laplace(filt[i][ROI_ff[1]:ROI_ff[3], ROI_ff[0]:ROI_ff[2]],
                                    sigma = 5)
        thisshift.append(var.shift_distance(im_gau0, im_gau1, 10))
    
    
    filt0 = multiCTF.shift_imageset(filt, thisshift)

    filt0 = np.pad(filt0, ((0,0),(Npad, Npad),(Npad, Npad)), 'edge')               # padding with border values
    filt0 = multiCTF.multi_distance_CTF(filt0, beta_delta, 
                                          fresnelN, zero_compensation)
    filt0 = filt0[Npad:(filt0.shape[0]-Npad),:]

    imsave(os.path.join(folder_temp,''.join(os.path.basename(images[0][j]).partition('_'+str(distances[0]))[0:3:2])),filt0)
    # pda = da.from_array(filt0)
    # da.to_zarr(pda,'/scratch/schorb/HH_platy/Platy-12601_'+str(j)+'.zarr')

    return 'done processing image '+str(j)




## Holographic reconstruction functions

In [None]:

# =============================


# RUN  SCRIPT

# =============================


images, flats = init_paths(data_name, folder, distances)

im_shape = (ROI[3]-ROI[1], ROI[2]-ROI[0])

shape_ff = (N_distances, len(flats[0]), im_shape[0], im_shape[1])
ff_shared = F(shape = shape_ff, dtype = 'd')


#read ff-files to memory

ff = np.zeros(shape_ff)

for i in range(N_distances):
    for j,fname in enumerate(flats[i]):
        ff[i][j]=imread(fname)[ROI[1]:ROI[3], ROI[0]:ROI[2]]
        


#calculate ff-related constants
ROI_ff = (ff.shape[3]//4, ff.shape[2]//4,3 * ff.shape[3]//4, 3 * ff.shape[2]//4)    # make ROI for further flatfield and shift corrections, same logic as for normal ROI


ff_con = np.zeros(N_distances, 'object')                                                # array of classes to store flatfield-related constants
for i in np.arange(N_distances):    
    ff_con[i] = SSIM_const(ff[i][:,ROI_ff[1]:ROI_ff[3], 
                                    ROI_ff[0]:ROI_ff[2]].transpose(1,2,0))

ffcon_file = folder_temp+'ffcon.npy'
np.save(ffcon_file,ff_con)



ff_file = folder_temp+'ff.npy'
np.save(ff_file,ff)


#read_flat(j,images=images, ROI_ff=ROI_ff, ROI=ROI,flats=flats,distances=distances,ffcon_file=ffcon_file,ff_file=ff_file, N_start=N_start, Npad=Npad)

#%%
# s1=client.map....

status = 'p'

while status != 'done':
    for st in s1:
        
        if st.status in ['error']:
            print('retrying '+st.key)
            
            st.retry()
            status = 'p'
            time.sleep(1)
        elif st.status in ['finished']:
            status = 'done'
    
    

## Stripe removal and storage of intermediate result

In [None]:
imfiles = sorted(glob.glob(folder_temp+'*.tiff'))

im = imread(imfiles[0])
pshape = (3600,im.shape[1],im.shape[2])

proj = np.zeros(pshape)

In [None]:
for idx,imf in enumerate(imfiles):proj[idx,:]=imread(imf)

print('stripe removal\n\n================================\n\n')

proj = tomopy.prep.stripe.remove_stripe_fw(proj,level=3, wname=u'db25', sigma=2, pad = False,ncore=cp_count,nchunk=chunksz)


stripe_file = folder_temp+'/stripe.npy'
np.save(stripe_file,proj)

## Read data back in

### Make it parallel...

In [3]:
proj = np.lib.format.open_memmap(folder_temp+'/stripe.npy',mode='r')


In [60]:
proj = np.load(folder_temp+'/stripe.npy')

In [19]:
projd = da.from_array(proj,chunks = [1,-1,-1])

In [4]:
import numpy as np
import dask
import dask.array as da


def mmap_load_chunk(filename, shape, dtype, offset, sl):
    '''
    Memory map the given file with overall shape and dtype and return a slice
    specified by :code:`sl`.

    Parameters
    ----------

    filename : str
    shape : tuple
        Total shape of the data in the file
    dtype:
        NumPy dtype of the data in the file
    offset : int
        Skip :code:`offset` bytes from the beginning of the file.
    sl:
        Object that can be used for indexing or slicing a NumPy array to
        extract a chunk

    Returns
    -------

    numpy.memmap or numpy.ndarray
        View into memory map created by indexing with :code:`sl`,
        or NumPy ndarray in case no view can be created using :code:`sl`.
    '''
    data = np.memmap(filename, mode='r', shape=shape, dtype=dtype, offset=offset)
    return data[sl]


def mmap_dask_array(filename, shape, dtype, offset=0, blocksize=5):
    '''
    Create a Dask array from raw binary data in :code:`filename`
    by memory mapping.

    This method is particularly effective if the file is already
    in the file system cache and if arbitrary smaller subsets are
    to be extracted from the Dask array without optimizing its
    chunking scheme.

    It may perform poorly on Windows if the file is not in the file
    system cache. On Linux it performs well under most circumstances.

    Parameters
    ----------

    filename : str
    shape : tuple
        Total shape of the data in the file
    dtype:
        NumPy dtype of the data in the file
    offset : int, optional
        Skip :code:`offset` bytes from the beginning of the file.
    blocksize : int, optional
        Chunk size for the outermost axis. The other axes remain unchunked.

    Returns
    -------

    dask.array.Array
        Dask array matching :code:`shape` and :code:`dtype`, backed by
        memory-mapped chunks.
    '''
    load = dask.delayed(mmap_load_chunk)
    chunks = []
    for index in range(0, shape[0], blocksize):
        # Truncate the last chunk if necessary
        chunk_size = min(blocksize, shape[0] - index)
        chunk = dask.array.from_delayed(
            load(
                filename,
                shape=shape,
                dtype=dtype,
                offset=offset,
                sl=slice(index, index + chunk_size)
            ),
            shape=(chunk_size, ) + shape[1:],
            dtype=dtype
        )
        chunks.append(chunk)
    return da.concatenate(chunks, axis=0)


In [5]:

projd = mmap_dask_array(
    filename=folder_temp+'/stripe.npy',
    shape=proj.shape,
    dtype=proj.dtype
)

In [None]:
projd

In [6]:
import os
os.environ['OMP_NUM_THREADS'] ='1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

In [7]:
client=Client(n_workers=6)

In [8]:
from dask_image.ndinterp import rotate


In [61]:
slices = np.arange(0,projd.shape[0],100)
slices = np.append(slices,projd.shape[0]-1)


In [62]:
ra = []

for idx,chunk in enumerate(slices[0:-1]):
    p1=rotate(da.from_array(proj[chunk:slices[idx+1]]),inclination,axes=(1,2),reshape=True)
    ra.append(p1)

In [None]:
p2 = client.compute(p1.compute())

In [64]:
len(ra)

36

In [None]:
out=[]

for chunk in ra:
    out.append(da.from_array(client.compute(chunk.compute()),chunks = [1,-1,-1]))
    
outarr = da.stack(out)

In [30]:
da.to_zarr()

In [None]:
import matplotlib.pyplot as plt

In [None]:
client.close()

In [None]:
from dask_image.ndinterp import rotate


In [None]:
plt.imshow(p2[4,:,:])