# __Diffusion Tensor Imaging__
#### __Last updated on:__ 29/05/2020

## DTI reconstruction

###  Import libraries used

In [None]:
import os
import numpy as np
import nibabel as nib
import timeit; timeit.timeit()
import math
from skimage import io #用于读取保存或显示图片或者视频

from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table
from dipy.reconst.dti import TensorModel
from dipy.reconst.dti import fractional_anisotropy
from dipy.reconst.dti import color_fa

### __DWI data path__

In [None]:
data_path = "/home/erjun/Documents/dHCP/dhcp_dmri_pipeline/sub-CC00060XX03/ses-12501/dwi"
dwi_file = 'sub-CC00060XX03_ses-12501_desc-preproc_dwi.nii.gz'
brainmask_file = 'sub-CC00060XX03_ses-12501_desc-preproc_space-dwi_brainmask.nii.gz'
bval = 'sub-CC00060XX03_ses-12501_desc-preproc_dwi.bval'
bvec = 'sub-CC00060XX03_ses-12501_desc-preproc_dwi.bvec'

### __Post-processing__

In [None]:
# Change directory
os.chdir(data_path)

#--------------------------------------------------
# load DWI data files
#--------------------------------------------------
img1 = nib.load(os.path.join(data_path,dwi_file))
data = img1.get_data()

img2 = nib.load(os.path.join(data_path,brainmask_file))
brainmask = img2.get_data()

bvals, bvecs = read_bvals_bvecs(os.path.join(bval),
                                os.path.join(data_path,bvec))
gtab = gradient_table(bvals, bvecs)

#--------------------------------------------------------------
#               Fit diffusion tensor model
#--------------------------------------------------------------
print('Fitting diffuison tensor model')

ten_model = TensorModel(gtab)
ten_fit = ten_model.fit(data, brainmask)
        
#--------------------------------------------------------------
#               Save DTI parametric maps
#--------------------------------------------------------------
if not os.path.exists(data_path+'/DTI/'):
    os.mkdir(data_path+'/DTI')
output_path = data_path+'/DTI/'
        
FA = ten_fit.fa
AD = ten_fit.ad
RD = ten_fit.rd
MD = ten_fit.md
        
nib.save(nib.Nifti1Image(FA, img1.affine), os.path.join(output_path,'FA.nii.gz'))
nib.save(nib.Nifti1Image(MD, img1.affine), os.path.join(output_path,'MD.nii.gz'))
nib.save(nib.Nifti1Image(RD, img1.affine), os.path.join(output_path,'RD.nii.gz'))
nib.save(nib.Nifti1Image(AD, img1.affine), os.path.join(output_path,'AD.nii.gz'))
    
#--------------------------------------------------------------
#               Save FA RGB map
#--------------------------------------------------------------
fa = fractional_anisotropy(ten_fit.evals)
cfa = color_fa(fa, ten_fit.evecs)
FA = np.clip(fa, 0, 1)
RGB = color_fa(fa, ten_fit.evecs)

nib.save(nib.Nifti1Image(np.array(255 * cfa, 'uint8'), img1.affine), os.path.join(output_path,'FA_RGB.nii.gz'))

print('Elapsed time:',timeit.timeit())
print('Done!!')

## Visualization of MRI

### Basical output maps

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

# set plot background
plt.style.use('seaborn-dark')

# plot paramter maps          
fig, [ax0, ax1, ax2, ax3, ax4] = plt.subplots(1,5,figsize=(15,8),subplot_kw={'xticks': [], 'yticks': []})
ax0.imshow(RGB[:,30,:,:]); ax0.set_title('Color coded FA',fontweight='bold',size=10)
ax1.imshow(FA[:,30,:]); ax1.set_title('Fractional anisotropy',fontweight='bold',size=10)
ax2.imshow(MD[:,30,:]); ax2.set_title('Mean diffusivity',fontweight='bold',size=10)
ax3.imshow(RD[:,30,:]); ax3.set_title('Radial diffusivity',fontweight='bold',size=10)
ax4.imshow(AD[:,30,:]); ax4.set_title('Axial diffusivity',fontweight='bold',size=10)

In [None]:
np.shape(FA)

### 3D MRI
- RGB
- With a data cursor 

In [None]:
# plot new figure
from matplotlib.widgets import Cursor
plt.figure(figsize=(12,12))

# plot paramter maps          
plt.title('Color Coded Fractional anisotropy',fontweight='bold',fontsize=16)
#plt.xticks([]);plt.yticks([]);
plt.axis('off')
ax8=plt.imshow(RGB[:,:,20,:])
#cursor = Cursor(ax8,useblit=True, color='white',linewidth=2)
plt.show()


### Slider
- RGB images
- From slice 0 to slice 63;

In [None]:
# Import data
import time
import numpy as np

from skimage import io #用于读取保存或显示图片或者视频

vol = MD
colormax = vol.max()#获取最大array中的最大值，最后代表cmax
volume = vol.T
len(volume)
r, c = volume[math.floor(len(volume)/2)].shape
# Define frames
import plotly.graph_objects as go
nb_frames = len(volume)-1

fig = go.Figure(frames=[go.Frame(data=go.Surface(
    z=(len(volume)-1 - k ) * np.ones((r, c)),
    surfacecolor=volume[len(volume)-1 - k],
    cmin=0, cmax=colormax
    ),
    name=str(k) # you need to name the frame for the animation to behave properly
    )
    for k in range(nb_frames)])

# Add data to be displayed before animation starts
fig.add_trace(go.Surface(
    z=(len(volume)-1) * np.ones((r, c)),
    surfacecolor=volume[len(volume)-1],#np.flipud(volume[30]),
    #colorscale='Gray',
    cmin=0, cmax=colormax,
    colorbar=dict(thickness=20, ticklen=4)
    ))


def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(0)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

# Layout
fig.update_layout(
         title='Slices in volumetric data',
         width=600,
         height=600,
         scene=dict(
                    zaxis=dict(range=[-1, len(volume)-1], autorange=False),
                    aspectratio=dict(x=1, y=1, z=1),
                    ),
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(50)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(0)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
)

fig.show()

### Dropdown

- option $1$:RGB,FA,MD,RD,AD;
- option $2$: slice 0:63;
