## Imports

In [None]:
from common import *
from surface import *
from streamlines import *

In [None]:
from ipywidgets import Layout, interact, IntSlider, FloatSlider
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
plt.rcParams["figure.dpi"] = 100
plt.rcParams["axes.grid"] = False

## Plotting functions

In [None]:
def select(arr, surf_vals=(V_SB, V_ST, V_SE), region=REGION):
    """Select values from a 3D or 4D volume corresponding to a surface or another region.
    
    Return i, j, k, v, where i, j, k are the indices of the voxels, and v are scalar or 
    vector values of the array at those voxels.
    """
    i, j, k = get_surface_indices(region, surf_vals).T
    v = arr[i, j, k, ...]
    return i, j, k, v

In [None]:
def barebone(ax):
    ax.set_facecolor(cm.get_cmap('viridis')(0))
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
def plot_volume(scalar, vector=None, streamlines=None, max_streamlines=10_000, all_projections=True):
    n, m, p = scalar.shape[:3]
    vmin = scalar.min()
    vmax = scalar.max()
    
    imshow_kwargs = dict(interpolation='none', origin='upper', vmin=vmin, vmax=vmax)
    interact_kwargs = dict(a=(0.0, 1.0, 0.01))
    f_kwargs = dict(a=.5)
    quiver_kwargs = dict(scale=30, width=.003, alpha=.35)
    streamlines_kwargs = dict(color='w', lw=2, alpha=.15)
    title_kwargs = dict(color='w')
    
    if vector is not None:
        assert vector.ndim == 4
        assert vector.shape[3] == 3
        i, j, k, v = select(vector)
        interact_kwargs['show_vector'] = True
        f_kwargs['show_vector'] = True
        
    if streamlines is not None:
        streamlines_subset = subset(streamlines, max_streamlines)
        interact_kwargs['show_streamlines'] = True
        f_kwargs['show_streamlines'] = True
    
    @interact(**interact_kwargs)
    def f(**f_kwargs):
        a = f_kwargs.get('a', None)
        show_vector = f_kwargs.get('show_vector', False)
        show_streamlines = f_kwargs.get('show_streamlines', False)
        
        fig, axes = plt.subplots(1, 3 if all_projections else 1, figsize=(18, 12))
        
        # HACK
        if not all_projections:
            axes = [axes]
        
        ai = np.clip(int(round(n*a)), 0, n-1)
        aj = np.clip(int(round(m*a)), 0, m-1)
        ak = np.clip(int(round(p*a)), 0, p-1)
        
        axes[0].imshow(scalar[ai, :, :], **imshow_kwargs)
        barebone(axes[0])
        axes[0].set_title('Coronal', **title_kwargs)
        
        if all_projections:
            axes[1].imshow(scalar[:, aj, :], **imshow_kwargs)
            axes[2].imshow(scalar[:, :, ak], **imshow_kwargs)

            barebone(axes[1])
            barebone(axes[2])

            axes[1].set_title('Transverse', **title_kwargs)
            axes[2].set_title('Sagittal', **title_kwargs)

        if vector is not None and show_vector:
            step = 3
            idxq = np.nonzero(i == ai)[0][::step]
            axes[0].quiver(k[idxq], j[idxq], v[idxq, 2], -v[idxq, 1], **quiver_kwargs)
            
            if all_projections:
                idxq = np.nonzero(j == aj)[0][::step]
                axes[1].quiver(k[idxq], i[idxq], v[idxq, 2], -v[idxq, 0], **quiver_kwargs)

                idxq = np.nonzero(k == ak)[0][::step]
                axes[2].quiver(j[idxq], i[idxq], v[idxq, 1], -v[idxq, 0], **quiver_kwargs)
        
        if streamlines is not None and show_streamlines:
            pz = streamlines_subset[:, :, 0]
            pidx = (ai-2 <= pz[:, :]) & (pz[:, :] <= ai+2)
            which = pidx.max(axis=1) > 0
            axes[0].plot(streamlines_subset[which, :, 2].T, streamlines_subset[which, :, 1].T, **streamlines_kwargs);
            
            if all_projections:
                pz = streamlines_subset[:, :, 1]
                pidx = (aj-2 <= pz[:, :]) & (pz[:, :] <= aj+2)
                which = pidx.max(axis=1) > 0
                axes[1].plot(streamlines_subset[which, :, 2].T, streamlines_subset[which, :, 0].T, **streamlines_kwargs);

                pz = streamlines_subset[:, :, 2]
                pidx = (ak-2 <= pz[:, :]) & (pz[:, :] <= ak+2)
                which = pidx.max(axis=1) > 0
                axes[2].plot(streamlines_subset[which, :, 1].T, streamlines_subset[which, :, 0].T, **streamlines_kwargs);
  

## Loading arrays

In [None]:
mask = load_npy(filepath(REGION, 'mask'))
normal = get_normal(REGION)
laplacian = load_npy(filepath(REGION, 'laplacian'))
gradient = load_npy(filepath(REGION, 'gradient'))
streamlines = load_npy(filepath(REGION, 'streamlines'))

## Plots

In [None]:
plot_volume(laplacian, streamlines=streamlines)