In [1]:
# %matplotlib widget
import numpy as np
import os
import sys
import h5py
import cv2
import importlib
import holoviews as hv
hv.extension('bokeh')

import pandas as pd
import matplotlib.pyplot as plt
from caiman.base.rois import register_multisession
from skimage import color

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import processing_parameters
import functions_bondjango as bd
import functions_misc as fm
import functions_data_handling as fdh

In [106]:
def get_footprint_centroids(calcium_data):
    cents = []
    for cell in calcium_data:
        new_cell = cell.copy()
        new_cell[new_cell > 0] == 1 
        M = cv2.moments(new_cell)
        
        # centroid calciulation
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        cents.append((cX, cY))
    return cents

def get_footprint_contours(calcium_data):
    contour_list = []
    contour_stats = []
    for frame in calcium_data:
        frame = frame * 255.
        frame = frame.astype(np.uint8)
        thresh = cv2.threshold(frame, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
        
        # get contours and filter out small defects
        contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # Only take the top-level contour
        cntr = contours[0]
        area = cv2.contourArea(cntr)
        perimeter = cv2.arcLength(cntr, True)
        compactness = 4*np.pi*area / (perimeter + 1e-16)**2
        
        contour_list.append(cntr)
        contour_stats.append((area, perimeter, compactness))

    return contour_list, np.array(contour_stats)

def make_contour_projection(contour_list, shape):
    contour_img = np.zeros(shape)
    for i, cntr in enumerate(contour_list):
        cv2.drawContours(contour_img[i, :], [cntr], 0, 1, 1)

    contour_img = np.sum(contour_img, axis=0)
    contour_img[contour_img > 0] = 1.0
    return contour_img

def get_binary_footprints(footprint_pic, threshold=0.1):
    bin_pic = np.zeros_like(footprint_pic)
    bin_pic[footprint_pic >= threshold] = 1
    return bin_pic

def make_rgb_overlay(max_proj, footprints, contour_img, channel='r'):
    max_proj -= max_proj.min()
    max_proj /= max_proj.max()

    # Make RGB max projection
    max_proj_rgb = np.dstack((max_proj, max_proj, max_proj))

    # mak RGB footprint image
    footprint_rgb = np.zeros((*max_proj.shape, 3))

    footprints /= footprints.max()
    if channel == 'r':
        footprint_rgb[:, :, 0] = footprints
    elif channel == 'g':
        footprint_rgb[:, :, 1] = footprints
    elif channel == 'b':
        footprint_rgb[:, :, 2] = footprints
    else:
        raise ValueError('channel must be r, g, or b')
    
    footprint_rgb[:] += np.expand_dims(contour_img, -1).astype(float)

    # Convert RGB max proj anf RGB footprints to HSV colorspace
    max_proj_hsv = color.rgb2hsv(max_proj_rgb)
    footprint_mask_hsv = color.rgb2hsv(footprint_rgb)

    # Overlay the footprint mask on the max projection
    max_proj_hsv[..., 0] = footprint_mask_hsv[..., 0]
    max_proj_hsv[..., 1] = footprint_mask_hsv[..., 1] * 0.6

    # Return to RGB colorspace
    overlay = color.hsv2rgb(max_proj_hsv)
    return overlay

def hv_plot_matching(rigs, overlay_binary_footprints, centroids, match_overlay_binary_footprints, match_centroids, unmatch_overlay_binary_footprints, unmatch_centroids):
    binary_images = []
    match_images = []
    unmatch_images = []
    for i, (rig, bin_pic, cents, match_bin_pic, match_cents, unmatch_bin_pic, unmatch_cents) in enumerate(zip(rigs, overlay_binary_footprints, centroids, 
                                                                              match_overlay_binary_footprints, match_centroids,
                                                                              unmatch_overlay_binary_footprints, unmatch_centroids)):
        
        # Plot all binarized ROIS with contours
        binary_image = hv.RGB(bin_pic.astype(float), bounds=(0, 0, 320, 320)).opts(title=rig)
        for num, cent in enumerate(cents):
            binary_image = binary_image * hv.Text(cent[0], 320-cent[1], str(num), fontsize=8).opts(color='white')
        binary_images.append(binary_image)

        # Plot only matched binarized ROIS with contours
        match_image = hv.RGB(match_bin_pic.astype(float), bounds=(0, 0, 320, 320))
        for num, cent in enumerate(match_cents):
            match_image = match_image * hv.Text(cent[0], 320-cent[1], str(assignments_filtered[num, i]), fontsize=8).opts(color='white')
        match_images.append(match_image)

        # Plot only matched binarized ROIS with contours
        unmatch_image = hv.RGB(unmatch_bin_pic.astype(float), bounds=(0, 0, 320, 320))
        for num, cent in enumerate(unmatch_cents):
            unmatch_image = unmatch_image * hv.Text(cent[0], 320-cent[1], str(unassigned[i][num]), fontsize=8).opts(color='white')
        unmatch_images.append(unmatch_image)

    # Plot contour overlay for all cells
    binary_overlay = hv.RGB(np.dstack((countour_pics[0], np.zeros_like(countour_pics[0]), countour_pics[1])), bounds=(0, 0, 320, 320)).opts(title='Overlay')
    matched_overlay = hv.RGB(np.dstack((match_binary_contour_projs[0], np.zeros_like(match_binary_contour_projs[0]), match_binary_contour_projs[1])), bounds=(0, 0, 320, 320))
    unmatched_overlay = hv.RGB(np.dstack((unmatch_binary_contour_projs[0], np.zeros_like(unmatch_binary_contour_projs[0]), unmatch_binary_contour_projs[1])), bounds=(0, 0, 320, 320))

    layout = hv.Layout(binary_images[0] + binary_images[1] + binary_overlay + 
                       match_images[0] + match_images[1] + matched_overlay +
                       unmatch_images[0] + unmatch_images[1] + unmatched_overlay).cols(3).opts(hv.opts.RGB(tools=['hover'], xlabel=None, ylabel=None, xaxis=None, yaxis=None))
    return layout

In [102]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string + r", analysistype:calciumraw"
parsed_search = parsed_search = fdh.parse_search_string(search_string)

# get the paths from the database
file_infos = bd.query_database('analyzed_data', search_string)
calcium_paths = np.array([el['analysis_path'] for el in file_infos if ('calciumraw' in el['slug']) and
                            (parsed_search['mouse'].lower() in el['slug'])])
cell_matching_path = [el['analysis_path'] for el in file_infos if ('daycellmatch' in el['slug']) and
                            (parsed_search['mouse'].lower() in el['slug'])]
rigs = np.array([el['rig'] for el in file_infos if ('calciumraw' in el['slug']) and 
                 (parsed_search['mouse'].lower() in el['slug'])])
print(calcium_paths)
print(cell_matching_path)

['Z:\\Prey_capture\\AnalyzedData\\07_24_2023_15_41_45_VTuningWF_MM_230518_b_control_dark_free0_nogabor_calciumraw.hdf5'
 'Z:\\Prey_capture\\AnalyzedData\\07_24_2023_16_17_48_VWheelWF_MM_230518_b_control_dark_fixed1_nogabor_calciumraw.hdf5']
['Z:\\Prey_capture\\AnalyzedData\\07_24_2023_MM_230518_b_dayCellMatch.hdf5']


In [103]:
# load the data for the matching
calcium_list = []
max_proj_list = []
footprint_list = []
contour_list = []
size_list = []
template_list = []
footprint_pics = []
countour_pics = []
centroids_list = []
binary_footprints = []
overlay_footprints = []
overlay_binary_footprints = []

# load the calcium data
for files, channel in zip(calcium_paths, ['r', 'b']):

    with h5py.File(files, mode='r') as f:

        try:
            calcium_data = np.array(f['A'])
            max_proj = np.array(f['max_proj'])     

        except KeyError:
            continue

    # if there are no ROIs, skip
    if (type(calcium_data) == np.ndarray) and np.any(calcium_data.astype(str) == 'no_ROIs'):
        continue
        
    # clear the rois that don't pass the size or compactness criteria
    roi_stats = fm.get_roi_stats(calcium_data)
    contours, contour_stats = get_footprint_contours(calcium_data)

    if len(roi_stats.shape) == 1:
        roi_stats = roi_stats.reshape(1, -1)
        contour_stats = contour_stats.reshape(1, -1)

    areas = roi_stats[:, -1]
    compactness = contour_stats[:, -1]

    keep_vector = (areas > processing_parameters.roi_parameters['area_min']) & \
                  (areas < processing_parameters.roi_parameters['area_max']) & \
                  (compactness > 0.5)

    if np.all(keep_vector == False):
        continue

    calcium_data = calcium_data[keep_vector, :, :]
    contours = [contours[i] for i, keep in enumerate(keep_vector) if keep]

    centroids = get_footprint_centroids(calcium_data)
    footprint_proj = np.sum(calcium_data, axis=0)
    binary_footprint_proj = get_binary_footprints(footprint_proj)
    contour_proj = make_contour_projection(contours, calcium_data.shape)
    
    # format and masks and store for matching
    calcium_list.append(calcium_data)
    footprint_list.append(np.moveaxis(calcium_data, 0, -1).reshape((-1, calcium_data.shape[0])))
    contour_list.append(contours)
    countour_pics.append(contour_proj)

    size_list.append(calcium_data.shape[1:])
    template_list.append(max_proj)
    max_proj_list.append((max_proj - max_proj.min())/ max_proj.max())
    footprint_pics.append(footprint_proj)
    binary_footprints.append(binary_footprint_proj)
    centroids_list.append(centroids)

    overlay_footprints.append(make_rgb_overlay(max_proj, footprint_proj, contour_proj, channel=channel))
    overlay_binary_footprints.append(make_rgb_overlay(max_proj, binary_footprint_proj, contour_proj, channel=channel))

# Run the  registration and filter the matched cells
spatial_union, assignments, matchings = register_multisession(A=footprint_list, dims=size_list[0], templates=template_list, 
                                                              align_flag=True, use_opt_flow=True, max_thr=0.1, thresh_cost=0.8, max_dist=8)

n_reg = 2  # minimal number of sessions that each component has to be registered in

# Use number of non-NaNs in each row to filter out components that were not registered in enough sessions
assignments_filtered = np.array(assignments[np.sum(~np.isnan(assignments), axis=1) >= n_reg], dtype=int)
unassigned = np.array(assignments[np.sum(~np.isnan(assignments), axis=1) < n_reg])
unassigned = [unassigned[~np.isnan(unassigned[:, 0]), 0].astype(int), unassigned[~np.isnan(unassigned[:,1]), 1].astype(int)]
unassigned = [np.sort(np.unique(unassigned[0])), np.sort(np.unique(unassigned[1]))]

# Use filtered indices to select the corresponding spatial components
spatial_filtered = footprint_list[0][:, assignments_filtered[:, 0]]
matched_footprints = np.sum(spatial_filtered.reshape(320, 320, spatial_filtered.shape[-1]), axis=-1)

# Filter footprints and contours based on the matching
match_ca1 = calcium_list[0][assignments_filtered[:,0], :]
match_ca2 = calcium_list[1][assignments_filtered[:,1], :]

match_footprint_projs = [np.sum(match_ca1, axis=0),  np.sum(match_ca2, axis=0)]
match_binary_footprint_projs = [get_binary_footprints(fp_proj) for fp_proj in match_footprint_projs]
match_centroids = [get_footprint_centroids(match_ca1), get_footprint_centroids(match_ca2)]

match_contours1, _ = get_footprint_contours(match_ca1)
match_contours2, _ = get_footprint_contours(match_ca2)
match_contour_proj1 = make_contour_projection(match_contours1, match_ca1.shape)
match_contour_proj2 = make_contour_projection(match_contours2, match_ca2.shape)
match_binary_contour_projs = [match_contour_proj1, match_contour_proj2]
match_contours = [match_contours1, match_contours2]

match_overlay_binary_footprints = []
for i, channel in enumerate(['r', 'b']):
    match_overlay_binary_footprints.append(make_rgb_overlay(max_proj_list[i], match_binary_footprint_projs[i], match_binary_contour_projs[i], channel=channel))

# Filter unmatched footprints and contours based on the matching
unmatch_ca1 = calcium_list[0][unassigned[0], :]
unmatch_ca2 = calcium_list[1][unassigned[1], :]

unmatch_footprint_projs = [np.sum(unmatch_ca1, axis=0),  np.sum(unmatch_ca2, axis=0)]
unmatch_binary_footprint_projs = [get_binary_footprints(fp_proj) for fp_proj in unmatch_footprint_projs]
unmatch_centroids = [get_footprint_centroids(unmatch_ca1), get_footprint_centroids(unmatch_ca2)]

unmatch_contours1, _ = get_footprint_contours(unmatch_ca1)
unmatch_contours2, _ = get_footprint_contours(unmatch_ca2)
unmatch_contour_proj1 = make_contour_projection(unmatch_contours1, unmatch_ca1.shape)
unmatch_contour_proj2 = make_contour_projection(unmatch_contours2, unmatch_ca2.shape)
unmatch_binary_contour_projs = [unmatch_contour_proj1, unmatch_contour_proj2]
unmatch_contours = [unmatch_contours1, unmatch_contours2]

unmatch_overlay_binary_footprints = []
for i, channel in enumerate(['r', 'b']):
    unmatch_overlay_binary_footprints.append(make_rgb_overlay(max_proj_list[i], unmatch_binary_footprint_projs[i], unmatch_binary_contour_projs[i], channel=channel))

  self._set_arrayXarray(i, j, x)


In [105]:
match_plot = hv_plot_matching(rigs, overlay_binary_footprints, centroids_list, match_overlay_binary_footprints, match_centroids, unmatch_overlay_binary_footprints, unmatch_centroids)
match_plot.opts(hv.opts.RGB(xlabel=None, ylabel=None, xaxis=None, yaxis=None))
match_plot

In [None]:
figure_save_path = r"C:\Users\mmccann\Dropbox\bonhoeffer lab\SFN 2023\poster"
fig.savefig(os.path.join(figure_save_path, 'Fig3', 'cell_matching.png'), dpi=600, format='png')

In [97]:
centroids_list[0]

[(5, 1),
 (218, 103),
 (262, 117),
 (251, 117),
 (234, 125),
 (224, 146),
 (230, 155)]