# Refining CaImAn components
This notebook serves to refine spatiotemporal components extracted via CNMF from CaImAn (https://github.com/flatironinstitute/CaImAn; Giovannucci et al., 2019, eLife), using interactive quality threshold setting followed by manual curation.

In [None]:
from ipywidgets import interact, widgets
import numpy as np
import glob
import os
import pandas as pd
from bokeh.io import push_notebook, show, output_notebook, output_file
from bokeh.plotting import figure, ColumnDataSource
from bokeh.models import HoverTool, CustomJS, Button, LinearAxis, Range1d, Span, PointDrawTool
from bokeh.models.glyphs import Circle, Ray
from bokeh.layouts import row, column
from bokeh.events import ButtonClick
output_notebook()
import caiman as cm
from caiman.source_extraction.cnmf.cnmf import load_CNMF
from caiman.utils import visualization as cmviz
from scipy.stats import zscore
from IPython.core.display import display, HTML, Markdown, clear_output

display(HTML("<style>.container { width:100% !important; }</style>"))
scaling = 1

In [None]:
cnm_path = ''
cnm = load_CNMF(cnm_path)

In [None]:
cnm.mmap_file = cnm.mmap_file  # option to manually change mmap file path if file has been moved since creation of cnm file
cnm.estimates.restore_discarded_components()
if os.path.exists(cnm.mmap_file):
    print("Loading 'C' memmap file, calculating correlation image...")
    Yr, dims, T = cm.load_memmap(cnm.mmap_file)  # loading 'C' memmap file
    images = np.reshape(Yr.T, [T] + list(dims), order='F')  # load frames in python format (T x X x Y)
    Cn = cm.local_correlations(images.transpose(1, 2, 0))
    Cn[np.isnan(Cn)] = 0
    if not cnm.estimates.cnn_preds.any():
        print("Components not yet evaluated. Running component evaluation now...")
        cnm.estimates.evaluate_components(images, cnm.params)
else:
    sys.exit("Memory mapped file not found. Has file been moved?")
coords = cmviz.get_contours(cnm.estimates.A, cnm.dims)
thresholdsource = ColumnDataSource({'span':[-1,1], 'span_log':[0.01,100], 'thr_snr': [cnm.params.quality['min_SNR']]*2, 'min_snr': [cnm.params.quality['SNR_lowest']]*2, 'thr_r': [cnm.params.quality['rval_thr']]*2, 'thr_cnn': [cnm.params.quality['min_cnn_thr']]*2, 'min_r': [cnm.params.quality['rval_lowest']]*2, 'min_cnn': [cnm.params.quality['cnn_lowest']]*2})

In [None]:
def normalise_evals(val, by=1, type=float):
    val_bottomed = val - min(val)
    norm = val_bottomed / max(val_bottomed) * by
    return norm.astype(type)

# Loading components' temporal traces
Fcols = ['F_'+str(num) for num in range(len(cnm.estimates.F_dff))]
F = pd.DataFrame(cnm.estimates.F_dff.T, index=np.array(range(len(cnm.estimates.F_dff.T))),
          columns=Fcols)
F['Frame'] = F.index
Ccols = ['C_'+str(num) for num in range(len(cnm.estimates.C))]
C = pd.DataFrame(cnm.estimates.C.T, index=np.array(range(len(cnm.estimates.C.T))),
          columns=Ccols)
Scols = ['S_'+str(num) for num in range(len(cnm.estimates.S))]
S = pd.DataFrame(cnm.estimates.S.T, index=np.array(range(len(cnm.estimates.S.T))),
          columns=Scols)

# Creating a 'comps' ColumnDataSource for Bokeh
comps = pd.DataFrame()
comps['snr'] = cnm.estimates.SNR_comp
comps['r'] = np.array(cnm.estimates.r_values)
comps['cnn'] = np.array(cnm.estimates.cnn_preds)
rgbArray = np.array([normalise_evals(np.cbrt(comps['snr']), 255, 'int'), normalise_evals(comps['r'], 255, 'int'), normalise_evals(comps['cnn'], 255, 'int')]).T.tolist()
comps['hex'] = ['#%02x%02x%02x' % tuple(i) for i in rgbArray]
comps['good'] = 1
comps['alpha'] = 0.5
comps['xcoords'] = np.array([coords[x]['coordinates'][:,0] for x in range(len(coords))])
comps['ycoords'] = np.array([coords[y]['coordinates'][:,1] for y in range(len(coords))])
s1 = ColumnDataSource(comps)

TOOLTIPS = [
        ("index", "$index"),
        ("SNR", "@snr"),
        ("R", "@r"),
        ("CNN", "@cnn"),
        ("good", "@good")
    ]

p1 = figure(title="Components: "+str(sum(s1.data['good']))+" out of "+str(len(s1.data['good']))+" selected to meet criteria", plot_height=scaling*640, plot_width=scaling*640, 
           x_range=(0,cnm.dims[0]), y_range=(cnm.dims[1], 0), tools='hover, reset, tap, pan, box_zoom, wheel_zoom',
           active_scroll = "wheel_zoom", tooltips=TOOLTIPS)

def update(thr_snr=cnm.params.quality['min_SNR'], thr_r=cnm.params.quality['rval_thr'], thr_cnn=cnm.params.quality['min_cnn_thr'], min_snr=cnm.params.quality['SNR_lowest'], min_r=cnm.params.quality['rval_lowest'], min_cnn=cnm.params.quality['cnn_lowest']):
    s1.data['good'] = (((comps['snr']>thr_snr).astype(int) + (comps['r']>thr_r).astype(int) + (comps['cnn']>thr_cnn).astype(int) >= 2) & ((comps['snr']>min_snr) & (comps['r']>min_r) & (comps['cnn']>min_cnn)).astype(int))
    thresholdsource.data['thr_snr'] = [thr_snr]*2
    thresholdsource.data['min_snr'] = [min_snr]*2
    thresholdsource.data['thr_r'] = [thr_r]*2
    thresholdsource.data['min_r'] = [min_r]*2
    thresholdsource.data['thr_cnn'] = [thr_cnn]*2
    thresholdsource.data['min_cnn'] = [min_cnn]*2
    p1.title.text = "Components: "+str(sum(s1.data['good']))+" out of "+str(len(s1.data['good']))+" selected"
    push_notebook()

i = p1.image(image=[np.flip(Cn, axis=0)], x=[0], y=[Cn.shape[1]],
           dw=[Cn.shape[0]], dh=[Cn.shape[1]], palette='Greys256')#'Viridis256')
r = p1.patches('xcoords', 
               'ycoords', 
               line_color='white', 
               selection_line_color='white',
               nonselection_line_color='white',
               line_alpha='alpha',
               selection_line_alpha=1,
               nonselection_line_alpha=.5,
               fill_color='hex',
               selection_fill_color='hex',
               fill_alpha='good', 
               selection_fill_alpha=1,
               source=s1)
r.nonselection_glyph.line_dash = r.data_source.data['good']

s2 = ColumnDataSource(data=dict(x=[], y=[]))
s3 = ColumnDataSource(data=F)
s4 = ColumnDataSource(data=dict(x=[], y=[]))
s5 = ColumnDataSource(data=C)
s6 = ColumnDataSource(data=dict(x=[], y=[]))
s7 = ColumnDataSource(data=S)

TOOLTIPS_TRACES = [
        ("Y", "@y"),
        ("X", "@x"),
    ]

p2 = figure(title="Traces", x_axis_label='Frames', y_axis_label=u'\u0394F/F', plot_height=scaling*250, plot_width=scaling*900, x_range=(0,F.shape[0]), y_range=(-1.5, 3), tools='reset, pan, box_zoom', toolbar_location='above')
p2.extra_y_ranges = {"deconv": Range1d(start=-500, end=20000)}
p2.add_layout(LinearAxis(y_range_name="deconv", axis_label='AU'), 'right')
p2.line('x', 'y', source=s2, line_color='blue', alpha=.5, legend_label=u'\u0394F/F', line_width=scaling*1)
p2.line('x', 'y', source=s4, line_color='green', alpha=.5, y_range_name="deconv", legend_label='Deconvolved', line_width=scaling*1)
p2.line('x', 'y', source=s6, line_color='red', alpha=.5, y_range_name="deconv", legend_label='Events', line_width=scaling*1)
p2.add_tools(HoverTool(mode='vline', tooltips=TOOLTIPS_TRACES))

p3 = figure(title="Quality: SNR v r", plot_height=scaling*220, plot_width=scaling*220,
           x_axis_label='SNR', x_axis_type='log', y_axis_label='r', tools='tap, hover', x_range=(min(comps['snr']), max(comps['snr'])), y_range=(min(comps['r']), max(comps['r'])), tooltips=TOOLTIPS)
scatter1 = p3.circle('snr', 'r', fill_color='hex', size=scaling*6, 
                    alpha=.7, line_color='black', line_width=scaling*.8, line_alpha='good', 
                    nonselection_fill_color='black', source=s1)
p3.line(x='thr_snr', y='span', line_width=scaling*2, line_color='red', line_dash=[2*scaling,3*scaling], source=thresholdsource)
p3.line(x='min_snr', y='span', line_width=scaling*2, line_color='red', source=thresholdsource)
p3.line(x='span_log', y='thr_r', line_width=scaling*2, line_color='green', line_dash=[2*scaling,3*scaling], source=thresholdsource)
p3.line(x='span_log', y='min_r', line_width=scaling*2, line_color='green', source=thresholdsource)

p4 = figure(title="Quality: SNR v CNN", plot_height=scaling*220, plot_width=scaling*220,
           x_axis_label='SNR', y_axis_label='CNN', tools='tap, hover', x_axis_type='log', x_range=(min(comps['snr']), max(comps['snr'])), y_range=(min(comps['cnn']), max(comps['cnn'])), tooltips=TOOLTIPS)
scatter2 = p4.circle('snr', 'cnn', fill_color='hex', size=scaling*6, 
                    alpha=.7, line_color='black', line_width=scaling*.8, line_alpha='good', 
                    nonselection_fill_color='black', source=s1)
p4.line(x='thr_snr', y='span', line_width=scaling*2, line_color='red', line_dash=[2*scaling,3*scaling], source=thresholdsource)
p4.line(x='min_snr', y='span', line_width=scaling*2, line_color='red', source=thresholdsource)
p4.line(x='span_log', y='thr_cnn', line_width=scaling*2, line_color='blue', line_dash=[2*scaling,3*scaling], source=thresholdsource)
p4.line(x='span_log', y='min_cnn', line_width=scaling*2, line_color='blue', source=thresholdsource)

p5 = figure(title="Quality: CNN v r", plot_height=scaling*220, plot_width=scaling*220,
           x_axis_label='CNN', y_axis_label='r', tools='tap, hover', x_range=(min(comps['cnn']), max(comps['cnn'])), y_range=(min(comps['r']), max(comps['r'])), tooltips=TOOLTIPS)
scatter3 = p5.circle('cnn', 'r', fill_color='hex', size=scaling*6, 
                    alpha=.7, line_color='black', line_width=scaling*.8, line_alpha='good', 
                    nonselection_fill_color='black', source=s1)
p5.line(x='thr_cnn', y='span', line_width=scaling*2, line_color='blue', line_dash=[2*scaling,3*scaling], source=thresholdsource)
p5.line(x='min_cnn', y='span', line_width=scaling*2, line_color='blue', source=thresholdsource)
p5.line(x='span', y='thr_r', line_width=scaling*2, line_color='green', line_dash=[2*scaling,3*scaling], source=thresholdsource)
p5.line(x='span', y='min_r', line_width=scaling*2, line_color='green', source=thresholdsource)

s1.selected.js_on_change('indices', CustomJS(args={
    'title': p2.title, 's1':s1, 's2':s2, 's3':s3, 's4':s4, 's5':s5, 's6':s6, 's7':s7
    }, code="""
        var inds = cb_obj.indices;
        var d1 = s1.data;
        var d2 = s2.data;
        var d3 = s3.data;
        var d4 = s4.data;
        var d5 = s5.data;
        var d6 = s6.data;
        var d7 = s7.data;
        var F_idx = 'F_' + inds[0]
        var C_idx = 'C_' + inds[0]
        var S_idx = 'S_' + inds[0]
        d2['x'] = d3['Frame']
        d2['y'] = d3[F_idx]
        d4['x'] = d3['Frame']
        d4['y'] = d5[C_idx]
        d6['x'] = d3['Frame']
        d6['y'] = d7[S_idx]
        var snr = Math.round(d1['snr'][inds[0]] * 100) / 100
        var r = Math.round(d1['r'][inds[0]] * 100) / 100
        var cnn = Math.round(d1['cnn'][inds[0]] * 100) / 100
        var goodness = d1['good'][inds[0]]
        title.text = 'Component: ' + inds[0] + ', SNR: ' + snr + ', R: ' + r + ', CNN: ' + cnn + ', good: ' + goodness
        s2.change.emit();
        s4.change.emit();
        s6.change.emit();
    """)
)

quality_col = column(p3,p4,p5)
row1 = row(p2)
row2 = row(p1, quality_col)
layout = column(row1, row2)

In [None]:
show(layout, notebook_handle=True)
interact(update, 
         thr_snr=(0,np.percentile(comps['snr'], 95),.1), 
         thr_r=(-0.5,1,.01), 
         thr_cnn=(0,1,.01),
         min_snr=(0,np.percentile(comps['snr'], 95),.1), 
         min_r=(-0.5,1,.01), 
         min_cnn=(0,1,.01))

In [None]:
# SAVE AND FILTER COMPONENTS BASED ON NEW THRESHOLDS

new_quality = {}
new_quality['min_SNR'] = thresholdsource.data['thr_snr'][0]
new_quality['rval_thr'] = thresholdsource.data['thr_r'][0]
new_quality['min_cnn_thr'] = thresholdsource.data['thr_cnn'][0]
new_quality['SNR_lowest'] = thresholdsource.data['min_snr'][0]
new_quality['rval_lowest'] = thresholdsource.data['min_r'][0]
new_quality['cnn_lowest'] = thresholdsource.data['min_cnn'][0]
cnm.estimates.filter_components(images, cnm.params, new_quality)
cnm.save(cnm_path)

In [None]:
if not cnm.estimates.idx_components.any():
    cnm.estimates.restore_discarded_components()
idx = cnm.estimates.idx_components
i = 0
idx_curated = []
idx_curated_bad = []

traces = pd.DataFrame()
traces['x'] = np.array(range(len(cnm.estimates.F_dff.T)))
traces['C'] = cnm.estimates.C[idx[0]]
traces['S'] = cnm.estimates.S[idx[0]]
traces['F'] = cnm.estimates.F_dff[idx[0]]
tracesource = ColumnDataSource(data=traces)

trace_hover = HoverTool(mode='vline', names=["line_with_hovertool"])
trace_hover.tooltips = """
    <style>
        .bk-tooltip>div:not(:first-child) {display:none;}
    </style>

    <b>Frame: </b> @x <br>
    <b>dF/F: </b> @F <br>
    <b>C: </b> @C <br>
    <b>S: </b> @S
"""

p20 = figure(x_axis_label='Frames', y_axis_label=u'\u0394F/F', plot_height=scaling*250, plot_width=scaling*900, x_range=(0,F.shape[0]), y_range=(-1.5, 3), tools=[trace_hover, 'reset, crosshair, pan, box_zoom'], toolbar_location='above')
p20.extra_y_ranges = {"deconv": Range1d(start=-500, end=20000)}
p20.add_layout(LinearAxis(y_range_name="deconv", axis_label='AU'), 'right')
p20.title.text = 'Component: ' + str(idx[i]) + '/' + str(max(idx)) + ', SNR: ' + str(round(comps['snr'][idx[i]]*100)/100) + ', r: ' + str(round(comps['r'][idx[i]]*100)/100) + ', CNN: ' + str(round(comps['cnn'][idx[i]]*100)/100)
f = p20.line('x', 'F', line_color='blue', alpha=.5, legend_label=u'\u0394F/F', line_width=scaling*1, source = tracesource, name="line_with_hovertool")
c = p20.line('x', 'C', line_color='green', alpha=.5, y_range_name="deconv", legend_label='Deconvolved', line_width=scaling*1, source=tracesource)
s = p20.line('x', 'S', line_color='red', alpha=.5, y_range_name="deconv", legend_label='Events', line_width=scaling*1, source=tracesource)

p10 = figure(plot_height=scaling*640, plot_width=scaling*640, 
       x_range=(0,cnm.dims[0]), y_range=(cnm.dims[1], 0), tools='reset, tap, pan, box_zoom, wheel_zoom',
       active_scroll = "wheel_zoom")
p10.image(image=[np.flip(Cn, axis=0)], x=[0], y=[Cn.shape[1]], dw=[Cn.shape[0]], dh=[Cn.shape[1]], palette='Viridis256')
q = p10.patches(comps['xcoords'][idx], comps['ycoords'][idx], line_color='white', line_alpha=1, line_width=scaling*1, line_dash=[2*scaling,3*scaling], fill_alpha=0)
r = p10.patch(comps['xcoords'][idx[i]], comps['ycoords'][idx[i]], line_color='white', line_alpha=1, line_width=scaling*3, fill_alpha=0)

buttonYes = widgets.Button(description='Accept')
buttonYes.style.button_color = 'lightgreen'
buttonNo = widgets.Button(description='Reject', button_style='danger')
buttonNo.style.button_color = 'red'
display(buttonYes); display(buttonNo)
layout = column(p20, p10)
show(layout, notebook_handle=True)
      
def on_button_clicked_yes(b):
    global i
    if i < len(idx):
        idx_curated.append(idx[i])
        i+=1
        p20.title.text = 'Component: ' + str(idx[i]) + '/' + str(max(idx)) + ', SNR: ' + str(round(comps['snr'][idx[i]]*100)/100) + ', r: ' + str(round(comps['r'][idx[i]]*100)/100) + ', CNN: ' + str(round(comps['cnn'][idx[i]]*100)/100)
        tracesource.data['C'] = cnm.estimates.C[idx[i]]
        tracesource.data['S'] = cnm.estimates.S[idx[i]]
        tracesource.data['F'] = cnm.estimates.F_dff[idx[i]]
        r = p10.patch(comps['xcoords'][idx[i]], comps['ycoords'][idx[i]], line_color='white', line_alpha=1, line_width=scaling*3, fill_alpha=0)
        if idx_curated:
            r = p10.patch(comps['xcoords'][idx_curated[-1]], comps['ycoords'][idx_curated[-1]], line_alpha=0, fill_alpha=1, fill_color='white')
        if idx_curated_bad:
            r = p10.patch(comps['xcoords'][idx_curated_bad[-1]], comps['ycoords'][idx_curated_bad[-1]], line_alpha=0, fill_alpha=1, fill_color='black')
        push_notebook()

def on_button_clicked_no(b):
    global i
    if i < len(idx):
        idx_curated_bad.append(idx[i])
        i+=1
        p20.title.text = 'Component: ' + str(idx[i]) + '/' + str(max(idx)) + ', SNR: ' + str(round(comps['snr'][idx[i]]*100)/100) + ', r: ' + str(round(comps['r'][idx[i]]*100)/100) + ', CNN: ' + str(round(comps['cnn'][idx[i]]*100)/100)
        tracesource.data['C'] = cnm.estimates.C[idx[i]]
        tracesource.data['S'] = cnm.estimates.S[idx[i]]
        tracesource.data['F'] = cnm.estimates.F_dff[idx[i]]
        r = p10.patch(comps['xcoords'][idx[i]], comps['ycoords'][idx[i]], line_color='white', line_alpha=1, line_width=scaling*3, fill_alpha=0)
        if idx_curated:
            r = p10.patch(comps['xcoords'][idx_curated[-1]], comps['ycoords'][idx_curated[-1]], line_alpha=0, fill_alpha=1, fill_color='white')
        if idx_curated_bad:
            r = p10.patch(comps['xcoords'][idx_curated_bad[-1]], comps['ycoords'][idx_curated_bad[-1]], line_alpha=0, fill_alpha=1, fill_color='black')
        push_notebook()
        
buttonYes.on_click(on_button_clicked_yes)
buttonNo.on_click(on_button_clicked_no)

In [None]:
# SAVE CURATED SELECTION
include_uncurated = True  # if set to TRUE, 'good' index will include all components that have not been curated (yet)
if include_uncurated: 
    idx_curated = [item for item in idx if item not in idx_curated_bad]
print(idx_curated)
print('Saving new component selection...')
cnm.estimates.select_components(idx_components=idx_curated, save_discarded_components=True)
cnm.save(cnm_path)