Set-up:

To run this example, you will need the at least following libraries `numpy`, `scipy`, `dask`, `distributed`, `matplotlib`, and `fish`. All except `fish` can be installed via `conda`, e.g. `conda install dask`. 

Unfortunately I haven't made `fish` installable with `pip` or `conda`, so for now you have to clone it from github using `git clone https://github.com/d-v-b/fish.git` and add the path to the cloned folder to your $PYTHONPATH environment variable.

I use `flika` for visualizing 3D data plane-by-plane. It can be installed via `pip install flika`. This assumes you are running this notebook with a desktop environment, e.g. via NoMachine. Be advised that attempting to use `flika` without a desktop environment will likely crash the notebook. 

Make sure you have dask-drmaa or dask-jobqueue installed; if not, run `pip install dask-drmaa` or `pip install dask-jobqueue`. Either works for dynamically requesting workers; in this example, I will use `dask-jobqueue`, using a convenience wrapper that can be found here: https://github.com/d-v-b/fish/blob/master/fish/util/distributed.py#L14

you will also need some environment variables set. Put this code block in your `~/.bash_profile` file, then run `source .bash_profile`
```bash
# Export LSF variables, if available.
# May not be available when using Linux locally or Windows with Git Bash.
if [[ -f /misc/lsf/conf/profile.lsf ]]; then
    source /misc/lsf/conf/profile.lsf
    export LSB_STDOUT_DIRECT='Y'
    export LSB_JOB_REPORT_MAIL='N'
    export LSF_DRMAA_LIBRARY_PATH=/misc/sc/lsf-glibc2.3/lib/libdrmaa.so.0.1.1
    export DRMAA_LIBRARY_PATH=$LSF_DRMAA_LIBRARY_PATH
fi
```

see https://github.com/d-v-b/bash_profile-janelia/blob/master/.bash_profile for an example bash_profile

In [None]:
# import a bunch of stuff 
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import dask.array as da
from scipy.ndimage.filters import median_filter

import flika as flk
flk.start_flika()
from flika.window import Window as flw

In [None]:
# a parallelizable function that applies a 2D shift to an array. we will use this to apply our registration results
def shift_yx(im, shifts, block_id):
    from scipy.ndimage.interpolation import shift
    t = block_id[0]    
    return shift(im.astype('float32'), (0,0, *shifts[t]), order=1, cval=100)

# parallelizable function that takes an array as input and performs efficient df/f along the 0th axis of the vector
def mydff(v, fs_im):
    from fish.image.vol import dff
    camera_offset = 80
    window = 300 * fs_im
    percentile = 20
    offset = 10    
    downsample = 10;
    return dff((v - camera_offset).clip(1, None), window, percentile, offset, downsample, axis=0).astype('float32')

### Set up the data, start talking to the distributed scheduler

In [None]:
# path to raw data
from fish.image.zds import ZDS
base_dir = '/nrs/ahrens/davis/data/spim/raw/20160608/6dpf_cy171xcy221_f1_omr_1_20160608_170933/'

# make a ZDS with the path to raw data.
dset = ZDS(base_dir)
fs_im = dset.metadata['volume_rate']
sample = dset.data[0].compute(scheduler='threads')
plt.imshow(sample.max(0), cmap='gray', clim=(100,250))
print(dset)

In [None]:
# the zds has a data property that is a dask array with 1 chunk : file 
# For this demo, I crop in time and space using slice objects
roi = slice(-500, None), slice(20,30), slice(None), slice(None)
data = dset.data[roi]
print(data)

In [None]:
from fish.util.distributed import get_jobqueue_cluster
from dask.distributed import Client

# to do distributed computation, dask needs an object that lets it talk to the janelia compute cluster
# compute clusters use job scheduler software to give users resources; at Janelia, that software is called 
# LSF and there are 2 dask libraries that can bridge dask with LSF -- dask-jobqueue and dask-drmaa
# here I'm using dask-jobqueue via a wrapper I wrote that uses good janelia-specific default settings
cluster = get_jobqueue_cluster()

# instantiate a dask.distributed.Client object with the cluster object
client = Client(cluster)

# once we have a client object, it will register itself with dask as the default scheduler, overriding 'threads'. 
# so calling dask_array.compute() with no scheduler specified will try to use the distributed scheduler,
# even if we have no workers requested (in which case, your computation goes nowhere)
# we add workers with cluster.start_workers() and remove them (all) with cluster.stop_all_jobs()
client

### Distributed motion correction

1. Generate a reference image 
2. Make a lazy version of a function that estimates a transform to align two images
3. Apply that lazy function to all images in the dataset
4. Examine (and modify) estimated transform parameters

In [None]:
from scipy.ndimage.filters import median_filter

# make a filtered version of our raw data
# we are mapping a function over 4D chunks of data, so the function needs to assume a 4D input
data_filt = data.astype('float32').map_blocks(lambda v: median_filter(v, (1,1,5,5)))    

# take the mean to form a reference image for registration
# 'threads' is the default scheduler, but I set it explicitly for pedagogical purposes
anat_ref = data[data.shape[0]//2 + np.arange(-5,5)].compute(scheduler='threads').mean(0)

In [None]:
# visualize the reference image
flw(anat_ref)

In [None]:
# Here I import a function for estimating translation
from fish.image.alignment import estimate_translation
# Here I import a function that makes other functions lazy
from dask import delayed

# make a lazy version of my registration function
lazyreg = delayed(estimate_translation)

# make a lazy version of the reference image
ref_mx = da.from_array(anat_ref.max(0), chunks=(-1,-1))

# make a list of lazy registration calculations
affs = [lazyreg(ref_mx, mx) for mx in data_filt.max(1)]

In [None]:
%%time
# Get 300 workers to estimate the transform parameters. Ideally we only do this once, save the transform parameters
# and load those parameters from disk each time we need to register data.
cluster.start_workers(300)
reg_result = client.compute(affs, sync=True)
shifts = -np.array([r.affine[:-1,-1] for r in reg_result])
cluster.stop_all_jobs()

In [None]:
fig, axs = plt.subplots(figsize=(12,4), dpi=200)
shifts_filt = median_filter(shifts, size=(10,1))
axs.plot(shifts)
axs.plot(shifts_filt, color='k')
axs.legend(['dy','dx'])

### Apply registration, downscale, transpose, and estimate Δf/f

1. Apply motion-correction 
2. Apply a median filter and downscale in xy dimensions
3. Rechunk the data so that each chunk contains all timepoints for a few pixels
4. Apply Δf/f estimation


In [None]:
# shift each timepoint and apply a median filter
data_txf = data.map_blocks(shift_yx, shifts, dtype='float32').map_blocks(lambda v: median_filter(v, size=(1,1,3,3)))

# reduce data size by 16 by downsampling by 4 in x and y 
ds_xy = 4
data_ds = da.coarsen(np.mean, data_txf, {2: ds_xy, 3: ds_xy})

# now rechunk the data so that each chunk contains an entire timeseries. 
rechunked = data_ds.rechunk(chunks=(-1, 'auto', 'auto','auto'))
data_dff = rechunked.map_blocks(lambda v: mydff(v, fs_im=fs_im), dtype='float32')

In [None]:
%%time
cluster.start_workers(100)
# for this example, I'm taking the max projection in Z to keep the result small
result = data_dff.max(1).compute()
cluster.stop_all_jobs()

### Visualize the result

In [None]:
# plot the peak-weighted peak times
from fish.util.plot import depth_project
fig, axs = plt.subplots(dpi=200, figsize=(8,8),ncols=2, gridspec_kw={'width_ratios':(20,1)})
cmap='rainbow'
axs[0].imshow(depth_project(result, clim=(.1,1), mode='max', cmap=cmap)[:,:,:-1])
axs[1].imshow(np.arange(result.shape[0]).reshape(-1,1), cmap=cmap, extent=(0,result.shape[0]//20,0,result.shape[0]))

In [None]:
flw(result)