<a href="https://colab.research.google.com/github/bantin/PhoRC/blob/master/examples/phorc_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PhoRC Colab Demo
This notebook demonstrates **Pho**tocurrent **R**emoval with **C**onstraints (PhoRC), a software tool for removing direct photocurrent artifacts. It accompanies the preprint:
> _Removing direct photocurrent artifacts in optogenetic connectivity mapping data via constrained matrix factorization._ (2023) B. Antin\*, M. Sadahiro\*, M. A. Triplett, M. Gajowa, H. Adesnik, and L. Paninski

Following photocurrent subtraction, these traces should be ready for connectivity inference using your choice of algorithm. We recommend the CAVIaR algorithm, described in detail in the preprint:
> _Rapid learning of neural circuitry from holographic ensemble stimulation enabled by model-based compressed sensing_. (2022). M. A. Triplett\*, M. Gajowa\*, B. Antin, M. Sadahiro, H. Adesnik, and L. Paninski.



Below, we install PhoRC, Circuitmap, along with some small utilities libraries for plotting and visualization.
It's safe to ignore the output of this cell unless there's a glaring error.

In [1]:
#@title Install dependencies (double-click to show code)
!pip install "jax[cuda11_cudnn805]"==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!git clone https://github.com/marcustriplett/circuitmap
!pip install ./circuitmap
!git clone https://github.com/bantin/PhoRC
!pip install ./PhoRC
!git clone https://github.com/bantin/matplotlib-ephys
!pip install ./matplotlib-ephys
!pip install matplotlib_scalebar
!pip install gdown

import matplotlib.pyplot as plt
import numpy as np
import circuitmap as cm
import phorc
from circuitmap import NeuralDemixer
from circuitmap.simulation import simulate_continuous_experiment
from circuitmap.viz import plot_checkerboard, plot_spike_inference_comparison
from sklearn.metrics import r2_score
import phorc.utils

import os
import h5py


plt.rcParams.update({'font.size': 7, 'lines.markersize': np.sqrt(5), 'lines.linewidth': 0.5, 'lines.markeredgewidth': 0.25})

import matplotlib as mpl
mpl.rcParams['axes.spines.left'] = True
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.bottom'] = True

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda11_cudnn805]==0.3.15
  Downloading jax-0.3.15.tar.gz (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jaxlib==0.3.15+cuda11.cudnn805 (from jax[cuda11_cudnn805]==0.3.15)
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15%2Bcuda11.cudnn805-cp310-none-manylinux2014_x86_64.whl (250.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m250.6/250.6 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.3.15-py3-none-any.whl size=1201902 sha256=41d6aa89980d8a705ca3480d108029b22b635c9559f14764c0cf48b9a7e7785c
  Stored in directory: /root/.cache/pip/wheels/42/65/08/12200413b

# Download example dataset

We've placed an example grid dataset in Google drive. Below, we use the gdown utility to download it for further processing.

In [2]:
import gdown
url = "https://drive.google.com/uc?id=1v-cNKRMdCcVX4NFLiQt1v2Bb4eTlA7LS"
filename = "220308_chrome2f_cell1.mat"
gdown.download(url, filename, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1v-cNKRMdCcVX4NFLiQt1v2Bb4eTlA7LS
To: /content/220308_chrome2f_cell1.mat
100%|██████████| 231M/231M [00:04<00:00, 55.0MB/s]


'220308_chrome2f_cell1.mat'

# Load example grid dataset
Below, we'll load the dataset using h5py. Assume that the dataset contains K trials, each of which last for T timesteps. Let N be the total number of presynaptic targets. In our case, the data were recorded by stimulating a 26 x 26 grid across five planes, for a total of N=3380 targets. Additionally, this dataset was performed using three laser powers.

**Data format**:   
The dataset contains the following arrays:

- `pscs`: a K x T matrix in which each row is a recorded trace following optical stimulation. For this dataset, stimulation lasts 5 milliseconds, beginning at frame 100 and ending at frame 200.
- `stim_matrix`: an N x T matrix where `stim_matrix[i,j]` is the power used to stimulate target `i` on trial `j`.
- `targets`: N x 3 array of target locations in space. These are ordered to match the ordering of targets in the stim matrix. Together, this `stim_matrix` and `targets` tell us which point in 3D space was stimulated at each trial, and the corresponding row of the `pscs` matrix gives the postsynaptic cell's response on that trial.

In [3]:
# Load experiment
f = h5py.File(filename)

In [4]:
# unpack dataset
pscs = np.array(f['pscs']).T
stim_matrix = np.array(f['stimulus_matrix']).T
targets = np.array(f['targets']).T
powers = np.max(stim_matrix, axis=0)
N, K = stim_matrix.shape
_, T = pscs.shape

# Remove trials with no stim
stim_matrix = stim_matrix[:, powers > 0]
pscs = pscs[powers > 0]
powers = powers[powers > 0]

## Visualizing the raw data
Running the two cells below will show visualizations of the raw data.

**Grid Maps**: The left figure shows grid maps, a visual representation of the total synaptic response across different planes. The region of interest (ROI) is highlighted within these maps.

**Photocurrent Traces**: On the right, you'll find photocurrent traces that are associated with the selected ROI. For each unique power level, a subplot is displayed showing the top 10 traces with the largest response sum. Vertical bars denote onset and offset of the laser.

**Adjust ROI with sliders**: Below the figures are three sliders labeled 'X', 'Y', and 'Z'. These are used to adjust the dimensions of the ROI in the grid maps, which in turn changes the photocurrent traces displayed. The figures automatically update to reflect the selected ROI when the sliders are adjusted.
Note that the sliders will not update until the mouse is released -- it may take a few seconds for the plot to render after changing the sliders.


**Explore the data by adjusting the sliders to select different regions of interest.**



In [5]:
#@title Define plotting function for gridmaps
from matplotlib import gridspec

def plot_gridmaps(fig, mean_maps, depth_idxs,
                  cmaps='viridis', vmin=None, vmax=None, zs=None, zlabels=None,
                  powers=None, roi_bounds=None, map_names=None):

    # allow option to pass separate cmaps for each grid plot
    if not isinstance(cmaps, list):
        cmaps = len(mean_maps) * [cmaps]

    # Create an outer grid
    outer_grid = gridspec.GridSpec(
        1, len(mean_maps) + 1, width_ratios=[1]*len(mean_maps) + [0.05])

    # Calculate global min_val and max_val across all mean_maps if vmin and vmax are not provided
    if vmin is None:
        min_val = np.nanmin([np.nanmin(mean_map) for mean_map in mean_maps])
    else:
        min_val = vmin
    if vmax is None:
        max_val = np.nanmax([np.nanmax(mean_map) for mean_map in mean_maps])
    else:
        max_val = vmax

    for mean_idx, mean_map, cmap in zip(range(len(mean_maps)), mean_maps, cmaps):

        num_powers, _, _, num_planes = mean_map.shape
        num_planes_to_plot = len(depth_idxs)
        assert num_planes_to_plot <= num_planes

        # use subgrid for each ImageGrid
        subgrid = gridspec.GridSpecFromSubplotSpec(
            num_planes_to_plot, num_powers, subplot_spec=outer_grid[mean_idx], wspace=0.05, hspace=0.05)

        # Set title for the map if provided
        if map_names:
            fig.suptitle(map_names[mean_idx])


        for j in range(num_planes_to_plot * num_powers):
            ax = plt.Subplot(fig, subgrid[j])
            fig.add_subplot(ax)

            ax.set_xticks([])
            ax.set_yticks([])
            plt.axis('off')

            row = j // num_powers
            col = j % num_powers

            im = ax.imshow(mean_map[col, :, :, depth_idxs[row]],
                           origin='lower', vmin=min_val, vmax=max_val, cmap=cmap)

            # optionally add labels
            if (zs is not None) and col == 0 and mean_idx == 0:
                ax.set_ylabel('%d ' % zs[depth_idxs[row]] + r'$\mu m $')
            elif (zlabels is not None) and col == 0 and mean_idx == 0:
                ax.set_ylabel(zlabels[row])

            # optionally add power label as white text on top of the image
            if powers is not None and row == 0:
                ax.annotate('%d mW' % powers[col], xy=(1.0, 1.3), xycoords='axes fraction',
                            horizontalalignment='right', verticalalignment='top')

            # Draw ROI if bounds are provided and the current z index is within the ROI bounds
            if roi_bounds is not None:
                xmin, xmax = roi_bounds[0]
                ymin, ymax = roi_bounds[1]
                zmin, zmax = roi_bounds[2]

                if zmin <= depth_idxs[row] <= zmax:
                    roi = plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')
                    ax.add_patch(roi)

        if mean_idx == len(mean_maps) - 1:
            colorbar_grid = gridspec.GridSpecFromSubplotSpec(
                num_planes // 2, 1, subplot_spec=outer_grid[-1],)
            cbar = plt.colorbar(im, cax=plt.subplot(colorbar_grid[0]))


# The following line reshapes the data to create a 6-dimensional array (power, x, y, z, trials, time)
# containing all recorded current traces. We then average over the last two dimensions to create a visual map.
psc_tensor = phorc.utils.make_psc_tensor_multispot(
            pscs,
            powers,
            targets,
            stim_matrix
        )
raw_map = phorc.utils.traces_tensor_to_map(psc_tensor)

In [6]:
#@title Make interactive plot of gridmaps and traces
import ipywidgets as widgets
from IPython.display import display

# Assuming raw_map and psc_tensor are defined
num_planes = raw_map.shape[-1]
n_powers = len(np.unique(powers))

# Define the sliders
x_slider = widgets.IntRangeSlider(min=0, max=raw_map.shape[1], description='X:', continuous_update=False)
y_slider = widgets.IntRangeSlider(min=0, max=raw_map.shape[2], description='Y:', continuous_update=False)
z_slider = widgets.IntRangeSlider(min=0, max=raw_map.shape[3], description='Z:', continuous_update=False)

fig = plt.figure(figsize=(3,2.8), dpi=300)
def interactive_plot(roi_x, roi_y, roi_z):
    global fig
    plt.clf()
    fig = plt.figure(figsize=(3,2.8), dpi=300)
    subfigs = fig.subfigures(1, 2, width_ratios=[1,0.6], wspace=0.07)

    plot_gridmaps(subfigs[0],
        [raw_map],
        np.arange(num_planes),
        cmaps='magma',
        roi_bounds=(roi_x, roi_y, roi_z),
        powers=np.unique(powers),
        zs = np.unique(targets[:,-1]),
        map_names=['raw']
    )

    total_subplots = n_powers
    axes_right = subfigs[1].subplots(total_subplots, 1, sharex=True, sharey=True)

    for i in range(total_subplots):
        these_trials = psc_tensor[i,
                                  roi_x[0]:roi_x[1]+1,
                                  roi_y[0]:roi_y[1]+1,
                                  roi_z[0]:roi_z[1]+1,
                                  :,:]
        these_trials = these_trials.reshape(-1,900)
        these_trials = these_trials[~np.isnan(these_trials[:,0])]

        idxs = np.argsort(these_trials.sum(1))[::-1]
        these_trials = these_trials[idxs[0:10]]
        phorc.utils.plot_current_traces(these_trials.reshape(-1,900), ax=axes_right[i], scalebar_loc='bottom',
                                          linewidth=0.5, alpha=0.8, IV_bar_length=0.01, box_aspect=0.5, scalebar=False)
        axes_right[i].set_title('%d mW' % np.unique(powers)[i])
        axes_right[i].set_ylabel('Current (nA)')

    plt.show()

out = widgets.interactive_output(interactive_plot, {'roi_x': x_slider, 'roi_y': y_slider, 'roi_z': z_slider})

# Display everything
display(x_slider, y_slider, z_slider, out)


IntRangeSlider(value=(6, 19), continuous_update=False, description='X:', max=26)

IntRangeSlider(value=(6, 19), continuous_update=False, description='Y:', max=26)

IntRangeSlider(value=(1, 3), continuous_update=False, description='Z:', max=5)

Output()

## Estimate photocurrents with PHoRC

The idea of PhoRC is to approximate the data as a sum of scaled copies of the photocurrent.
Let $Y$ be an N x T matrix containing the PSCs along the rows (`pscs` in the code).
We'll approximate the data as a low rank product $UV$ where $U$ is N x R and $V$ is R x T, and R is the rank of the factorization.
The algorithm proceeds in two steps: first we estimate the photocurrent weights $U$ using the _photocurrent integration window_. This is set by the parameters `window_start` and `window_end`. We refer to these as `t1` and `t2` below for convenience.


**Step 1: Estimate photocurrent weights**
$$
\min_{U_{\text{stim}}, V_{\text{stim}}} \| Y_{:, t_1:t_2} - U_{\text{stim}} V_{\text{stim}}\|_F^2
$$
Subject to:
$$
U_{\text{stim}}, V_{\text{stim}} \geq 0, \quad
U_{\text{stim}} V_{\text{stim}} \leq Y_{:, t_1:t_2}
$$

**Step 2: Estimate photocurrent waveform**
$$
\min_{V} \| Y - U_{\text{stim}} V\|_F^2
$$
    Subject to:
$$
V \geq 0, \quad
U_{\text{stim}} V \leq Y
$$

The photocurrent is then estimated as $U_{\text{stim}}V$, which is subtracted from the data to reveal the underlying EPSCs.
For more details on the PHoRC algorithm, see the preprint.

### Setting hyperparameters
We encourage you to adjust parameters and see what happens! Below are the main hyperaparameters which control PHoRC's behavior
- `rank`: this is the rank of the factorization. We found R=2 worked well in most cases, and that R=1 works well for datasets with smaller photocurrents (<0.05 nA)
- `window_start`: This should match the onset of laser stimulation. In our case, that's at sample 100
- `window_end`: A good starting point is to match this to end of stimulation, which in our case is sample 200. In some cases, it can be helpful to reduce this (to say, 3 ms after laser onset) to ensure that the lowest latency PSCs are preserved

- `batch_size`: PhoRC splits the data into small batches to account for variation in waveform shape. We found that 100-200 worked well. A batch size of -1 will process the entire datset in one batch.

**NB**: This cell will take between 3-5 minutes to run.


In [7]:
photocurrent_est = phorc.estimate(pscs, rank=2, batch_size=200, window_start=100, window_end=200, rho=1)

Running photocurrent estimation with 169 batches...


100%|██████████| 169/169 [02:59<00:00,  1.06s/it]


In [8]:
pscs_corrected = pscs - photocurrent_est

## Visualize the Raw vs. Subtracted Maps





**Grid Maps:** The left figure displays grid maps of raw data and subtracted maps side by side. The region of interest (ROI) is highlighted in each map.

**Photocurrent Traces:** The right panel displays three sets of photocurrent traces for each unique power level. Each row represents a unique power level, with raw traces, photocurrent estimates, and subtracted traces displayed from left to right, respectively. The top 10 traces with the largest response sum are selected from the chosen ROI and displayed for each category.

**Adjust the ROI via Sliders:** As above, the ROI is selected interactively with the 'X', 'Y', and 'Z' sliders placed below the figures. Adjusting these sliders updates the figures above, changing both the highlighted ROI in the grid maps and the displayed photocurrent traces.


In [9]:
#@title Make interactive plot showing raw vs. subtracted maps and traces
import ipywidgets as widgets
from IPython.display import display

psc_tensor_corrected = phorc.utils.make_psc_tensor_multispot(
            pscs_corrected,
            powers,
            targets,
            stim_matrix
        )
subtracted_map = phorc.utils.traces_tensor_to_map(psc_tensor_corrected)

est_tensor = phorc.utils.make_psc_tensor_multispot(
    photocurrent_est,
    powers,
    targets,
    stim_matrix
)

# Assuming raw_map and psc_tensor are defined
num_planes = raw_map.shape[-1]
n_powers = len(np.unique(powers))

# Define the sliders
x_slider = widgets.IntRangeSlider(min=0, max=raw_map.shape[1], description='X:', continuous_update=False)
y_slider = widgets.IntRangeSlider(min=0, max=raw_map.shape[2], description='Y:', continuous_update=False)
z_slider = widgets.IntRangeSlider(min=0, max=raw_map.shape[3], description='Z:', continuous_update=False)

fig_height=3
fig_width=8
fig = plt.figure(figsize=(fig_width, fig_height), dpi=200)
def interactive_plot(roi_x, roi_y, roi_z):
    global fig
    plt.clf()
    fig = plt.figure(figsize=(fig_width, fig_height), dpi=300)

    # top panel: raw and subtracted maps. Bottom panel, raw and subtracted traces
    subfigs = fig.subfigures(1, 2, wspace=0.07)

    # plot raw and subtracted maps next to each other
    plot_gridmaps(subfigs[0],
        [raw_map, subtracted_map],
        np.arange(num_planes),
        cmaps='magma',
        roi_bounds=(roi_x, roi_y, roi_z),
        powers=np.unique(powers),
        map_names = ['raw', 'subtracted'],
    )

    total_subplots = n_powers
    axes = subfigs[1].subplots(total_subplots, 3, sharex=True, sharey=True)

    print(axes.shape)
    axes_left = axes[:,0]
    axes_middle = axes[:,1]
    axes_right = axes[:,2]
    for i in range(total_subplots):

        # on left, plot raw traces
        these_trials = psc_tensor[i,
                                  roi_x[0]:roi_x[1]+1,
                                  roi_y[0]:roi_y[1]+1,
                                  roi_z[0]:roi_z[1]+1,
                                  :,:]
        these_trials = these_trials.reshape(-1,900)
        these_trials = these_trials[~np.isnan(these_trials[:,0])]

        idxs = np.argsort(these_trials.sum(1))[::-1]
        these_trials = these_trials[idxs[0:10]]
        phorc.utils.plot_current_traces(these_trials.reshape(-1,900), ax=axes_left[i], scalebar_loc='bottom',
                                          linewidth=0.5, alpha=0.8, IV_bar_length=0.01, box_aspect=0.5, scalebar=False)
        axes_left[i].set_title('%d mW' % np.unique(powers)[i])
        axes_left[i].set_ylabel('Current (nA)')


        # in the middle, plot estimates
        these_trials = est_tensor[i,
                                  roi_x[0]:roi_x[1]+1,
                                  roi_y[0]:roi_y[1]+1,
                                  roi_z[0]:roi_z[1]+1,
                                  :,:]
        these_trials = these_trials.reshape(-1,900)
        these_trials = these_trials[~np.isnan(these_trials[:,0])]

        these_trials = these_trials[idxs[0:10]]
        phorc.utils.plot_current_traces(these_trials.reshape(-1,900), ax=axes_middle[i], scalebar_loc='bottom',
                                          linewidth=0.5, alpha=0.8, IV_bar_length=0.01, box_aspect=0.5, scalebar=False)

        # on the right, plot subtracted traces
        these_trials = psc_tensor_corrected[i,
                                  roi_x[0]:roi_x[1]+1,
                                  roi_y[0]:roi_y[1]+1,
                                  roi_z[0]:roi_z[1]+1,
                                  :,:]
        these_trials = these_trials.reshape(-1,900)
        these_trials = these_trials[~np.isnan(these_trials[:,0])]

        these_trials = these_trials[idxs[0:10]]
        phorc.utils.plot_current_traces(these_trials.reshape(-1,900), ax=axes_right[i], scalebar_loc='bottom',
                                          linewidth=0.5, alpha=0.8, IV_bar_length=0.01, box_aspect=0.5, scalebar=False)


    plt.show()

out = widgets.interactive_output(interactive_plot, {'roi_x': x_slider, 'roi_y': y_slider, 'roi_z': z_slider})

# Display everything
display(x_slider, y_slider, z_slider, out)


IntRangeSlider(value=(6, 19), continuous_update=False, description='X:', max=26)

IntRangeSlider(value=(6, 19), continuous_update=False, description='Y:', max=26)

IntRangeSlider(value=(1, 3), continuous_update=False, description='Z:', max=5)

Output()