# Inspect multisession outputs

This is currently being tested to inspect the output of processed `roicat` outputs.

This currently uses `tk` and cannot run on Oscar headless.

Note: this is now superseded by `refine-muse.ipynb`

In [None]:
import os
import glob
import warnings

from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import skimage
import matplotlib.pyplot as plt

%matplotlib tk

In [None]:
muse_dir = Path('data/SD_0664/multi-session/plane1/')
roicat_out_file = muse_dir / 'roicat-output.pkl'
aligned_img_file = muse_dir / 'aligned-images.pkl'
roi_table_file = muse_dir / 'finalized-roi.csv'

In [None]:
fuse_cfg = dict(
    background_choice = 'max_proj', # max_proj, meanImg, meanImgE, Vcorr
    fuse_alpha = 0.4, # fusing ROI mask with background
    fuse_bg_max_p = 98, # for clipping background max value with percentile
)

In [None]:
!ls $muse_dir

In [None]:
with open(roicat_out_file, 'rb') as f:
    roicat_out = pickle.load(f)

with open(aligned_img_file, 'rb') as f:
    aligned_images = pickle.load(f)

aligned_rois = roicat_out['results']['ROIs']['ROIs_aligned']

# TODO: unclear if this is the right order
image_dims = (
    roicat_out['results']['ROIs']['frame_width'],
    roicat_out['results']['ROIs']['frame_height']
)
roi_table = pd.read_csv(roi_table_file)

num_sessions = len(aligned_rois)
assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))

num_global_rois = roi_table['global_roi'].nunique()

In [None]:
global_roi_colors = np.random.rand(num_global_rois, 3).round(3)

roi_table['global_roi_color'] = roi_table['global_roi'].apply(
    lambda x: global_roi_colors[x]
)
roi_table

In [None]:
def throw_invalid_footprint_warning(footprints, roi_table, session):
    invalid_max_idx = footprints.max(axis=1).toarray().squeeze() <= 0
    if sum(invalid_max_idx) == 0:
        return
    
    invalid_max_idx = np.where(invalid_max_idx)[0]
    
    roi_record = roi_table.query(
        'session == @session and ' \
        'session_roi in @invalid_max_idx and' \
        '(num_sessions > 1 or roicat_global_roi >= 0)'        
    )
    if len(roi_record) > 0:
        warnings.warn(
            'The following ROI records have invalid footprints, i.e. '\
            '(a) max = 0, (b) persist for more than 1 sessions or `roicat_global_roi=-1`:',
            roi_record.drop(columns='global_roi_color').to_dict('records')
        )
        
def select_session_rois(footprints, roi_table, session):
    session_iscell_rois = (
        roi_table.query('session == @session')
        .sort_values(by='session_roi')
        .reset_index(drop=True)
    )
    roi_colors = np.stack(session_iscell_rois['global_roi_color'].to_list())
    roi_iscell_idx = session_iscell_rois['session_roi'].to_list()

    footprints = footprints[roi_iscell_idx]
    assert footprints.shape[0] == len(roi_iscell_idx)
    
    return footprints, roi_colors 

def normalize_footprint(X):
    max_X = X.max(axis=1).toarray()
    X = X.multiply(1.0 / max_X)
    return X
    
def sparse2image(sparse_footprints, image_dims):
    image = (
        (sparse_footprints > 0)
        .T
        .reshape(*image_dims)
        .toarray()
    )
    return image

def sparse2contour(sparse_footprints, image_dims):
    image = (
        (sparse_footprints > 0)
        .T
        .reshape(*image_dims)
        .toarray()
    )
    contour = skimage.measure.find_contours(image)
    return contour

def color_footprint_one_channel(sparse_footprints, color_vec, image_dims):
    image = np.array(
        (sparse_footprints > 0)
        .multiply(color_vec)
        .sum(axis=0)
        .T
        .reshape(*image_dims)
    )
    return image

def color_footprint(sparse_footprints, color_matrix, image_dims):
    image = np.stack([
        color_footprint_one_channel(
            sparse_footprints,
            color_vec.reshape(-1,1),
            image_dims
        )
        for color_vec in color_matrix.T
    ], axis=-1)
    image = np.clip(image, a_min=0.0, a_max=1.0)
    return image

def fuse_rois_in_background(rois, background, alpha=0.5, background_max_percentile=99):
    background = (background - background.min()) / \
        (np.percentile(background, background_max_percentile) - background.min())
    background = np.clip(background, a_min=0, a_max=1.0)
    background = skimage.color.gray2rgb(background)
    
    fused_image = background * (1 - alpha) + rois * alpha
    fused_image = np.clip(fused_image, a_min=0, a_max=1.0)    
    return fused_image
 
def compute_session_fused_footprints(
    images, footprints, roi_table, session, 
    background_choice = 'max_proj',
    fuse_alpha = 0.5,
    fuse_bg_max_p = 99,
):
    background_images = images[session]
    if background_choice not in background_images:
        warnings.warn(f'{background_choice} not in the aligned images. Using "fov" field instead')
        assert 'fov' in background_images, '"fov" not found in aligned images'
        background_choice = 'fov'
    background_image = background_images[background_choice]
        
    sparse_footprints = footprints[session]
    throw_invalid_footprint_warning(sparse_footprints, roi_table, session)       
    
    sparse_footprints, roi_colors = select_session_rois(sparse_footprints, roi_table, session)
    sparse_footprints = normalize_footprint(sparse_footprints)
    colored_footprints = color_footprint(sparse_footprints, roi_colors, image_dims)
    fused_footprints = fuse_rois_in_background(
        colored_footprints,
        background_image,
        alpha=fuse_alpha, 
        background_max_percentile=fuse_bg_max_p
    )
    return fused_footprints

In [None]:
fuse_kwargs = dict(
    images=aligned_images,
    footprints=aligned_rois,
    roi_table=roi_table,
    **fuse_cfg
)

fused_footprints = [
    compute_session_fused_footprints(
        session=session_idx,
        **fuse_kwargs
    )
    for session_idx in range(num_sessions)
]

In [None]:
roi_contours = []
for session_rois in aligned_rois:
    roi_contours.append([
        sparse2contour(roi_footprint, image_dims) 
        for roi_footprint in session_rois
    ])

In [None]:
plt.figure(figsize=(20,10))
for i in range(num_sessions):
    plt.subplot(1,num_sessions,i+1)
    plt.imshow(fused_footprints[i])
    plt.axis('off')
plt.show()

In [None]:
%matplotlib tk

contour_kwargs = dict(c='r', lw=2, alpha=0.8)
num_cols = min(num_sessions, 6)


fig, axes = plt.subplots(
    int(np.ceil(num_sessions/num_cols)), num_cols,
    figsize=(20,10),
    sharex=True,
    sharey=True,
)
ax_list = axes.flatten()

for i in range(num_sessions):
    ax_list[i].imshow(fused_footprints[i])
    ax_list[i].set_axis_off()
    ax_list[i].set_title(f'session #{i+1}')
    ax_list[i].session_index = i
    
def find_tagged_objects(fig, value, key='tag'):
    objs = fig.findobj(
        match=lambda x:
            False if not hasattr(x, key) 
            else value in getattr(x, key)
    )
    return objs    

    
def onclick(event, tag='highlight'):        
    # remove previous highlighted objects
    for x in find_tagged_objects(fig, value=tag):
        x.remove()
        
    # double click: reset title
    # not working, TBD
    # if event.dblclick:
    #     for ax in ax_list:
    #         ax.set_title(f'session #{i+1}')
    #     return
        
    # get data
    ix, iy, ax = event.xdata, event.ydata, event.inaxes
    session = ax.session_index
    
    # obtain session ROI index
    flat_idx = np.ravel_multi_index((round(iy),round(ix)), image_dims)
    session_roi = aligned_rois[session][:,flat_idx].nonzero()[0]
    if len(session_roi) == 0:
        return
    session_roi = session_roi[0] # just select first one if there are ovelap
    select_global_roi = (
        roi_table.query('session == @session and session_roi == @session_roi')
        ['global_roi'].to_list()
    )
    assert len(select_global_roi) == 1
    select_global_roi = select_global_roi[0]
    
    # obtain contours
    select_contours = {
        r['session']: dict(
            contour = roi_contours[r['session']][r['session_roi']],
            **r,
        )
        for _, r in roi_table.query('global_roi == @select_global_roi').iterrows()
    }
    
    # plot contours
    for session, ax in enumerate(ax_list):
        if session not in select_contours:
            ax.set_title(f'session #{session+1} [NOT FOUND]')
            continue
        
        session_contours = select_contours[session]['contour']
        select_session_roi = select_contours[session]['session_roi']
        
        for c in session_contours:
            c_handles = ax.plot(c[:,1],c[:,0], **contour_kwargs)
            for ch in c_handles:
                ch.tag = tag
        
        ax.set_title(f'session #{session+1} [id={select_session_roi} | ID={select_global_roi}]')
        
    plt.show()
    
cid = fig.canvas.mpl_connect('button_press_event', onclick)

plt.show()