# CNMF Componenent evaluation

This is similar to the last visualization in `cnmf_viz.ipynb` but combines it with component evaluation

View cmaps with `r_values`, `cnn_preds`, `comp_SNR`

Click on components and press `"g"` to add to good, `"b"` to add to bad.

In [None]:
from mesmerize_core import *
import numpy as np
import pandas as pd

In [None]:
from fastplotlib import ImageWidget, Plot, GridPlot
import fastplotlib as fpl
from fastplotlib.utils import make_colors
from ipywidgets import VBox, IntSlider, Layout

In [None]:
pd.options.display.max_colwidth = 120

# Paths

In [None]:
# for this demo set this dir as the path to your `caiman_data` dir
set_parent_raw_data_path("/home/kushal/caiman_data/")

# batch path
batch_path = get_parent_raw_data_path().joinpath("mesmerize-batch/batch.pickle")

## Load batch

In [None]:
df = load_batch(batch_path)
df

In [None]:
# You can change this to plot the outputs for different batch items
index = 1

# get the motion corrected input movie as a memmap
cnmf_movie = df.iloc[index].caiman.get_input_movie()

# we can get the contours of the spatial components
contours, coms = df.iloc[index].cnmf.get_contours("all", swap_dim=False)

# and temporal components
temporal = df.iloc[index].cnmf.get_temporal("all")

ixs_good = df.iloc[index].cnmf.get_good_components()
ixs_bad = df.iloc[index].cnmf.get_bad_components()

cnmf_obj = df.iloc[index].cnmf.get_output()

In [None]:
np.sort(cnmf_obj.estimates.r_values)

In [None]:
rcm = df.iloc[index].cnmf.get_rcm()
corr_img = df.iloc[index].caiman.get_corr_image()

In [None]:
class DummyMovie:
    def __init__(self, image: np.ndarray, shape, ndim, size):
        self.image = image
        self.shape = shape
        self.ndim = ndim
        self.size = size
        
    def __getitem__(self, index: int):
        return self.image

In [None]:
from ipywidgets import RadioButtons, FloatSlider

In [None]:
# 2 x 2 gridplot
# |=====================|
# |  movie |    rcm     |
# |---------------------|
# | corr  |  temporal   |
# |=====================|
cnmf_grid = GridPlot(
    shape=(1, 3), 
    controllers="sync", 
    names=[["movie", "rcm", "corr"]]
)

corr_img_movie = DummyMovie(corr_img, shape=cnmf_movie.shape, ndim=cnmf_movie.ndim, size=cnmf_movie.size)

iw = fpl.ImageWidget(
    [cnmf_movie, rcm, corr_img_movie],
    vmin_vmax_sliders=True,
    names=["movie", "rcm", "corr"],
    cmap="gnuplot2"
)

snr_comps_log = np.log10(cnmf_obj.estimates.SNR_comp)
r_values = cnmf_obj.estimates.r_values
cnn_preds = cnmf_obj.estimates.cnn_preds

# add contours to both movie and rcm subplots
contours_movie = iw.gridplot["movie"].add_line_collection(
    contours,
    cmap="spring",
    cmap_values=np.log(snr_comps),
    thickness=3,
    name="contours",
)

contours_rcm = iw.gridplot["rcm"].add_line_collection(
    contours,
    cmap="spring",
    cmap_values=np.log(snr_comps),
    thickness=3,
    name="contours",
)

contours_corr = iw.gridplot["corr"].add_line_collection(
    contours,
    cmap="spring",
    cmap_values=np.log(snr_comps),
    thickness=3,
    name="contours",
)

# plot single temporal, just like before
plot_temporal_single = Plot()
temporal_graphic = plot_temporal_single.add_line_collection(temporal, colors="w")

radio_eval_metrics = RadioButtons(
    options=["snr", "r_values", "cnn_preds"],
    value="snr",
    description="current eval colors"
)

def change_eval_metric(change):
    value = change["new"]
    
    current_bad_ixs = df.iloc[index].cnmf.get_output().estimates.idx_components_bad
    
    if value == "snr":
        contours_rcm.cmap_values = snr_comps_log
        contours_movie.cmap_values = snr_comps_log
        contours_corr.cmap_values = snr_comps_log
        
    elif value == "r_values":
        contours_rcm.cmap_values = r_values
        contours_movie.cmap_values = r_values
        contours_corr.cmap_values = r_values
    
    elif value == "cnn_preds":
        contours_rcm.cmap_values = cnn_preds
        contours_movie.cmap_values = cnn_preds
        contours_corr.cmap_values = cnn_preds
        
    contours_rcm[current_bad_ixs].colors = "w"
    contours_movie[current_bad_ixs].colors = "w"
    contours_corr[current_bad_ixs].colors = "w"
        
radio_eval_metrics.observe(change_eval_metric, "value")

slider_min_snr = FloatSlider(
    min=snr_comps.min(), 
    max=snr_comps.max(), 
    value=cnmf_obj.params.get_group("quality")["min_SNR"],
    step=(snr_comps.max() - snr_comps.min()) / 100,
    description="min_SNR",
)


def update_eval(change):
    new_params = {
        "min_SNR": slider_min_snr.value
    }
    
    df.iloc[index].cnmf.run_eval(new_params)
    change_eval_metric({"new": radio_eval_metrics.value})
    

slider_min_snr.observe(update_eval, "value")
    
VBox([plot_temporal_single.show(), iw.show(), radio_eval_metrics, slider_min_snr])

In [None]:
df.iloc[1].cnmf.get_output_path()

In [None]:
import h5py

In [None]:
h5_path = df.iloc[1].cnmf.get_output_path()

f = h5py.File(h5_path)

In [None]:
f.keys()

In [None]:
f["estimates"].keys()

In [None]:
f["estimates"]["idx_components"][:]

In [None]:
f["estimates"]["idx_components"][:]

In [None]:
f["estimates"]["F_dff"][:]

In [None]:
f["estimates"]["F_dff"].shape