# XRD-CT reconstruction
## run this notebook inside the Tomography / Tomorec / GPU environment 

This notebook will do a simple tomo resontruction for a XRD-CT map (omega-translation flyscan). It requires data from the `.h5` masterfile and the `_pilatus_integrated.h5` file.

### Load data and plot average XRD for the map

In [None]:
%pylab inline
import os
import h5py as h5
import numpy as np
import matplotlib.pyplot as plt
try:
    import tomopy
except ModuleNotFoundError as err:
    raise Exception('tomopy not found: try changing your server to "Tomography/Tomorec/GPU": in File->Hub control panel')

#To import DanMAX from the folder above:
import sys
sys.path.append('../')
import DanMAX as DM
style = DM.darkMode(style_dic={'figure.figsize':'large'})


In [None]:
# Define scan location:
fname = DM.findScan()

# get motor names, nominal- and registered moto positions
M1, M2 = DM.getMotorSteps(fname,proposal=None,visit=None)
M1, nom_1, reg_1 = M1
M2, nom_2, reg_2 = M2

# assign rotation and translation positions
if 'ry' in M1 or 'rx' in M1 or 'huber' in M1:
    M_rot, M_trans = M1, M2
    rot, trans = reg_1, reg_2
    # data shape  (rotation, translation)
    shape_2d = (nom_1.shape[0],nom_2.shape[0])
elif 'ry' in M2 or 'rx' in M2 or 'huber' in M2:
    M_rot, M_trans = M2, M1
    rot, trans = reg_2, reg_1
    # data shape  (rotation, translation)
    shape_2d = (nom_2.shape[0],nom_1.shape[0])
else:
    print(f'Unable to guess rotational motor, assuming {M1} is rotational')
    M_rot, M_trans = M1, M2
    rot, trans = reg_1, reg_2
    shape_2d = (nom_1.shape[0],nom_2.shape[0])
    
# read azimuthally binned data
aname = DM.getAzintFname(fname)
data = DM.getAzintData(aname)

I = data['I']
# determine and read radial unit
if type(data['q']) != type(None):
    x = data['q']
    Q = True
else:
    x = data['tth']
    Q = False

# get meta data
custom_keys={'I1':'entry/instrument/albaem-xrd_ch3/data'}


meta = DM.getMetaData(fname,custom_keys=custom_keys)
I0 = meta['I0']
I1 = np.squeeze(meta['I1'][:,1,:]).astype(np.float32)
exposure = DM.getExposureTime(fname)



if not I1 is None:
    I1 = (I1.T/I0).T
    I1_2d = I1.reshape(shape_2d)
    # I_air = np.median(I1_2d[:,[0,1,2,-3,-2,-1]])
    lower, upper = 5, -5
    roi = np.r_[upper:lower]
    I_air = np.median(I1_2d[:,roi],axis=1)
    # I_air = np.median(I1_2d[:,-5:])
    # I1 /= I_air
    I1_2d = (I1_2d.T/I_air).T
    I1 = I1_2d.flatten()
    # normalize to I1
    # I = (I.T/I0).T
    I = (I.T/I1).T
      
  
else:
    # normalize to I0
    I = (I.T/I0).T
I[I<0.]=0.

zero_mask = np.nanmean(I,axis=0)>0.

I = I[:,zero_mask]
x = x[zero_mask]

# reshape the diffraction data (rot, trans, radial)
I_2d = I.reshape(*shape_2d,I.shape[-1])

# air background
air_bgr = np.mean(I_2d[:,roi,:],axis=(0,1))

I -= air_bgr
I -= I.min()

# calculate average diffraction pattern
I_avg = np.nanmean(I,axis=0)
I_median = np.nanmedian(I,axis=0)

# plot average pattern
plt.figure()
plt.title(DM.getScan_id(fname))
# plt.plot(x,air_bgr)
plt.plot(x,I_avg,label='avg')
plt.plot(x,I_median,label='median')
# plt.plot(x,np.nanmax(I,axis=0),label='median')
if Q:
    plt.xlabel(r'Q ($\AA^{-1}$)')
else:
    plt.xlabel(r'2$\theta$ ($\deg$)')
plt.ylabel('Intensity')
plt.legend()



meta = DM.getMetaData(fname,custom_keys=custom_keys)
I0 = meta['I0']
I1 = meta['I1']
exposure = DM.getExposureTime(fname)



if not I1 is None:
    I1_2d = I1.reshape(shape_2d)
    # I_air = np.median(I1_2d[:,[0,1,2,-3,-2,-1]])
    lower, upper = 5, -5
    roi = np.r_[upper:lower]
    I_air = np.median(I1_2d[:,roi])
    # I_air = np.median(I1_2d[:,-5:])
    I1_2d /= I_air

    # normalize to I1
    I = (I.T/I1).T
else:
    # normalize to I0
    I = (I.T/I0).T
I[I<0.]=0.

zero_mask = np.nanmean(I,axis=0)>0.

I = I[:,zero_mask]
x = x[zero_mask]

# calculate average diffraction pattern
I_avg = np.nanmean(I,axis=0)
I_median = np.nanmedian(I,axis=0)

# plot average pattern
plt.figure()
plt.title(DM.getScan_id(fname))
plt.plot(x,I_avg,label='avg')
plt.plot(x,I_median,label='median')
if Q:
    plt.xlabel(r'Q ($\AA^{-1}$)')
else:
    plt.xlabel(r'2$\theta$ ($\deg$)')
plt.ylabel('Intensity')
plt.legend()

#### Reshape XRD data for tomo reconstruction

In [None]:
# reshape the motor position data (rot, trans)
rot_2d = rot.reshape(shape_2d)
trans_2d = trans.reshape(shape_2d)
# reshape the diffraction data (rot, trans, radial)
I_2d = I.reshape(*shape_2d,I.shape[-1])

# reshape to adhere to the tomopy definition (rot,radial,trans)
proj = I_2d.transpose(0,2,1)
theta = rot_2d[:,0]*np.pi/180

# add the absorption data
if not I1 is None:
    proj = np.insert(proj, 0, values=1-I1_2d, axis=1)

#### Find the rotation center and plot a sinogram for the maximum point in the average diffraction pattern
For visual confirmation that everything is OK

In [None]:
#add padding along the translation axis
proj = tomopy.misc.morph.pad(proj, axis=2, mode='edge')

#find rotation center (pick an algorithm)
rot_center = tomopy.find_center(proj, theta)[0]
#rot_center = tomopy.find_center_pc(proj[0], proj[-1], tol=0.5)
#rot_center = tomopy.find_center_vo(proj)

print('The padded rotation center is {:.2f} px'.format(rot_center))

# difference between unpadded and padded projection
pad = round((proj.shape[-1] - I_2d.shape[1])/2)
# mm per pixel conversion
mm_per_px = np.mean(np.diff(trans_2d[0,:]))
# rotation center in mm
rot_cen_mm = (rot_center-pad)*mm_per_px+trans_2d[0,0]

print('The unpadded rotation center is {:.2f} mm'.format(rot_cen_mm))


fig = plt.figure()
# initialize grid and subplot with different size-ratios
grid = plt.GridSpec(2,1,height_ratios=[1,6]) #rows,columns
ax0, ax1 = [fig.add_subplot(gr) for gr in grid]

ax0.plot(trans_2d[0],np.mean(I_2d.transpose(0,2,1)[:,np.nanargmax(I_avg),:],axis=0))
ax0.axvline(trans_2d[0][lower],color='b',linestyle='dashed')
ax0.axvline(trans_2d[0][upper],color='b',linestyle='dashed')
ax0.set_yscale('log')

ax1.pcolormesh(trans_2d[0],
               theta*180/np.pi,
               I_2d.transpose(0,2,1)[:,np.nanargmax(I_avg),:],
               shading='auto',
               vmin=0,
              )
ax1.axvline(rot_cen_mm,color='w',linestyle='dashed')

ax1.axvline(trans_2d[0][lower],color='b',linestyle='dashed')
ax1.axvline(trans_2d[0][upper],color='b',linestyle='dashed')



ax1.set_xlabel(f'{M_trans} (mm)')
ax1.set_ylabel(f'{M_rot} (deg)')
plt.grid(False)
#plt.colorbar()

#### Perform the reconstruction

In [None]:
#tomo reconstruction (radial, padded shape, padded shape)
recon = tomopy.recon(proj, theta, center=rot_center, algorithm='gridrec', filter_name='parzen')
print(recon.shape)
recon = recon[:,pad:-pad,pad:-pad]
print(recon.shape)

# scale the reconstructed intensities to the number of projections and exposure time
recon = recon*exposure*rot_2d.shape[0]

# create x/y maps in mm
x_map, y_map = np.mgrid[0:recon.shape[1],0:recon.shape[2]]*mm_per_px

#### Plot the reconstruction

In [None]:
#add a circular mask
recon = tomopy.circ_mask(recon, axis=0, ratio=0.96)

plt.figure()
plt.pcolormesh(x_map,
               y_map,
               recon[np.argmax(I_avg)],
               #recon[0],
               cmap='viridis',
               vmin=0)
plt.xlabel('x (mm)')
plt.ylabel('y (mm)')
plt.colorbar()
plt.gca().set_aspect('equal')

### Save the reconstruction as a map

In [None]:
maps = {'x_map':x_map ,
        'y_map':y_map,
        'xrd_map':recon.T,
        'x_xrd':x,
        'Q':Q,
       }

DM.mapping.save_maps(maps=maps, scans=[fname])

### Plot region-of-interest maps

In [None]:
# define the approximate region of interest in scattering units
#            label    :     roi
regions = {
    'peak_1' : [5.0,6.0],
    'peak_2' : [9.2,10.1],
}

# plot the region of interest for the average pattern
plt.figure()
plt.title(DM.getScan_id(fname))
plt.plot(x,I_avg,label='average pattern')
plt.plot(x,I_median,label='median pattern')
# loop through all peaks
for region in regions:
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])
    plt.plot(x[roi],I_avg[roi],'.',label=region)
plt.ylabel('I [a.u.]')
if Q:
    plt.xlabel(r'Q [$\AA^{-1}$]')
else:
    plt.xlabel(r'2$\theta$ [$\deg$]')
plt.legend()

In [None]:
# Set the number of columns for the figure
cols = 3

rows = int(len(regions)/cols) + (len(regions)%cols!=0)

# initialize figure
fig, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
#fig.set_size_inches(12,8)
# fig.suptitle(DM.getScan_id(fname))
fig.suptitle(fname.split('raw/')[-1].replace('/',': '))
axes = axes.flatten()

# calculate scale bar values
scale_500um = .5/(mm_per_px)
offset = recon.shape[1]*0.025

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])
    
    bgr = np.min(recon[1:][roi][[0,-1]],axis=0)
    bgr[bgr<0]=0
    im = np.mean(recon[1:][roi],axis=0)-bgr
    
    # plot heatmap
    ax = axes[k]
    ax.set_title(f'{region}')
    ax.grid(False)
    
    vmin = max(im.min(),0)
    ax.imshow(im,vmin=vmin)
    
    ax.plot([offset,offset+scale_500um],
            [offset,offset],
            'w',
            lw=10,
           )
    ax.annotate('500 μm',
                (offset+scale_500um/2,offset*2),
                horizontalalignment='center',
                verticalalignment='top',
               )

# delete surplus plots
for i in range(1,cols*rows-len(regions)+1):
    fig.delaxes(axes[-i])
fig.tight_layout()


In [None]:
plt.figure()
plt.imshow(recon[0])
plt.colorbar()