# Inspect and refine multisession outputs

## Description

This is currently being tested to inspect and refine the output of processed `roicat` outputs.

Very **EXPERIMENTAL**. Use with caution

## Requirements

- `numpy`
- `pandas`
- `scikit-image`
- `matplotlib` with `tk` (can change to others like `qt` if needed, but not tested)

## Limitations

This currently uses `tk` and cannot run on Oscar headless. Not sure about Open Ondemand Jupyter.

So either run this locally or use Open Ondemand Desktop session.

The rendering is currently not optimized and will be slow (as it replots many things depending on the operation).

This only plots the `iscell` ones.

## Instructions

### Paths and config

Define the parameters and paths in `Define paths and parameters`

Then run the whole notebook. 

### GUI interactions

- Note: 
    - Sometimes it might take some time to render
    - The buttons can sometimes be not very responsive, click hard on them
- To simply inspect: 
    - Click on an ROI of a cetain session
    - If available, **red** contours of the same ROI will show up in other sessions
    - The titles will change:
        - Lowercase `id` is the session local ROI index
        - Uppercase `ID` is the global experiment ROI index
        - If ROI cannot be found in a given session based on `roicat` output, the title will say `NOT FOUND`
- At any point, to clear current selections and requests, click on `Clear`
- To **chain** ROIs across sessions: 
    - First select an ROI so that the **red** contours show up
    - Then click `Chain Request`
    - Then click on another ROI in other sessions
    - **Light blue** contours will show up
    - Then click `Chain`
    - Wait a bit, the colors will change so they will match
    - Click `Clear` and then re-select the ROIs to make sure they now have the **same** global ID (i.e. *chained*)
    - Note: only ROIs that do not appear in the same sessions can be chained
- To **unchain** ROIs of a certain session (only one at a time):
    - First select an ROI so that the **red** contours show up
    - Then click `Unchain Request`
    - Then enter the session number (starts from 1) to unchain in the box next to it
    - Then press the `Enter` key (or `Return` key on Mac)
    - Wait a bit, the color of the requested ROI to unchain will change
    - Click `Clear` and then re-select the ROIs to make sure they now have **different** global IDs (i.e. *unchained*)
- If you're satisfied, click `Save`, then the `refined-roi.csv` file will appear
    - Check the same `muse_dir` directory to make sure a file `refined-roi.csv` appears
    - Note: this will not save the color columns
- To re-inspect the refined output, re-run the notebook with `plot_refined = True`


## Import

In [None]:
import os
import glob
import warnings
from datetime import datetime

from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import skimage

import matplotlib.pyplot as plt
from matplotlib.widgets import Button, MultiCursor, TextBox

# change to other backend if needed
%matplotlib tk

## Define paths and parameters

In [None]:
# Define where multisession directory is (with specified plane)
# e.g:`path/to/SUBJECT_ID/multi-session/PLANE_ID`
muse_dir = 'data/SD_0664/multi-session/plane0'

# only if there's a `refined-roi.csv` file
# if run for first time of a `muse_dir`, use `False`
# use this to indicate whether you want to replot the refined ROI
plot_refined = True

# Define visualization configurations (combining ROI masks and background images)
fuse_cfg = dict(
    background_choice = 'max_proj', # max_proj, meanImg, meanImgE, Vcorr
    fuse_alpha = 0.4, # fusing ROI mask with background
    fuse_bg_max_p = 98, # for clipping background max value with percentile
)

# use -1, 0 or None to indicate num subplot columns = num sessions
# use integer > 2 otherwise 
column_wrap = None

## Initialize

In [None]:
# define paths
muse_dir = Path(muse_dir)
roicat_out_file = muse_dir / 'roicat-output.pkl'
aligned_img_file = muse_dir / 'aligned-images.pkl'
roi_table_file = muse_dir / 'finalized-roi.csv'
save_roi_file = muse_dir / 'refined-roi.csv'

if plot_refined:
    assert save_roi_file.exists(), f'The refined file "{save_roi_file}" does not exist to plot'
    roi_table_file = save_roi_file

In [None]:
# read in
with open(roicat_out_file, 'rb') as f:
    roicat_out = pickle.load(f)

with open(aligned_img_file, 'rb') as f:
    aligned_images = pickle.load(f)

aligned_rois = roicat_out['results']['ROIs']['ROIs_aligned']

# TODO: unclear if this is the right order
image_dims = (
    roicat_out['results']['ROIs']['frame_width'],
    roicat_out['results']['ROIs']['frame_height']
)
roi_table = pd.read_csv(roi_table_file)

num_sessions = len(aligned_rois)
assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))

num_global_rois = roi_table['global_roi'].nunique()

In [None]:
# assign unique colors
global_roi_colors = np.random.rand(num_global_rois, 3).round(3)

roi_table['global_roi_color'] = roi_table['global_roi'].apply(
    lambda x: global_roi_colors[x]
)

# save backup of modifiable columns 
backup_tag = datetime.now().strftime('pre[%Y%m%d_%H%M%S]')

roi_table[f'{backup_tag}::num_sessions'] = roi_table['num_sessions'].copy()
roi_table[f'{backup_tag}::global_roi'] = roi_table['global_roi'].copy()
roi_table[f'{backup_tag}::global_roi_color'] = roi_table['global_roi_color'].copy()

roi_table

In [None]:
# functions
# TODO: these should be in a source file
def throw_invalid_footprint_warning(footprints, roi_table, session):
    invalid_max_idx = footprints.max(axis=1).toarray().squeeze() <= 0
    if sum(invalid_max_idx) == 0:
        return
    
    invalid_max_idx = np.where(invalid_max_idx)[0]
    
    roi_record = roi_table.query(
        'session == @session and ' \
        'session_roi in @invalid_max_idx and' \
        '(num_sessions > 1 or roicat_global_roi >= 0)'        
    )
    if len(roi_record) > 0:
        warnings.warn(
            'The following ROI records have invalid footprints, i.e. '\
            '(a) max = 0, (b) persist for more than 1 sessions or `roicat_global_roi=-1`:',
            roi_record.drop(columns='global_roi_color').to_dict('records')
        )
        
def select_session_rois(footprints, roi_table, session):
    session_iscell_rois = (
        roi_table.query('session == @session')
        .sort_values(by='session_roi')
        .reset_index(drop=True)
    )
    roi_colors = np.stack(session_iscell_rois['global_roi_color'].to_list())
    roi_iscell_idx = session_iscell_rois['session_roi'].to_list()

    footprints = footprints[roi_iscell_idx]
    assert footprints.shape[0] == len(roi_iscell_idx)
    
    return footprints, roi_colors 

def normalize_footprint(X):
    max_X = X.max(axis=1).toarray()
    X = X.multiply(1.0 / max_X)
    return X
    
def sparse2image(sparse_footprints, image_dims):
    image = (
        (sparse_footprints > 0)
        .T
        .reshape(*image_dims)
        .toarray()
    )
    return image

def sparse2contour(sparse_footprints, image_dims):
    image = (
        (sparse_footprints > 0)
        .T
        .reshape(*image_dims)
        .toarray()
    )
    contour = skimage.measure.find_contours(image)
    return contour

def color_footprint_one_channel(sparse_footprints, color_vec, image_dims):
    image = np.array(
        (sparse_footprints > 0)
        .multiply(color_vec)
        .sum(axis=0)
        .T
        .reshape(*image_dims)
    )
    return image

def color_footprint(sparse_footprints, color_matrix, image_dims):
    image = np.stack([
        color_footprint_one_channel(
            sparse_footprints,
            color_vec.reshape(-1,1),
            image_dims
        )
        for color_vec in color_matrix.T
    ], axis=-1)
    image = np.clip(image, a_min=0.0, a_max=1.0)
    return image

def fuse_rois_in_background(rois, background, alpha=0.5, background_max_percentile=99):
    background = (background - background.min()) / \
        (np.percentile(background, background_max_percentile) - background.min())
    background = np.clip(background, a_min=0, a_max=1.0)
    background = skimage.color.gray2rgb(background)
    
    fused_image = background * (1 - alpha) + rois * alpha
    fused_image = np.clip(fused_image, a_min=0, a_max=1.0)    
    return fused_image
 
def compute_session_fused_footprints(
    images, footprints, roi_table, session, 
    background_choice = 'max_proj',
    fuse_alpha = 0.5,
    fuse_bg_max_p = 99,
):
    background_images = images[session]
    if background_choice not in background_images:
        warnings.warn(f'{background_choice} not in the aligned images. Using "fov" field instead')
        assert 'fov' in background_images, '"fov" not found in aligned images'
        background_choice = 'fov'
    background_image = background_images[background_choice]
        
    sparse_footprints = footprints[session]
    throw_invalid_footprint_warning(sparse_footprints, roi_table, session)       
    
    sparse_footprints, roi_colors = select_session_rois(sparse_footprints, roi_table, session)
    sparse_footprints = normalize_footprint(sparse_footprints)
    colored_footprints = color_footprint(sparse_footprints, roi_colors, image_dims)
    fused_footprints = fuse_rois_in_background(
        colored_footprints,
        background_image,
        alpha=fuse_alpha, 
        background_max_percentile=fuse_bg_max_p
    )
    return fused_footprints

In [None]:
# fuse ROI and background
fuse_kwargs = dict(
    images=aligned_images,
    footprints=aligned_rois,
    roi_table=roi_table,
    **fuse_cfg
)

fused_footprints = [
    compute_session_fused_footprints(
        session=session_idx,
        **fuse_kwargs
    )
    for session_idx in range(num_sessions)
]

# get contours
roi_contours = []
for session_rois in aligned_rois:
    roi_contours.append([
        sparse2contour(roi_footprint, image_dims) 
        for roi_footprint in session_rois
    ])

# plot to see what they look like
# plt.figure(figsize=(20,10))
# for i in range(num_sessions):
#     plt.subplot(1,num_sessions,i+1)
#     plt.imshow(fused_footprints[i])
#     plt.axis('off')
# plt.show()

## Refine GUI

In [None]:
# TODO: set variables inside objects instead of modifying global objects
#       - roi_table
#       - aligned_images
#       - aligned_rois
#       - highlight_contour_kwargs
#       - request_contour_kwargs
#       - save_roi_file
# TODO: figure out how to optimize re-drawing bc very slow right now
# TODO: warn when an operation is not valid

highlight_contour_kwargs = dict(c='r', lw=2, alpha=0.8)
request_contour_kwargs = dict(c='#67a9cf', lw=2, alpha=0.8, ls='--')

if column_wrap is None:
    column_wrap = num_sessions
if column_wrap < 2:
    column_wrap = num_sessions
num_cols = min(num_sessions, column_wrap)

fig, axes = plt.subplots(
    int(np.ceil(num_sessions/num_cols)), num_cols,
    figsize=(20,10),
    sharex=True,
    sharey=True,
)
ax_list = axes.flatten()

for i in range(num_sessions):
    ax_list[i].imshow(fused_footprints[i])
    ax_list[i].set_axis_off()
    ax_list[i].set_title(f'session #{i+1}')
    ax_list[i].session_index = i
    
def find_tagged_objects(fig, value, key='tag'):
    objs = fig.findobj(
        match=lambda x:
            False if not hasattr(x, key) 
            else value in getattr(x, key)
    )
    return objs    

class MuseStateCallBacks:
    state = dict(
        chain_request = False,
        chain = False,
        unchain_request = False,
        # unchain = False,
    )
    selection = dict(
        session = None,
        session_roi = None,
        global_roi = None
    )
    chain_selection = dict(   
        session = None,
        session_roi = None,
        global_roi = None
    )
    
    def __init__(self, save_path):
        self.save_path = save_path
    
    def refuse_footprint(self):
        fuse_kwargs = dict(
            images=aligned_images,
            footprints=aligned_rois,
            roi_table=roi_table,
            **fuse_cfg
        )

        fused_footprints = [
            compute_session_fused_footprints(
                session=session_idx,
                **fuse_kwargs
            )
            for session_idx in range(num_sessions)
        ]
        
        
        for i in range(num_sessions):
            ax_list[i].imshow(fused_footprints[i])
        
        plt.draw()
    
    def clear(self, event):
        self.reset_state()
        clear_objs = find_tagged_objects(fig, value='highlight') +\
            find_tagged_objects(fig, value='chain_request')
        
        for x in clear_objs:
            x.remove()        
        for i, ax in enumerate(ax_list):
            ax.set_title(f'session #{i+1}')
        
    def reset_state(self):
        for k in self.state.keys():
            self.state[k] = False
    
    def unchain(self, unchain_session):
        print(self.state, self.selection)
        if not self.state['unchain_request'] or None in self.selection.values():
            self.reset_state()
            return
        unchain_session = int(unchain_session) - 1
        curr_global_roi = self.selection['global_roi']
        unchain_roi_idx = (
            roi_table
            .query('global_roi == @curr_global_roi and session == @unchain_session')
            .index.to_list()
        )
        if len(unchain_roi_idx) < 1:
            self.reset_state()
            return
        unchain_roi_idx = list(unchain_roi_idx)[0]
        if roi_table.loc[unchain_roi_idx, 'num_sessions'] < 2:
            self.reset_state()
            return
        next_global_roi = roi_table['global_roi'].max() + 1
        roi_table.loc[roi_table['global_roi'] == curr_global_roi, 'num_sessions'] -= 1
        roi_table.loc[unchain_roi_idx, 'global_roi'] = next_global_roi        
        roi_table.at[unchain_roi_idx, 'global_roi_color'] = np.random.rand(3)
        
        self.reset_state()
        self.refuse_footprint()
        
    def chain(self, event):
        print(self.state, self.selection)
        if (
            not self.state['chain_request'] or \
            None in self.selection.values()
        ):
            self.reset_state()
            return
        
        if None in self.chain_selection.values():
            return
        
        curr_glob_roi = self.selection['global_roi']
        curr_glob_roi_rows = roi_table.query('global_roi == @curr_glob_roi')
        
        chain_glob_roi = self.chain_selection['global_roi']
        chain_glob_roi_rows = roi_table.query('global_roi == @chain_glob_roi')
        
        if curr_glob_roi == chain_glob_roi:
            return
        
        print(curr_glob_roi_rows, chain_glob_roi_rows)
        shared_sessions = (
            set(curr_glob_roi_rows['session'].to_list())
            .intersection(chain_glob_roi_rows['session'].to_list())
        )
        union_sessions = (
            set(curr_glob_roi_rows['session'].to_list())
            .union(chain_glob_roi_rows['session'].to_list())
        )
        if len(shared_sessions) > 0:
            return
            
        min_glob_roi = min(curr_glob_roi, chain_glob_roi)
        glob_color = roi_table.query('global_roi == @min_glob_roi')['global_roi_color'].iloc[0]
        concat_idx = list(curr_glob_roi_rows.index) + list(chain_glob_roi_rows.index)
        roi_table.loc[concat_idx, 'num_sessions'] = len(union_sessions)
        roi_table.loc[concat_idx, 'global_roi'] = min_glob_roi
        for idx in concat_idx:
            roi_table.at[idx, 'global_roi_color'] = glob_color
            
        self.reset_state()
        self.refuse_footprint()
        
    def unchain_request(self, event):
        self.reset_state()
        if None in self.selection.values():
            return
        self.state['unchain_request'] = True
        
    def chain_request(self, event):
        self.refuse_footprint()
        if None in self.selection.values():
            return
        self.state['chain_request'] = True
        
    def onclick(self, event, tag='highlight'):      
        # get data
        ix, iy, ax = event.xdata, event.ydata, event.inaxes
        if not hasattr(ax, 'session_index'):
            return
        
        contour_kwargs = highlight_contour_kwargs.copy()
        if self.state['chain_request']:
            tag = 'chain_request'
            contour_kwargs = request_contour_kwargs.copy()

        # remove previous highlighted objects    
        for x in find_tagged_objects(fig, value=tag):
            x.remove()        
        session = ax.session_index

        # obtain session ROI index
        flat_idx = np.ravel_multi_index((round(iy),round(ix)), image_dims)
        session_roi = aligned_rois[session][:,flat_idx].nonzero()[0]
        if len(session_roi) == 0:
            return
        session_roi = session_roi[0] # just select first one if there are ovelap
        select_global_roi = (
            roi_table.query('session == @session and session_roi == @session_roi')
            ['global_roi'].to_list()
        )
        if len(select_global_roi) != 1:
            return
        select_global_roi = select_global_roi[0]
        
        if not (self.state['chain_request']):
            self.selection['session'] = session
            self.selection['session_roi'] = session_roi
            self.selection['global_roi'] = select_global_roi
        else:
            self.chain_selection['session'] = session
            self.chain_selection['session_roi'] = session_roi
            self.chain_selection['global_roi'] = select_global_roi
            
        # obtain contours
        select_contours = {
            r['session']: dict(
                contour = roi_contours[r['session']][r['session_roi']],
                **r,
            )
            for _, r in roi_table.query('global_roi == @select_global_roi').iterrows()
        }

        # plot contours
        for session, ax in enumerate(ax_list):
            if session not in select_contours:
                ax.set_title(f'session #{session+1} [NOT FOUND]')
                continue

            session_contours = select_contours[session]['contour']
            select_session_roi = select_contours[session]['session_roi']

            for c in session_contours:
                c_handles = ax.plot(c[:,1],c[:,0], **contour_kwargs)
                for ch in c_handles:
                    ch.tag = tag

            ax.set_title(f'session #{session+1} [id={select_session_roi} | ID={select_global_roi}]')

        plt.draw()
        
    def save(self, event):
        # currently avoid saving color columns
        # TODO: if a file exists, warn before saving
        color_columns = (
            roi_table
            .filter(regex='.*global_roi_color.*')
            .columns
            .to_list()
        )
        
        (
            roi_table
            .drop(columns=color_columns)
            .to_csv(
                self.save_path,
                index=False
            )
        )
        

muse_cb = MuseStateCallBacks(
    save_path = save_roi_file,
)
cid = fig.canvas.mpl_connect('button_press_event', muse_cb.onclick)

ax_buttons = dict(
    chain_request = fig.add_axes([0.10, 0.05, 0.08, 0.05]),
    chain = fig.add_axes([0.20, 0.05, 0.08, 0.05]),
    unchain_request = fig.add_axes([0.40, 0.05, 0.08, 0.05]),
    clear = fig.add_axes([0.75, 0.05, 0.08, 0.05]),
    save = fig.add_axes([0.9, 0.05, 0.08, 0.05]),
)

buttons = {
    k: Button(v, k.replace('_', ' ').title())
    for k, v in ax_buttons.items()
}
buttons['clear'].on_clicked(muse_cb.clear)
buttons['chain_request'].on_clicked(muse_cb.chain_request)
buttons['chain'].on_clicked(muse_cb.chain)
buttons['unchain_request'].on_clicked(muse_cb.unchain_request)
buttons['save'].on_clicked(muse_cb.save)

ax_text = fig.add_axes([0.52, 0.05, 0.05, 0.05])
text_box = TextBox(ax_text, f"Session \n (1-{num_sessions}) ", textalignment="left")
text_box.on_submit(muse_cb.unchain)

multi = MultiCursor(None, ax_list, color='r', lw=0.5, horizOn=True, vertOn=True)

plt.show()

In [None]:
!ls $muse_dir