# Refining CaImAn components
This notebook serves to refine spatiotemporal components extracted via CNMF from CaImAn (https://github.com/flatironinstitute/CaImAn; Giovannucci et al., 2019, eLife), using interactive quality threshold setting followed by manual curation.

Author: Oliver Barnstedt

In [None]:
from ipywidgets import interact, widgets
import numpy as np
import glob
import os
import sys
import pandas as pd
from bokeh.io import push_notebook, show, output_notebook, output_file
from bokeh.plotting import figure, ColumnDataSource
from bokeh.models import HoverTool, CustomJS, Button, LinearAxis, Range1d, Span, PointDrawTool, Slider, Legend
from bokeh.models.glyphs import Circle, Ray
from bokeh.layouts import row, column
from bokeh.events import ButtonClick
output_notebook()
import caiman as cm
from caiman.source_extraction.cnmf.cnmf import load_CNMF
from caiman.utils import visualization as cmviz
from scipy.stats import zscore
import scipy
from IPython.core.display import display, HTML, Markdown, clear_output
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

display(HTML("<style>.container { width:100% !important; }</style>"))
scaling = 0.6  # adapt to optimally fit screen

## Preparing and inspecting data

In [None]:
# enter path to CNM file
cnm_dir = os.path.join(os.path.expanduser('~'), 'data/1P')
cnm_files = glob.glob(os.path.join(cnm_dir, '*cnm.hdf5'))
cnm_files
cnm = load_CNMF(cnm_files[0])

After loading the CNM file generated from CaImAn, we first polish the data using several CaImAn functions and prepare the workspace for visualisation:
* Detrending fluorescence signals, removing any baseline shifts and obtaining deltaF/F
* Marking heavily overlapping neurons and removing them as potential duplicates
* Evaluating components. For 1P, this includes spatial correlations and signal-to-noise ratio (SNR).
* Generating correlation images for visualisation
* Obtaining spatial coordinates and preparing evaluation values of cells for later use with Bokeh interactive visualisation

In [None]:
# PREP DATA

# detrend data if not done yet
if cnm.estimates.F_dff is None or len(cnm.estimates.F_dff)!=len(cnm.estimates.C):
    cnm.estimates.detrend_df_f()

# remove neurons that heavily overlap and might be duplicates
cnm.estimates.threshold_spatial_components()
try: cnm.estimates.remove_duplicates(select_comp=True)
except: pass

# Load motion corrected video
if os.path.exists(cnm.mmap_file):
    print("Loading 'C' memmap file, calculating correlation image...")
    Yr, dims, T = cm.load_memmap(cnm.mmap_file)  # loading 'C' memmap file
    images = np.reshape(Yr.T, [T] + list(dims), order='F')  # load frames in python format (T x X x Y)
else:
    sys.exit("Memory mapped file not found. Has file been moved?")

# evaluate components if not done yet
cnm.params.set('quality', {'use_cnn': False})
if cnm.estimates.r_values is None:
    print("Components not yet evaluated. Running component evaluation now...")
    cnm.estimates.evaluate_components(images, cnm.params)

# generate correlation images
Cn = cm.local_correlations(images.transpose(1, 2, 0))
Cn[np.isnan(Cn)] = 0
cn_filter, pnr = cm.summary_images.correlation_pnr(images[::1], gSig=cnm.params.init['gSig'][0], swap_dim=False) # change swap dim if output looks weird, it is a problem with tiffile

# generate Bokeh Dataframes for interactive visualisation
thresholdsource = ColumnDataSource({'span':[-1,1], 'span_log':[0.01,100], 'min_snr': [cnm.params.quality['min_SNR']]*2, 'min_r': [cnm.params.quality['rval_thr']]*2})    
coords = cmviz.get_contours(cnm.estimates.A, cnm.dims)


In [None]:
# generate some overview plots
cnm.estimates.plot_contours_nb()

fig = plt.figure(figsize=(10,5))
grid = plt.GridSpec(2, 2, wspace=0.4, hspace=0.3)

plt.subplot(grid[:2,0])
plt.imshow(pnr)
plt.title('PNR image')

plt.subplot(grid[0,1])
p=plt.hist(cnm.estimates.SNR_comp[~np.isinf(cnm.estimates.SNR_comp)], bins=24)
plt.title('SNR values')

plt.subplot(grid[1,1])
q=plt.hist(cnm.estimates.r_values, bins=24)
plt.title('r values')

## Interactive evaluation thresholding

In [None]:
# This big chunk prepares the Bokeh interactive component plotting and evaluation tools

if 's1' in locals():
    del s1
if 'comps' in locals():
    del comps
if cnm.estimates.idx_components_bad is not None:
    if len(cnm.estimates.idx_components_bad)>0:
        cnm.estimates.restore_discarded_components();  # if running this part again, restore previously discarded components
    
def normalise_evals(val, by=1, type=float):
    val = np.nan_to_num(val)
    val_bottomed = val - min(val)
    norm = val_bottomed / max(val_bottomed) * by
    return norm.astype(type)

# Loading components' temporal traces
Fcols = ['F_'+str(num) for num in range(len(cnm.estimates.F_dff))]
F = pd.DataFrame(cnm.estimates.F_dff.T, index=np.array(range(len(cnm.estimates.F_dff.T))),
          columns=Fcols)
F['Frame'] = F.index
Ccols = ['C_'+str(num) for num in range(len(cnm.estimates.C))]
C = pd.DataFrame(cnm.estimates.C.T, index=np.array(range(len(cnm.estimates.C.T))),
          columns=Ccols)
Scols = ['S_'+str(num) for num in range(len(cnm.estimates.S))]
S = pd.DataFrame(cnm.estimates.S.T, index=np.array(range(len(cnm.estimates.S.T))),
          columns=Scols)

# Creating a 'comps' ColumnDataSource for Bokeh
comps = pd.DataFrame()
comps['snr'] = [i if not np.isinf(i) else 10 for i in list(cnm.estimates.SNR_comp)]
comps['r'] = np.array(cnm.estimates.r_values)
if cnm.params.init['method_init']=='corr_pnr':
    rgbArray = np.array([normalise_evals(np.sqrt(comps['snr']), 255, 'int'), normalise_evals(comps['r'], 255, 'int')]).T.tolist()
    comps['hex'] = ['#%02x00%02x' % tuple(i) for i in rgbArray]
else:
    comps['cnn'] = np.array(cnm.estimates.cnn_preds)
    rgbArray = np.array([normalise_evals(np.cbrt(comps['snr']), 255, 'int'), normalise_evals(comps['r'], 255, 'int'), normalise_evals(comps['cnn'], 255, 'int')]).T.tolist()
    comps['hex'] = ['#%02x%02x%02x' % tuple(i) for i in rgbArray]
comps['good'] = 1
comps['alpha'] = 0.5
comps['xcoords'] = np.array([coords[x]['coordinates'][:,0] for x in range(len(coords))])
comps['ycoords'] = np.array([coords[y]['coordinates'][:,1] for y in range(len(coords))])
s1 = ColumnDataSource(comps)

TOOLTIPS = [
        ("index", "$index"),
        ("SNR", "@snr"),
        ("R", "@r"),
        ("good", "@good")
    ]

p1 = figure(title="Components: "+str(sum(s1.data['good']))+" out of "+str(len(s1.data['good']))+" selected to meet criteria", 
            plot_height=int(scaling*640), plot_width=int(scaling*640), 
           x_range=(0,cnm.dims[1]), y_range=(cnm.dims[0], 0), tools='hover, reset, tap, pan, box_zoom, wheel_zoom',
           active_scroll = "wheel_zoom", tooltips=TOOLTIPS)

def update_cnmfe(min_snr=cnm.params.quality['min_SNR'], min_r=cnm.params.quality['rval_thr']):
    s1.data['good'] = (((comps['snr']>min_snr) & (comps['r']>min_r)).astype(int))
    thresholdsource.data['min_snr'] = [min_snr]*2
    thresholdsource.data['min_r'] = [min_r]*2
    p1.title.text = "Components: "+str(sum(s1.data['good']))+" out of "+str(len(s1.data['good']))+" selected"
    push_notebook()

i = p1.image(image=[np.flip(cn_filter, axis=0)], x=[0], y=[cn_filter.shape[0]],
           dw=[cn_filter.shape[1]], dh=[cn_filter.shape[0]], palette='Greys256', alpha=0.5)#'Viridis256')
r = p1.patches('xcoords', 
               'ycoords', 
               line_color='black', 
               selection_line_color='black',
               nonselection_line_color='black',
               line_alpha='alpha',
               selection_line_alpha=1,
               nonselection_line_alpha=.5,
               fill_color='hex',
               selection_fill_color='hex',
               fill_alpha='good', 
               selection_fill_alpha=1,
               source=s1)
r.nonselection_glyph.line_dash = r.data_source.data['good']

s2 = ColumnDataSource(data=dict(x=[], y=[]))
s3 = ColumnDataSource(data=F)
s4 = ColumnDataSource(data=dict(x=[], y=[]))
s5 = ColumnDataSource(data=C)
s6 = ColumnDataSource(data=dict(x=[], y=[]))
s7 = ColumnDataSource(data=S)

TOOLTIPS_TRACES = [
        ("Y", "@y"),
        ("X", "@x"),
    ]

p2 = figure(title="Traces", x_axis_label='Frames', y_axis_label=u'\u0394F/F', plot_height=int(scaling*350), 
            plot_width=int(scaling*1100), x_range=(0,F.shape[0]), y_range=(-F.max().max(), F.max().max()), 
            tools='reset, pan, box_zoom', toolbar_location='above')
p2.extra_y_ranges = {"deconv": Range1d(start=C.min().min()-0.1*C.max().max(), end=2*C.max().max())}
p2.add_layout(LinearAxis(y_range_name="deconv", axis_label='AU'), 'right')
f=p2.line('x', 'y', source=s2, line_color='blue', alpha=.5, line_width=scaling*1.3)
c=p2.line('x', 'y', source=s4, line_color='green', alpha=.5, y_range_name="deconv", line_width=scaling*1.3)
s=p2.line('x', 'y', source=s6, line_color='red', alpha=.5, y_range_name="deconv", line_width=scaling*1.3)
p2.add_tools(HoverTool(mode='vline', tooltips=TOOLTIPS_TRACES))
legend = Legend(location='center', 
               glyph_width=5,
               label_width=10,
               items=[
                   (u'\u0394F/F', [f]),
                   ('Deconvolved', [c]),
                   ('Events', [s])
               ])
#p2.add_layout(legend, 'below')  # show legend (space issues though)

p3 = figure(title="Quality: SNR v r", plot_height=int(scaling*400), plot_width=int(scaling*500),
           x_axis_label='SNR', x_axis_type='log', y_axis_label='r', tools='tap, hover', 
            x_range=(min(comps['snr']), np.percentile(comps['snr'],99)), y_range=(min(comps['r']), max(comps['r'])), tooltips=TOOLTIPS)
scatter1 = p3.circle('snr', 'r', fill_color='hex', size=scaling*10, 
                    alpha=.7, line_color='black', line_width=scaling, line_alpha='good', 
                    nonselection_fill_color='black', source=s1)
p3.line(x='min_snr', y='span', line_width=scaling*2, line_color='red', source=thresholdsource)
p3.line(x='span_log', y='min_r', line_width=scaling*2, line_color='green', source=thresholdsource)

s1.selected.js_on_change('indices', CustomJS(args={
'title': p2.title, 'p2':p2, 's1':s1, 's2':s2, 's3':s3, 's4':s4, 's5':s5, 's6':s6, 's7':s7
}, code="""
    var inds = cb_obj.indices;
    var d1 = s1.data;
    var d2 = s2.data;
    var d3 = s3.data;
    var d4 = s4.data;
    var d5 = s5.data;
    var d6 = s6.data;
    var d7 = s7.data;
    var F_idx = 'F_' + inds[0]
    var C_idx = 'C_' + inds[0]
    var S_idx = 'S_' + inds[0]
    d2['x'] = d3['Frame']
    d2['y'] = d3[F_idx]
    var F_max = Math.max(...d2['y']);
    p2.y_range.start = -F_max
    p2.y_range.end = F_max
    d4['x'] = d3['Frame']
    d4['y'] = d5[C_idx]
    var C_min = Math.min(...d4['y']);
    var C_max = Math.max(...d4['y']);
    p2.extra_y_ranges['deconv'].start = C_min-0.1*C_max
    p2.extra_y_ranges['deconv'].end = C_max
    d6['x'] = d3['Frame']
    d6['y'] = d7[S_idx]
    var snr = Math.round(d1['snr'][inds[0]] * 100) / 100
    var r = Math.round(d1['r'][inds[0]] * 100) / 100
    var goodness = d1['good'][inds[0]]
    title.text = 'Component: ' + inds[0] + ', SNR: ' + snr + ', R: ' + r + ', good: ' + goodness
    s2.change.emit();
    s4.change.emit();
    s6.change.emit();
""")
)

trace_col = column(p2, p3)
layout = row(p1, trace_col)

In [None]:
show(layout, notebook_handle=True)
interact(update_cnmfe, 
         min_snr=(0,np.percentile(comps['snr'], 95),.1), 
         min_r=(-0.5,1,.01))

In [None]:
# SAVE AND FILTER COMPONENTS BASED ON NEW THRESHOLDS

new_quality = {}
new_quality['min_SNR'] = thresholdsource.data['min_snr'][0]
new_quality['rval_thr'] = thresholdsource.data['min_r'][0]
cnm.estimates.idx_components = comps[((comps['snr'] > thresholdsource.data['min_snr'][0]) * (comps['r'] > thresholdsource.data['min_r'][0]))].index.tolist()
cnm.estimates.idx_components_bad = [i for i in np.arange(cnm.estimates.nr) if i not in cnm.estimates.idx_components]
cnm.save(cnm_files[0][:-5]+'_refined.hdf5')


## Manual curation of components

After interactively setting quality thresholds, at least components at the edge of the quality pass should be manually inspected to be included or excluded in further analysis. The following interactive tool can help with that.

In [None]:
start_with_best = False  # if True, will start with highest quality components first

if cnm.estimates.idx_components is None:
    cnm.estimates.idx_components = np.arange(len(cnm.estimates.SNR_comp)).tolist()
quality_rank = scipy.stats.rankdata(cnm.estimates.SNR_comp)# + scipy.stats.rankdata(cnm.estimates.r_values)# + scipy.stats.rankdata(cnm.estimates.cnn_preds)
if start_with_best:
    idx = [value for value in np.argsort(quality_rank)[::-1] if value in cnm.estimates.idx_components]  # all filtered components sorted by their quality, starting with the lowest
else:
    idx = [value for value in np.argsort(quality_rank) if value in cnm.estimates.idx_components]  # all filtered components sorted by their quality, starting with the lowest
i = 0
idx_curated = []
idx_curated_bad = []

traces = pd.DataFrame()
traces['x'] = np.array(range(len(cnm.estimates.F_dff.T)))
traces['C'] = cnm.estimates.C[idx[0]]
traces['S'] = cnm.estimates.S[idx[0]]
traces['F'] = cnm.estimates.F_dff[idx[0]]
tracesource = ColumnDataSource(data=traces)

trace_hover = HoverTool(mode='vline', names=["line_with_hovertool"])
trace_hover.tooltips = """
    <style>
        .bk-tooltip>div:not(:first-child) {display:none;}
    </style>

    <b>Frame: </b> @x <br>
    <b>dF/F: </b> @F <br>
    <b>C: </b> @C <br>
    <b>S: </b> @S
"""

p20 = figure(x_axis_label='Frames', 
             y_axis_label=u'\u0394F/F', 
             plot_height=200, 
             plot_width=750, 
             x_range=(0,F.shape[0]), 
             y_range=(min(tracesource.data['F'])-max(tracesource.data['F']), max(tracesource.data['F'])), 
             tools=[trace_hover, 'reset, crosshair, pan, box_zoom'], 
             toolbar_location='above')
p20.extra_y_ranges = {"deconv": Range1d(start=min(tracesource.data['C'])-.01*max(tracesource.data['C']), end=2*max(tracesource.data['C']))}
p20.add_layout(LinearAxis(y_range_name="deconv", axis_label='AU'), 'right')
p20.title.text = 'Component: ' + str(idx[i]) + '/' + str(max(idx)) + ', SNR: ' + str(round(comps['snr'][idx[i]]*100)/100) + ', r: ' + str(round(comps['r'][idx[i]]*100)/100)
f = p20.line('x', 'F', line_color='blue', alpha=.5, line_width=scaling*1, source = tracesource, name="line_with_hovertool")
c = p20.line('x', 'C', line_color='green', alpha=.5, y_range_name="deconv", line_width=1, source=tracesource)
s = p20.line('x', 'S', line_color='red', alpha=.5, y_range_name="deconv", line_width=1, source=tracesource)
legend = Legend(location='center', 
               glyph_width=10,
               label_width=20,
               items=[
                   (u'\u0394F/F', [f]),
                   ('Deconvolved', [c]),
                   ('Events', [s])
               ])
#p20.add_layout(legend, 'below') # show legend for traces (display space issues though)

p10 = figure(plot_height=200, plot_width=300, 
       x_range=(0,cnm.dims[1]), y_range=(cnm.dims[0], 0), tools='reset, tap, pan, box_zoom, wheel_zoom',
       active_scroll = "wheel_zoom")
p10.image(image=[np.flip(pnr, axis=0)], x=[0], y=[pnr.shape[0]], dw=[pnr.shape[1]], dh=[pnr.shape[0]], alpha=0.3, palette='Turbo256')
q = p10.patches(comps['xcoords'][idx], comps['ycoords'][idx], line_color='white', line_alpha=1, line_width=1, line_dash='2 3', fill_alpha=0)
r = p10.patch(comps['xcoords'][idx[i]], comps['ycoords'][idx[i]], line_color='red', line_alpha=1, line_width=3, fill_alpha=0)

buttonYes = widgets.Button(description='Accept')
buttonYes.style.button_color = 'lightgreen'
buttonNo = widgets.Button(description='Reject', button_style='danger')
buttonNo.style.button_color = 'red'
display(buttonYes); display(buttonNo)
layout = column(p20, p10)
show(layout, notebook_handle=True)
      
def on_button_clicked_yes(b):
    global i
    idx_curated.append(idx[i])
    i+=1
    if i < len(idx):
        p20.title.text = 'Component: ' + str(idx[i]) + '/' + str(max(idx)) + ', SNR: ' + str(round(comps['snr'][idx[i]]*100)/100) + ', r: ' + str(round(comps['r'][idx[i]]*100)/100)
        tracesource.data['C'] = cnm.estimates.C[idx[i]]
        tracesource.data['S'] = cnm.estimates.S[idx[i]]
        tracesource.data['F'] = cnm.estimates.F_dff[idx[i]]
        p20.y_range.start = min(tracesource.data['F']) - max(tracesource.data['F'])
        p20.y_range.end = max(tracesource.data['F'])
        p20.extra_y_ranges['deconv'].start = min(tracesource.data['C']) - .01*max(tracesource.data['C'])
        p20.extra_y_ranges['deconv'].end = 2*max(tracesource.data['C'])
        r = p10.patch(comps['xcoords'][idx[i]], comps['ycoords'][idx[i]], line_color='red', line_alpha=1, line_width=scaling*3, fill_alpha=0)
        if idx_curated:
            r = p10.patch(comps['xcoords'][idx_curated[-1]], comps['ycoords'][idx_curated[-1]], line_alpha=0, fill_alpha=1, fill_color='white')
        if idx_curated_bad:
            r = p10.patch(comps['xcoords'][idx_curated_bad[-1]], comps['ycoords'][idx_curated_bad[-1]], line_alpha=0, fill_alpha=1, fill_color='black')
        push_notebook()
    else:
        print("Curation finished. You can filter and save the CNM object now.")

def on_button_clicked_no(b):
    global i
    idx_curated_bad.append(idx[i])
    i+=1
    if i < len(idx):
        p20.title.text = 'Component: ' + str(idx[i]) + '/' + str(max(idx)) + ', SNR: ' + str(round(comps['snr'][idx[i]]*100)/100) + ', r: ' + str(round(comps['r'][idx[i]]*100)/100)
        tracesource.data['C'] = cnm.estimates.C[idx[i]]
        tracesource.data['S'] = cnm.estimates.S[idx[i]]
        tracesource.data['F'] = cnm.estimates.F_dff[idx[i]]
        p20.y_range.start = min(tracesource.data['F']) - max(tracesource.data['F'])
        p20.y_range.end = max(tracesource.data['F'])
        p20.extra_y_ranges['deconv'].start = min(tracesource.data['C']) - .01*max(tracesource.data['C'])
        p20.extra_y_ranges['deconv'].end = 2*max(tracesource.data['C'])
        r = p10.patch(comps['xcoords'][idx[i]], comps['ycoords'][idx[i]], line_color='red', line_alpha=1, line_width=scaling*3, fill_alpha=0)
        if idx_curated:
            r = p10.patch(comps['xcoords'][idx_curated[-1]], comps['ycoords'][idx_curated[-1]], line_alpha=0, fill_alpha=1, fill_color='white')
        if idx_curated_bad:
            r = p10.patch(comps['xcoords'][idx_curated_bad[-1]], comps['ycoords'][idx_curated_bad[-1]], line_alpha=0, fill_alpha=1, fill_color='black')
        push_notebook()
    else:
        print("Curation finished. You can filter and save the CNM object now.")
        
buttonYes.on_click(on_button_clicked_yes)
buttonNo.on_click(on_button_clicked_no)

In [None]:
# SAVE CURATED SELECTION
include_uncurated = True  # if set to TRUE, 'good' index will include all components that have not been curated (yet)
if include_uncurated: 
    idx_curated = [item for item in idx if item not in idx_curated_bad]
print('Components included: {}'.format(idx_curated))
print('Saving new component selection of {} components...'.format(len(idx_curated)))
try: cnm.estimates.select_components(idx_components=idx_curated, save_discarded_components=True)
except: pass
cnm.save(cnm_files[0][:-5]+'_refined.hdf5')
print('Successfully saved to '+cnm_files[0][:-5]+'_refined.hdf5')


len(cnm.estimates.idx_components)