## <center>Nilearn statistical maps. <br>Brain 3d x, y, z-slice animation</center> ##

In the previous [notebook](https://github.com/empet/Plotly-viz-of-brain-images/blob/master/Nilearn-improved-Interactive-brain-slices.ipynb)
we presented how to select the cut point in MNI coordinates and plot interactively the 2d projections of the orthogonal 
brain slices through that point.

In this notebook we illustrate  the 3d animation of one of the three  x, y,  z-slices through the brain volume.

In [None]:
from plotly.offline import download_plotlyjs, init_notebook_mode,  iplot
init_notebook_mode(connected=True)

In [None]:
import numpy as np
import copy
from matplotlib import cm as mpl_cm
import matplotlib as mpl
from nilearn import (plotting, _utils)

In [None]:
def mpl_to_plotly(cmap, pl_entries):
    h=1.0/(pl_entries-1)
    pl_colorscale=[]
    for k in range(pl_entries):
        C=list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
        pl_colorscale.append([round(k*h,2), f'rgb({C[0]}, {C[1]}, {C[2]})'])
    return pl_colorscale

In [None]:
def colorscale(cmap, values, threshold=None, symmetric_cmap=True,
               vmax=None, vmin=None):
    """this function modifies nilearn.plotting.js_plotting_utils.colorscale
    It defines a Plotly colorscale from a given nilearn or matplotlib colormap,
    extracts the color range and the threshold"""
    
    cmap = mpl_cm.get_cmap(cmap)
    abs_values = np.abs(values)
    
    if not symmetric_cmap and (values.min() < 0):
        warnings.warn('you have specified symmetric_cmap=False '
                      'but the map contains negative values; '
                      'setting symmetric_cmap to True')
        symmetric_cmap = True
    if symmetric_cmap and vmin is not None:
        warnings.warn('vmin cannot be chosen when cmap is symmetric')
        vmin = None
    if threshold is not None:
        if vmin is not None:
            warnings.warn('choosing both vmin and a threshold is not allowed; '
                          'setting vmin to 0')
        vmin = 0
    if vmax is None:
        vmax = abs_values.max()
    if symmetric_cmap:
        vmin = - vmax
    if vmin is None:
        vmin = values.min()
    
    
   
    abs_threshold = None
    if threshold is not None:
        abs_threshold = _utils.param_validation.check_threshold(threshold, values, _utils.extmath.fast_abs_percentile)
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        ca = norm(-abs_threshold)
        cb = norm(abs_threshold)
        h1 = ca/11
        dl = [k*h1 for k in range(11)]
        h2 = (cb-ca) / 10
        dc = [ca+k*h2 for k in range(11)]
        h3 = (1-cb-h2) / 10
        dr = [cb+h2+k*h3 for k in range(11)]
        d = dl+dc+dr
        cmaplist = [cmap(t)[:3] for t in d]
        for k in range(11):  
            cmaplist[k+11] = mpl_cm.gray(k*0.1)[:3]
            
        pl_colorscale = []
   
        for k, t in enumerate(d):
            c = list(map(np.uint8, np.array(cmaplist[k])*255))
            pl_colorscale.append([round(t,3), f'rgb({c[0]}, {c[1]}, {c[2]})'])  
    else:
        pl_colorscale = mpl_to_plotly(cmap, 11)
    return  {
        'colorscale': pl_colorscale, 'vmin': vmin, 'vmax': vmax, 
        'abs_threshold': abs_threshold}

In [None]:
def pl_view_img(stat_map_img,  bg_img='MNI152', cut_coords= None,
             threshold=1e-6,
             cmap=plotting.cm.cold_hot,
             symmetric_cmap=True,
             dim='auto',
             vmax=None,
             vmin=None,
             resampling_interpolation='continuous',
             **kwargs):
    """
    Reads and processes the background image (MNI152) and a stats image, to be able to
    animate the x, y or z-slice in the brain volume
    THIS IS a MODIFIED version of the nilearn view_img function:
    https://github.com/nilearn/nilearn/blob/master/nilearn/plotting/html_stat_map.py#L332
    It is also a bit different from the function with the same name in the previous notebook
    
    Returns
    -------
    color_info, background, statistical and mask image data,
    the affine transformation associated to the processed stat image,
    the tuple of voxel index, cut_slices=(i, j, k), corresponding to given MNI cut_coords
      
    """    
       
   
    mask_img, stat_map_img, data, threshold = plotting.html_stat_map._mask_stat_map(stat_map_img, threshold)
    color_info = colorscale(cmap, data.ravel(), threshold=threshold, 
                        symmetric_cmap=symmetric_cmap, vmax=vmax,
                        vmin=vmin)
  
    bg_img, bg_min, bg_max, black_bg = plotting.html_stat_map._load_bg_img(stat_map_img, bg_img, dim)
    
    stat_map_img, mask_img = plotting.html_stat_map._resample_stat_map(stat_map_img, bg_img, mask_img,
                                                resampling_interpolation)
   
    A = stat_map_img.affine
    cut_slices =  plotting.html_stat_map._get_cut_slices(stat_map_img, cut_coords, threshold)
    
    return [color_info, _utils.niimg._safe_get_data(bg_img, ensure_finite=True),  
            _utils.niimg._safe_get_data(stat_map_img, ensure_finite=True), 
            _utils.niimg._safe_get_data(mask_img, ensure_finite=True),  A, cut_slices]


In [None]:
def plotly_slice_traces(bg_img, stat_img, mask_img, A, cut_slices, color_info):
    
    # min-max MNI-coordinates  corresponding to min-max voxel indices i, j, k:
    xMNI_min, yMNI_min, zMNI_min = A[:, 3][:3]
    imax, jmax, kmax =stat_map_img.shape 
    xMNI_max, yMNI_max, zMNI_max, one = np.dot(A, [imax-1, jmax-1, kmax-1, 1])
    
    xrange=[xMNI_min, xMNI_max]
    yrange=[yMNI_min, yMNI_max]
    zrange=[zMNI_min, zMNI_max]
    
    pl_colorscale=color_info['colorscale']
    vmin=color_info['vmin']
    vmax=color_info['vmax']
    abs_threshold=color_info['abs_threshold']
    islice, jslice, kslice = np.array(cut_slices-1, int)# voxel indices corresponding to cut_slices

    # Mix the backgraound image and statistical image values according to mask_img:
    a, b = -abs_threshold, abs_threshold
    vmin_bg, vmax_bg = bg_img.min(), bg_img.max()
    new_img = copy.deepcopy(stat_map_img)
    alpha = (b-a) / (vmax_bg-vmin_bg)  
    new_bg = a + alpha * (bg_img-vmin_bg) # map bg_img vals to [a,b]
    I, J, K = np.where(mask_img==1)
    new_img[I, J, K] = new_bg[I, J, K]

    # Define the 2d arrays of values cooresponding to the three slices:
    xsts = new_img[islice, :, :]
    ysts = new_img[:, jslice, :]
    zsts = new_img[:, :, kslice]

    # Define the arrays containing the voxel coordinates of the points in each slice:

    yx = list(range(xsts.shape[0]))
    zx = list(range(xsts.shape[1]))
    yx, zx = np.meshgrid(yx, zx)
    xx = islice * np.ones(xsts.T.shape)

    xy = list(range(ysts.shape[0]))
    zy = list(range(ysts.shape[1]))
    xy, zy = np.meshgrid(xy, zy)
    yy = jslice * np.ones(ysts.T.shape)

    xz = list(range(zsts.shape[0]))
    yz = list(range(zsts.shape[1]))
    xz, yz = np.meshgrid(xz,yz)
    zz = kslice * np.ones(zsts.T.shape)

    # Transform voxel indices to MNI coordinates:
    xx, yx, zx, one_arr = np.einsum('ik, kjm -> ijm', A, np.stack((xx, yx, zx, np.ones(xx.shape)))) 
    xy, yy, zy, one_arr = np.einsum('ik, kjm -> ijm', A, np.stack((xy, yy, zy, np.ones(yy.shape)))) 
    xz, yz, zz, one_arr = np.einsum('ik, kjm -> ijm', A, np.stack((xz, yz, zz, np.ones(zz.shape)))) 

    # Define the Plotly surfaces representing the three slices:
    xslice = dict(type='surface',
                        x=xx,
                        y=yx,
                        z=zx,
                        name='x-slice',
                        surfacecolor=xsts.T,
                        colorscale=pl_colorscale,     
                        colorbar=dict(thickness=20, ticklen=4, tick0=-7, dtick=2, len=0.75),
                        cmin=vmin,
                        cmax=vmax)

    yslice = dict(type='surface',
                        x=xy,
                        y=yy,
                        z=zy,
                        name='y-slice',
                        surfacecolor=ysts.T,
                        colorscale=pl_colorscale,    
                        colorbar=dict(thickness=20, ticklen=4, tick0=-7, dtick=2, len=0.75),
                        cmin=vmin,
                        cmax=vmax)
    zslice = dict(type='surface',
                        x=xz,
                        y=yz,
                        z=zz,
                        name='z-slice',
                        surfacecolor=zsts.T,
                        colorscale=pl_colorscale, 
                        colorbar=dict(thickness=20, ticklen=4, tick0=-7, dtick=2, len=0.75),
                        cmin=vmin,
                        cmax=vmax) 
    
    # Define the frames for each animation 
    xdivs = np.linspace(xMNI_min, xMNI_max, imax)
    xframes = [dict(data=[dict(x= xdivs[n]*np.ones(xsts.T.shape),
                          surfacecolor=new_img[n, :, :].T)],
               name=f'frame{n}',
              ) for n in range(8, imax-9)]
    
    ydivs = np.linspace(yMNI_min, yMNI_max, jmax)
    yframes = [dict(data=[dict(y= ydivs[n]*np.ones(ysts.T.shape),
                          surfacecolor=new_img[:, n, :].T)],
               name=f'frame{n}',
               ) for n in range(8, jmax-9)]
    
    zdivs = np.linspace(zMNI_max, zMNI_min, kmax)
    zframes = [dict(data=[dict(z= zdivs[n]*np.ones(zsts.T.shape),
                          surfacecolor=new_img[:, :, n].T)],
               name=f'frame{n}',
               ) for n in range(4, kmax-13)]
    
    return xslice, yslice, zslice, xframes, yframes, zframes, xrange, yrange, zrange

### Plotly plots of  slices in the 3d space referenced to the MNI-system of coordinates

In [None]:
axis3d = dict(showbackground=True, 
            backgroundcolor="rgb(230, 230,230)",
            gridcolor="rgb(255, 255, 255)",      
            zerolinecolor="rgb(255, 255, 255)", 
            ticklen=4,
            tickfont=dict(size=11)  
            )


layout = dict(width=600,
              height=600,
              scene=dict(xaxis=dict(axis3d),
                    yaxis=dict(axis3d), 
                    zaxis=dict(axis3d),
                    camera=dict(eye=dict(x=1.4, y=1.4, z=1.15))     
                    ),
              updatemenus=[dict(type='buttons', showactive=False,
                                  y=1,
                                  x=1.32,
                                  xanchor='right',
                                  yanchor='top',
                                  pad=dict(t=0, r=10),
                                  buttons=[dict(label='Play',
                                                method='animate',
                                                args=[None, 
                                                      dict(frame=dict(duration=70, 
                                                                      redraw=False),
                                                           transition=dict(duration=0),
                                                           fromcurrent=True,
                                                           mode='immediate')
                                                     ]
                                              )
                                         ]
                                )
                          ]

         )

In [None]:
stat_img = 'Data/image_10426.nii.gz'
cut_coords = [-90, -27, -20]
color_info, bg_img, stat_map_img, mask_img,  A, cut_slices = pl_view_img(stat_img, cut_coords= cut_coords, threshold=3)

xslice, yslice, zslice, xframes, yframes, zframes, xrange, yrange, zrange = \
                plotly_slice_traces(bg_img, stat_img, mask_img, A, cut_slices, color_info)

Define the frames for x-slice, y-slice, respectively z-slice animation:

In [None]:
figx = dict(data=[xslice], layout=layout, frames=xframes)
figx['layout']['scene']['xaxis'].update(range=xrange)
figx['layout']['title'] = 'x-slice animation'

iplot(figx, validate=False)

In [None]:
figy = dict(data=[yslice], layout=layout, frames=yframes)
figy['layout']['scene']['yaxis'].update(range=yrange)
figy['layout']['title'] = 'y-slice animation'

iplot(figy, validate=False)

In [None]:
figz = dict(data=[zslice], layout=layout, frames=zframes)
figz['layout']['scene']['zaxis'].update(range=zrange)
figz['layout']['title'] = 'z-slice'

iplot(figz, validate=False)

Gif file of x-slice animation:

In [1]:
%%html
<img src='x-sliceg.gif'>