In [1]:
%load_ext autoreload
%autoreload 2
# %matplotlib widget
import numpy as np
import os
import sys
import h5py
import cv2
import importlib
import holoviews as hv
hv.extension('bokeh')

import pandas as pd
import matplotlib.pyplot as plt
from skimage import color

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import processing_parameters
import functions_bondjango as bd
import functions_misc as fmisc
import functions_matching as fm
import functions_data_handling as fdh
import functions_tuning as tuning
import functions_plotting as fp
from wirefree_experiment import WirefreeExperiment, DataContainer
from functions_wirefree_trigger_fix import get_trial_duration_stats

fig_path = r"C:\Users\mmccann\Dropbox\bonhoeffer lab\conferences\senses_in_motion_24\figure_media"

In [2]:
def get_footprint_centroids(calcium_data):
    cents = []
    for cell in calcium_data:
        new_cell = cell.copy()
        new_cell[new_cell > 0] == 1 
        M = cv2.moments(new_cell)
        
        # centroid calciulation
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        cents.append([cX, cY])
    return cents

def get_footprint_contours(calcium_data):
    contour_list = []
    contour_stats = []
    for frame in calcium_data:
        frame = frame * 255.
        frame = frame.astype(np.uint8)
        thresh = cv2.threshold(frame, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]

        # # get contours and filter out small defects
        contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
        # Only take the largest contour
        cntr = max(contours, key=cv2.contourArea)
        area = cv2.contourArea(cntr)
        perimeter = cv2.arcLength(cntr, True)
        compactness = 4*np.pi*area / (perimeter + 1e-16)**2
        
        contour_list.append(cntr)
        contour_stats.append((area, perimeter, compactness))

    return contour_list, np.array(contour_stats)

def make_contour_projection(contour_list, shape, threshold=0.1):
    contour_img = np.zeros(shape)
    for i, cntr in enumerate(contour_list):
        cv2.drawContours(contour_img[i, :], [cntr], 0, 1, 1)

    contour_img = np.sum(contour_img, axis=0)
    contour_img[contour_img > threshold] = 1.0
    return contour_img

def get_binary_footprints(footprint_pic, threshold=0.1):
    bin_pic = np.zeros_like(footprint_pic)
    bin_pic[footprint_pic > threshold] = 1
    return bin_pic

def make_rgb_overlay(max_proj, footprints, contour_img, channel='r'):
    max_proj -= max_proj.min()
    max_proj /= max_proj.max()

    # Make RGB max projection
    max_proj_rgb = np.dstack((max_proj, max_proj, max_proj))

    # mak RGB footprint image
    footprint_rgb = np.zeros((*max_proj.shape, 3))

    footprints /= footprints.max()
    if channel == 'r':
        footprint_rgb[:, :, 0] = footprints
    elif channel == 'g':
        footprint_rgb[:, :, 1] = footprints
    elif channel == 'b':
        footprint_rgb[:, :, 2] = footprints
    else:
        raise ValueError('channel must be r, g, or b')
    
    # Convert RGB max proj and RGB footprints to HSV colorspace
    max_proj_hsv = color.rgb2hsv(max_proj_rgb)
    footprint_mask_hsv = color.rgb2hsv(footprint_rgb)

    # Overlay the footprint mask on the max projection
    max_proj_hsv[..., 0] = footprint_mask_hsv[..., 0]
    max_proj_hsv[..., 1] = footprint_mask_hsv[..., 1]

    # Return to RGB colorspace
    overlay = color.hsv2rgb(max_proj_hsv)
    overlay[:] += np.expand_dims(contour_img, -1).astype(float)

    return overlay

def hv_plot_FOVs(rigs, binary_footprints, contour_images, alpha=0, labels=None, overlay=True):
    binary_images = []

    for i, (rig, bin_pic) in enumerate(zip(rigs, binary_footprints)):
        # Plot all binarized ROIS with contours
        alpha_mask = np.ones_like(bin_pic) * alpha
        bin_pic = np.dstack((bin_pic, alpha_mask))
        binary_image = hv.RGB(bin_pic.astype(float), bounds=(0, 0, 320, 320)).opts(title=rig)
        
        if labels is not None:
            cents = labels[i][:, :2].copy()
            cents[:,1] = 320 - cents[:,1]
            label = labels[i][:, -1]
            label_plot = hv.Labels({('x', 'y'): cents, 'text': label}, ['x', 'y'], 'text').opts(text_color='white', xoffset=0.0, yoffset=0.0, text_font_size='8pt')
            binary_image = binary_image * label_plot

        binary_images.append(binary_image)

    if overlay:
        binary_overlay = hv.RGB(np.dstack((contour_images[1], np.zeros_like(contour_images[0]), contour_images[0])), bounds=(0, 0, 320, 320)).opts(title='Overlay')
        binary_images.append(binary_overlay)
        
    layout = hv.Layout(binary_images).cols(len(binary_images))

    return layout

In [24]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string
parsed_search = fdh.parse_search_string(search_string)

# get the paths from the database
file_infos = bd.query_database('analyzed_data', search_string)
preproc_paths = np.sort(np.array([el['analysis_path'] for el in file_infos if (el['analysis_type'] == 'preprocessing') and
                         (parsed_search['mouse'].lower() in el['slug'])]))
calcium_paths = np.sort(np.array([el['analysis_path'].replace('preproc', 'calciumraw') for el in file_infos if (el['analysis_type'] == 'preprocessing') and
                         (parsed_search['mouse'].lower() in el['slug'])]))
tc_paths = np.sort(np.array([el['analysis_path'] for el in file_infos if (el['analysis_type'] == 'tc_analysis') and
                         (parsed_search['mouse'].lower() in el['slug'])]))
cell_matching_path = [el['analysis_path'] for el in file_infos if ('daycellmatch' in el['slug']) and
                            (parsed_search['mouse'].lower() in el['slug'])]
rigs = np.array([os.path.basename(file).split('_')[6] for file in calcium_paths])
print(cell_matching_path)
print(calcium_paths)
print(preproc_paths)
print(tc_paths)
print(rigs)

In [25]:
# Load the assignments and find the column that corresponds to each file
assignments =  fm.match_cells(cell_matching_path[0])
new_cols = [col.split('_')[-2] for col in assignments.columns]
assignments.columns = new_cols
match_cols = [i for i, (rig, col) in enumerate(zip(rigs, assignments.columns)) if str(col) in rig]

# Use number of non-NaNs in each row to filter out components that were not registered in enough sessions
assignments_filtered = assignments.dropna().astype(int).to_numpy()
unassigned = np.array(assignments[np.sum(~np.isnan(assignments), axis=1) < 2])
unassigned = [unassigned[~np.isnan(unassigned[:, 0]), 0].astype(int), unassigned[~np.isnan(unassigned[:, 1]), 1].astype(int)]
unassigned = [np.sort(np.unique(unassigned[0])), np.sort(np.unique(unassigned[1]))]

In [26]:
# load the data for the matching
calcium_list = []
max_proj_list = []
footprint_list = []
contour_list = []
size_list = []
template_list = []
footprint_pics = []
countour_pics = []
centroids_list = []
binary_footprints = []
overlay_footprints = []
overlay_binary_footprints = []


# load the calcium data
for files, channel in zip(calcium_paths, ['b', 'r']):

    with h5py.File(files, mode='r') as f:

        try:
            calcium_data = np.array(f['A'])
            max_proj = np.array(f['max_proj'])     

        except KeyError:
            continue

    # if there are no ROIs, skip
    if (type(calcium_data) == np.ndarray) and np.any(calcium_data.astype(str) == 'no_ROIs'):
        continue
        
    # clear the rois that don't pass the size or compactness criteria
    roi_stats = fmisc.get_roi_stats(calcium_data)
    contours, contour_stats = get_footprint_contours(calcium_data)

    if len(roi_stats.shape) == 1:
        roi_stats = roi_stats.reshape(1, -1)
        contour_stats = contour_stats.reshape(1, -1)

    areas = roi_stats[:, -1]
    compactness = contour_stats[:, -1]

    keep_vector = (areas > processing_parameters.roi_parameters['area_min']) & \
                  (areas < processing_parameters.roi_parameters['area_max']) & \
                  (compactness > 0.5)

    if np.all(keep_vector == False):
        continue

    calcium_data = calcium_data[keep_vector, :, :]
    contours = [contours[i] for i, keep in enumerate(keep_vector) if keep]

    centroids = get_footprint_centroids(calcium_data)
    footprint_proj = np.sum(calcium_data, axis=0)
    binary_footprint_proj = get_binary_footprints(footprint_proj)
    contour_proj = make_contour_projection(contours, calcium_data.shape, threshold=0.5)
    
    # format and masks and store for matching
    calcium_list.append(calcium_data)
    footprint_list.append(np.moveaxis(calcium_data, 0, -1).reshape((-1, calcium_data.shape[0])))
    contour_list.append(contours)
    countour_pics.append(contour_proj)

    size_list.append(calcium_data.shape[1:])
    template_list.append(max_proj)
    max_proj_list.append((max_proj - max_proj.min())/ max_proj.max())
    footprint_pics.append(footprint_proj)
    binary_footprints.append(binary_footprint_proj)
    centroids_list.append(np.array(centroids))

    overlay_footprints.append(make_rgb_overlay(max_proj, footprint_proj, contour_proj, channel=channel))
    overlay_binary_footprints.append(make_rgb_overlay(max_proj, binary_footprint_proj, contour_proj, channel=channel))

# Filter footprints and contours based on the matching
match_ca1 = calcium_list[0][assignments_filtered[:, match_cols[0]], :]
match_ca2 = calcium_list[1][assignments_filtered[:, match_cols[1]], :]

match_footprint_projs = [np.sum(match_ca1, axis=0),  np.sum(match_ca2, axis=0)]
match_binary_footprint_projs = [get_binary_footprints(fp_proj) for fp_proj in match_footprint_projs]
match_centroids = [np.array(get_footprint_centroids(match_ca1)), np.array(get_footprint_centroids(match_ca2))]

match_contours1, _ = get_footprint_contours(match_ca1)
match_contours2, _ = get_footprint_contours(match_ca2)
match_contour_proj1 = make_contour_projection(match_contours1, match_ca1.shape)
match_contour_proj2 = make_contour_projection(match_contours2, match_ca2.shape)
match_binary_contour_projs = [match_contour_proj1, match_contour_proj2]
match_contours = [match_contours1, match_contours2]

match_overlay_binary_footprints = []
for i, channel in enumerate(['b', 'r']):
    match_overlay_binary_footprints.append(make_rgb_overlay(max_proj_list[i], match_binary_footprint_projs[i], match_binary_contour_projs[i], channel=channel))

# Filter unmatched footprints and contours based on the matching
unmatch_ca1 = calcium_list[0][unassigned[match_cols[0]], :]
unmatch_ca2 = calcium_list[1][unassigned[match_cols[1]], :]

unmatch_footprint_projs = [np.sum(unmatch_ca1, axis=0),  np.sum(unmatch_ca2, axis=0)]
unmatch_binary_footprint_projs = [get_binary_footprints(fp_proj) for fp_proj in unmatch_footprint_projs]
unmatch_centroids = [np.array(get_footprint_centroids(unmatch_ca1)), np.array(get_footprint_centroids(unmatch_ca2))]

unmatch_contours1, _ = get_footprint_contours(unmatch_ca1)
unmatch_contours2, _ = get_footprint_contours(unmatch_ca2)
unmatch_contour_proj1 = make_contour_projection(unmatch_contours1, unmatch_ca1.shape)
unmatch_contour_proj2 = make_contour_projection(unmatch_contours2, unmatch_ca2.shape)
unmatch_binary_contour_projs = [unmatch_contour_proj1, unmatch_contour_proj2]
unmatch_contours = [unmatch_contours1, unmatch_contours2]

unmatch_overlay_binary_footprints = []
for i, channel in enumerate(['b', 'r']):
    unmatch_overlay_binary_footprints.append(make_rgb_overlay(max_proj_list[i], unmatch_binary_footprint_projs[i], unmatch_binary_contour_projs[i], channel=channel))

In [30]:
all_cell_labels = [np.concatenate((centroids, np.arange(centroids.shape[0]).reshape(-1,1)), axis=1) for centroids in centroids_list]
all_cells = hv_plot_FOVs(rigs, overlay_binary_footprints, countour_pics, overlay=True, labels=all_cell_labels)

matched_cell_labels = [np.concatenate((match_centroids[i], assignments_filtered[:, i].reshape(-1,1)), axis=1) for i in np.arange(len(match_centroids))]
match_cells = hv_plot_FOVs(rigs, match_overlay_binary_footprints, match_binary_contour_projs, 
                           overlay=True, labels=matched_cell_labels).opts(hv.opts.RGB(title=''))

unmatched_cell_labels = [np.concatenate((unmatch_centroids[i], unassigned[i].reshape(-1,1)), axis=1) for i in np.arange(len(unmatch_centroids))]
unmatch_cells = hv_plot_FOVs(rigs, unmatch_overlay_binary_footprints, unmatch_binary_contour_projs, 
                             overlay=True, labels=unmatched_cell_labels).opts(hv.opts.RGB(title=''))

match_plot = hv.Layout(all_cells + match_cells + unmatch_cells).cols(3).opts(hv.opts.RGB(xlabel=None, ylabel=None, xaxis=None, yaxis=None, width=350, height=350), hv.opts.Labels(text_font_size='15pt'))

match_plot
# match_cells.opts(hv.opts.RGB(xlabel=None, ylabel=None, xaxis=None, yaxis=None), hv.opts.Labels(text_font_size='15pt'))