# Neuroscience using `fastplotlib` and `pynapple`

This notebook will build up a complex visualization using `fastplotlib`, in conjunction with `pynapple`, to show how `fastplotlib` can be a powerful tool in analysis and visualization of neural data!

In [1]:
import warnings
warnings.simplefilter('ignore')

In [2]:
# if not installed, will need a function from scikit-image
! pip install scikit-image



In [3]:
import fastplotlib as fpl
import pynapple as nap
import numpy as np
from ipywidgets import IntSlider, Layout, VBox, HBox, FloatSlider
from skimage import measure
from sidecar import Sidecar

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01,\x00\x00\x007\x08\x06\x00\x00\x00\xb6\x1bw\x99\x…

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


Available devices:
🯄 (default) | Intel(R) Arc(tm) Graphics (MTL) | IntegratedGPU | Vulkan | Mesa 24.0.8-1
❗ | llvmpipe (LLVM 17.0.6, 256 bits) | CPU | Vulkan | Mesa 24.0.8-1 (LLVM 17.0.6)
❗ | Mesa Intel(R) Arc(tm) Graphics (MTL) | IntegratedGPU | OpenGL | 


In [4]:
import warnings
warnings.simplefilter('ignore')

## Load the data 

#### Recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. 

In [5]:
data = nap.load_file("./data.nwb")

In [6]:
data

data
┍━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys                  │ Type        │
┝━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ position_time_support │ IntervalSet │
│ RoiResponseSeries     │ TsdFrame    │
│ calcium_video         │ TsdTensor   │
│ beh_video             │ TsdTensor   │
│ z                     │ Tsd         │
│ y                     │ Tsd         │
│ x                     │ Tsd         │
│ rz                    │ Tsd         │
│ ry                    │ Tsd         │
│ rx                    │ Tsd         │
┕━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙

### Let's view the behavior and calcium data

Hopefully, by the end of the summer we will have developed a tool ([`pynaviz`](https://github.com/pynapple-org/pynaviz)) that makes these visualizations and synchronizations even easier :D

In [7]:
# behavior shape
behavior_data = data["beh_video"]
behavior_data.shape

(9045, 204, 256)

In [8]:
# calcium shape
calcium_data = data["calcium_video"]
calcium_data.shape

(17886, 136, 166)

#### Minimize our view of the data to where both behavior and position data are available:

In [9]:
frame_min = data["position_time_support"]["start"][0]
frame_max = data["position_time_support"]["end"][0]
(frame_min, frame_max)

(7.39305, 1213.22765)

### Create a plot for calcium and behavior video

In [10]:
nap_figure = fpl.Figure(shape=(1,2), names=[["raw", "behavior"]])

calcium_graphic = nap_figure["raw"].add_image(data=calcium_data[0], name="raw_frame", cmap="viridis")
behavior_graphic = nap_figure["behavior"].add_image(data=behavior_data[0], cmap="gray")

RFBOutputContext()

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


#### Create a slider that updates the behavior and calcium videos using `pyanapple`

In [11]:
# This time will be in milliseconds
synced_time = IntSlider(min=frame_min, max=frame_max, step=1, description="s", layout=Layout(width="60%"))

#### Create a `TimeStore` object that will synchronize our data

In [12]:
# create a TimeStore
time_store = TimeStore()

# subscribe our slider and calcium/behavior data to the store to be updated
time_store.subscribe(synced_time)
time_store.subscribe(calcium_graphic, calcium_data)
time_store.subscribe(behavior_graphic, behavior_data)

NameError: name 'TimeStore' is not defined

**Here we are going to use `sidecar` to organize our visualizations better :D**

In [13]:
sc = Sidecar()
with sc:
    display(VBox([nap_figure.show(), synced_time]))

In [14]:
# manually set the vmin/vmax of the calcium data
nap_figure["raw"]["raw_frame"].vmax = 205
nap_figure["raw"]["raw_frame"].vmin = 25

#### Calculate the spatial contours and overlay them on the raw calcium data

In [15]:
# get the masks
contour_masks = data.nwb.processing['ophys']['ImageSegmentation']['PlaneSegmentation']['image_mask'].data[:]
# reshape the masks into a list of 105 components
contour_masks = list(contour_masks.reshape((len(contour_masks), 166, 136)))

In [16]:
# calculate each contour from the mask using `scikit-image.measure`
contours = list()

for mask in contour_masks:
    contours.append(np.vstack(measure.find_contours(mask)))

#### Add the calculated contours as an overlay to the calcium video

In [17]:
contours_graphic = nap_figure["raw"].add_line_collection(data=contours, colors="w")

**It is very easy to see that many of the identified neurons may be "bad" candidates. Let's remove them from the dataset as we go on in our anaylsis.**

### Select only head-direction neurons

In [18]:
# get the temporal data (calcium transients) from the nwb notebook
temporal_data = data["RoiResponseSeries"][:]
temporal_data

Time (s)           0        1        2        3        4  ...
-----------  -------  -------  -------  -------  -------  -----
0.0          0        0.43582  2.96331  0        0        ...
0.033333     0        0.43406  2.95294  0        0        ...
0.066667     0        0.43231  2.9426   0        0        ...
0.1          0        0.43057  2.93231  0        0        ...
0.133333     0        0.42883  2.92205  0        0        ...
0.166667     0        0.4271   2.91182  0        0        ...
0.2          0        0.42537  2.90163  0        0        ...
...
1192.166667  2.54202  0.14531  0.44013  0.5681   0.65477  ...
1192.2       2.53029  0.14775  0.43842  0.56657  0.65227  ...
1192.233333  2.51861  0.14962  0.43671  0.56505  0.64979  ...
1192.266667  2.50698  0.15104  0.435    0.56354  0.64731  ...
1192.3       2.49541  0.15209  0.43331  0.56202  0.64485  ...
1192.333333  2.48389  0.15283  0.43162  0.58476  0.64239  ...
1192.366667  2.47242  0.15333  0.42994  0.62802  0.63994  ...
dt

In [19]:
# compute 1D tuning curved based on head angle
head_angle = data["ry"]

tuning_curves = nap.compute_1d_tuning_curves_continuous(temporal_data, head_angle, nb_bins = 120)

#### Select the top 50 components 

In [20]:
# select good components 
good_ixs = list(np.argsort(np.ptp(tuning_curves, axis=0))[-50:])
bad_ixs = list(np.argsort(np.ptp(tuning_curves, axis=0))[:-50])

#### Color the "good" and "bad" components

In [21]:
contours_graphic[good_ixs].colors = "green"
contours_graphic[bad_ixs].colors = "red"

### Remove the "bad" neurons

In [22]:
# sorting the "good" neurons based on preferred directions
sorted_ixs = tuning_curves.iloc[:,good_ixs].idxmax().sort_values().index.values

In [23]:
sorted_ixs

array([75, 34, 77, 86, 21, 16,  6,  4, 58, 44, 14, 33, 94, 98, 90, 76,  7,
        5, 82, 28, 15, 88, 45, 39,  0,  8, 20, 13, 24, 60, 18, 27, 10, 78,
        2, 85,  3, 19, 38, 17, 30, 29, 25, 84, 12, 26, 41,  9, 11,  1])

In [24]:
# filter dataset based on sortex indices
temporal_data = temporal_data[:,sorted_ixs]
contours = [contours[i] for i in sorted_ixs]

### Plot only the "good" components

In [25]:
# only plot the good indices 
nap_figure[0,0].remove_graphic(contours_graphic)
contours_graphic = nap_figure[0,0].add_line_collection(data=contours, colors="w")

## Make a plot of the calcium traces as a `LineStack`

In [124]:
# create a figure
tstack_fig = fpl.Figure(shape=(2,1))

RFBOutputContext()

In [125]:
# we need to transpose our temporal data so that it is (# components, time (s))
raw_temporal = temporal_data.to_numpy().T

# use 'hsv' colormap to represent preferred head direction 
tstack_graphic = tstack_fig[0,0].add_line_stack(data=raw_temporal, cmap="hsv", name="temporal-stack")

#### Add a `LinearSelector` that we can map to our behavior and calcium videos

In [126]:
tstack_selector = tstack_graphic.add_linear_selector()

In [127]:
# subscribe selector to timestore
time_store.subscribe(tstack_selector, temporal_data.rate)

NameError: name 'time_store' is not defined

#### Let's view everything together

In [128]:
sc = Sidecar()

with sc:
    display(VBox([nap_figure.show(), tstack_fig.show(maintain_aspect=False), synced_time]))

In [129]:
# initialize the conditions

In [130]:
ix = 0

tstack_graphic[selected_ix].colors = "w"

contours_graphic[selected_ix].colors = "magenta"

tuning_ix = sorted_ixs[ix]

tuning_curve = tuning_curves.T.iloc[tuning_ix]

tuning_graphic = tstack_fig[1,0].add_line(data=tuning_curve, offset=(0,0,0))

In [131]:
# add an event handler that allows tabbing up and down traces
@tstack_fig.renderer.add_event_handler("key_up")
def update_selected_trace(ev):
    global ix
    if ev.key == "ArrowUp":
        # increment ix
        ix += 1
        # check for looping
        if ix == len(tstack_graphic.graphics):
            ix = 0
    if ev.key == "ArrowDown":
        # decrement ix
        ix -= 1
        # check for looping
        if ix < 0:
            ix = len(tstack_graphic) - 1
            
    # update colors of selected component
    contours_graphic.colors = "w"
    
    contours_graphic[ix].colors = "magenta"

    # reset the cmap to hsv
    tstack_graphic.cmap = "hsv"
    # update colors of selected component
    tstack_graphic[ix].colors = "w"

    # get tuning curve and updated
    tuning_ix = sorted_ixs[ix]

    tuning_curve = tuning_curves.T.iloc[tuning_ix]

    global tuning_graphic
    tstack_fig[1,0].remove_graphic(tuning_graphic)

    tuning_graphic = tstack_fig[1,0].add_line(data=tuning_curve, offset=(0,0,0))

In [31]:
# @nap_figure.renderer.add_event_handler("click")
# def click_contour(ev):
#     # reset the contours colors to white
#     contours_graphic.colors = "w"

#     # get the xy position of the click in world space
#     xy = nap_figure[0,0].map_screen_to_world((ev.x, ev.y))[:-1]

#     # calculate the nearest contour
#     nearest = nap_figure[0,0].get_nearest_graphics(tuple(xy), method="center", subset=contours_graphic)

#     # set the nearest contour color to magenta
#     nearest[0].colors = "m"

#     # reset the traces stack colors to "hsv"
#     tstack_graphic.cmap = "hsv"

#     # find index of selected component 
#     ix = np.where(contours_graphic.graphics == nearest[0])[0][0]

#     # set corresponding index in temporal stack graphic to white
#     tstack_graphic.graphics[ix].colors = "w"