In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
mice = [25,27,30,36,39,52]
test_sessions = [1,15]
planes = range(1,9)
save_dir = Path(r'C:\Users\shires\Dropbox\Works\Projects\2020 Neural stretching in S1\Analysis\cellpose_mask_test')
base_dir_1 = Path('F:/')
base_dir_2 = Path('D:/')

In [3]:
def get_overlap_mat(mask1, mask2):
    roi_inds1 = np.setdiff1d(np.unique(mask1), [0])
    roi_inds2 = np.setdiff1d(np.unique(mask2), [0])
    overlap_mat = np.zeros((len(roi_inds1), len(roi_inds2)))
    visited_roi_inds2 = []
    for j in range(len(roi_inds1)):
        temp_roi_ind1_map = mask1 == roi_inds1[j]
        temp_roi_ind2_visit = np.setdiff1d(np.unique(mask2[temp_roi_ind1_map]), [0])
        visited_roi_inds2.extend(temp_roi_ind2_visit)
        for temp_roi_2 in temp_roi_ind2_visit:
            k = np.where(roi_inds2 == temp_roi_2)[0][0]
            intersection_map = temp_roi_ind1_map * (mask2 == temp_roi_2) > 0
            union_map = temp_roi_ind1_map + (mask2 == temp_roi_2) > 0
            overlap_mat[j, k] = np.sum(intersection_map) / np.sum(union_map)
    left_roi_inds2 = np.setdiff1d(roi_inds2, np.unique(visited_roi_inds2))
    for temp_roi_ind2 in left_roi_inds2:
        k = np.where(roi_inds2 == temp_roi_2)[0][0]
        temp_roi_ind2_map = mask2 == temp_roi_ind2
        temp_roi_ind1_visit = np.setdiff1d(np.unique(mask1[temp_roi_ind2_map]), [0])
        for temp_roi_ind1 in temp_roi_ind1_visit:
            j = np.where(roi_inds1 == temp_roi_ind1)[0][0]
            intersection_map = temp_roi_ind2_map * (mask1 == temp_roi_ind1) > 0
            union_map = temp_roi_ind2_map + (mask1 == temp_roi_ind1) > 0
            overlap_mat[j, k] = np.sum(intersection_map) / np.sum(union_map)
    return overlap_mat

def get_overlap_mat_list(masks):
    overlap_mat_list = []
    for i in range(masks.shape[0]-1):
        temp_overlap_mat = get_overlap_mat(masks[i, :, :], masks[i+1, :, :])
        overlap_mat_list.append(temp_overlap_mat)
    return overlap_mat_list

def draw_mask_overlaps(ax, mask_curr, mask_next, roi_ind_curr, roi_ind_next, buffer=10):
    img_dims = mask_curr.shape
    assert img_dims == mask_next.shape
    yrange = [img_dims[0], 0]
    xrange = [img_dims[1], 0]

    if type(roi_ind_curr) == int:
        roi_ind_curr = [roi_ind_curr]
    if type(roi_ind_next) == int:
        roi_ind_next = [roi_ind_next]
    
    curr_roi_nums = np.setdiff1d(np.unique(mask_curr), [0])
    next_roi_nums = np.setdiff1d(np.unique(mask_next), [0])

    temp_curr_mask = np.zeros(img_dims, dtype=bool)
    for i in roi_ind_curr:
        roi_num = curr_roi_nums[i]
        draw_curr_mask = np.zeros(img_dims, dtype=bool)
        temp_curr_mask[mask_curr == roi_num] = True
        draw_curr_mask[mask_curr == roi_num] = True
        draw_contours(ax, draw_curr_mask, colors='m', linewidths=1)
        temp_inds = np.where(mask_curr == roi_num)
        yrange[0] = min(np.min(temp_inds[0]), yrange[0])
        yrange[1] = max(np.max(temp_inds[0]), yrange[1])
        xrange[0] = min(np.min(temp_inds[1]), xrange[0])
        xrange[1] = max(np.max(temp_inds[1]), xrange[1])
    temp_next_mask = np.zeros(img_dims, dtype=bool)
    for i in roi_ind_next:
        roi_num = next_roi_nums[i]
        draw_next_mask = np.zeros(img_dims, dtype=bool)
        temp_next_mask[mask_next == roi_num] = True
        draw_next_mask[mask_next == roi_num] = True
        draw_contours(ax, draw_next_mask, colors='c', linewidths=1)
        temp_inds = np.where(mask_next == roi_num)
        yrange[0] = min(np.min(temp_inds[0]), yrange[0])
        yrange[1] = max(np.max(temp_inds[0]), yrange[1])
        xrange[0] = min(np.min(temp_inds[1]), xrange[0])
        xrange[1] = max(np.max(temp_inds[1]), xrange[1])
    
    yrange[0] = max(yrange[0] - buffer, 0)
    yrange[1] = min(yrange[1] + buffer, img_dims[0])
    xrange[0] = max(xrange[0] - buffer, 0)
    xrange[1] = min(xrange[1] + buffer, img_dims[1])
    ax.set_xlim(xrange)
    ax.set_ylim(yrange)

    intersection = np.logical_and(temp_curr_mask, temp_next_mask)
    ioa_curr = intersection.sum() / temp_curr_mask.sum()
    ioa_next = intersection.sum() / temp_next_mask.sum()

    curr_size_dist = [np.sum(mask_curr==i) for i in np.setdiff1d(np.unique(mask_curr), [0])]
    next_size_dist = [np.sum(mask_next==i) for i in np.setdiff1d(np.unique(mask_next), [0])]
    curr_size_percentile = []
    for i in roi_ind_curr:
        roi_num = curr_roi_nums[i]
        select_roi_size = np.sum(mask_curr==roi_num)
        percentile = f'{np.sum(np.array(curr_size_dist)<select_roi_size)/len(curr_size_dist):.2f}'
        curr_size_percentile.append(float(percentile))
    next_size_percentile = []
    for i in roi_ind_next:
        roi_num = next_roi_nums[i]
        select_roi_size = np.sum(mask_next==roi_num)
        percentile = f'{np.sum(np.array(next_size_dist)<select_roi_size)/len(next_size_dist):.2f}'
        next_size_percentile.append(float(percentile))

    ax.set_title(f'IOA curr: {ioa_curr:.2f}, IOA next: {ioa_next:.2f}\ncurr size percentile: {curr_size_percentile}\nnext size percentile: {next_size_percentile}')
    return ioa_curr, ioa_next, curr_size_percentile, next_size_percentile


def draw_contours(ax, mask, img=None, colors='r', linewidths=1, alpha=1):
    if img is None:
        img = np.zeros(mask.shape)
    
    ax.imshow(img, cmap='gray')
    mask_inds = np.setdiff1d(np.unique(mask), 0)
    for i in mask_inds:
        temp_mask = mask == i
        ax.contour(temp_mask, colors=colors, linewidths=linewidths, alpha=alpha)

In [8]:
iou_threshold = 0.3
image_type_list = ['mean', 'meanE', 'max']
overlap_df = pd.DataFrame(columns=['mouse', 'session', 'plane', 'image_type', 'overlap_mat_ind', 'pair_ind',
                                   'overlap_mat', 'curr_mask', 'next_mask', 'curr_multi_ind', 'next_multi_ind',
                                   'ioa_curr', 'ioa_next', 'curr_size_percentile', 'next_size_percentile'])
for mouse in mice:
    if mouse < 31:
        base_dir = base_dir_1
    else:
        base_dir = base_dir_2
    for session in test_sessions:
        for plane in planes:
            print(f'mouse {mouse}, session {session}, plane {plane}')
            ops_fn = base_dir / f'{mouse:03}' / f'plane_{plane}' / f'{session:03}' / 'plane0' / 'ops.npy'
            ops = np.load(ops_fn, allow_pickle=True).item()

            results = np.load(save_dir / f'masks_{mouse:03d}_{session:03d}_{plane:03d}.npy', allow_pickle=True)
            save_image_i = 0
            for mask_ind, mask in enumerate(results):
                overlap_mat_list = get_overlap_mat_list(mask)
                save_mask_i = 0
                for overlap_ind, overlap_mat in enumerate(overlap_mat_list):
                    multi_ind_next = np.where(np.sum(overlap_mat > iou_threshold, axis=0)>1)[0]
                    multi_ind_curr = np.where(np.sum(overlap_mat > iou_threshold, axis=1)>1)[0]
                    if (len(multi_ind_next) > 0) or (len(multi_ind_curr) > 0):  # if there are any overlaps
                        # flags for df saving
                        save_image_i += 1
                        save_mask_i += 1

                        # Collect overlapping roi inds, going from curr to next and next to curr, and vice versa
                        next_roi_inds = np.setdiff1d(np.unique(mask[overlap_ind+1, :, :]), [0])
                        next_multi_inds = next_roi_inds[multi_ind_next]
                        overlap_curr_roi_inds = []
                        overlap_next_roi_inds = []
                        for ind in multi_ind_next:
                            temp_overlap_curr_roi_inds = np.where(overlap_mat[:, ind] > iou_threshold)[0]
                            overlap_curr_roi_inds.append(temp_overlap_curr_roi_inds)
                            temp_multi_overlap_next_roi_inds = []
                            for curr_ind in temp_overlap_curr_roi_inds:
                                temp_overlap_next_roi_inds = np.where(overlap_mat[curr_ind, :] > iou_threshold)[0]
                                temp_multi_overlap_next_roi_inds.append(temp_overlap_next_roi_inds)
                            overlap_next_roi_inds.append(np.unique(temp_multi_overlap_next_roi_inds))

                        for ind in multi_ind_curr:
                            temp_overlap_next_roi_inds = np.where(overlap_mat[ind, :] > iou_threshold)[0]
                            overlap_next_roi_inds.append(temp_overlap_next_roi_inds)
                            temp_multi_overlap_curr_roi_inds = []
                            for next_ind in temp_overlap_next_roi_inds:
                                temp_overlap_curr_roi_inds = np.where(overlap_mat[:, next_ind] > iou_threshold)[0]
                                temp_multi_overlap_curr_roi_inds.append(temp_overlap_curr_roi_inds)
                            overlap_curr_roi_inds.append(np.unique(temp_multi_overlap_curr_roi_inds))

                        # Get the image
                        image_type = image_type_list[mask_ind]
                        if image_type == 'mean':
                            img = ops['meanImg']
                        elif image_type == 'meanE':
                            img = ops['meanImgE']
                        else:
                            img = ops['max_proj']
                        for pair_ind in range(len(overlap_curr_roi_inds)):
                            curr_inds = overlap_curr_roi_inds[pair_ind]
                            next_inds = overlap_next_roi_inds[pair_ind]
                            
                            # draw and save images for each of the overlap inds pairs
                            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
                            ax[0].imshow(img)
                            ioa_curr, ioa_next, curr_size_percentile, next_size_percentile = \
                                draw_mask_overlaps(ax[1], mask[overlap_ind, :, :], mask[overlap_ind+1, :, :], 
                                                   curr_inds, next_inds)
                            ax[1].set_xlim(np.roll(ax[1].get_xlim(), 1))
                            ax[1].set_ylim(np.roll(ax[1].get_ylim(), 1))
                            ax[0].set_xlim(ax[1].get_xlim())
                            ax[0].set_ylim(ax[1].get_ylim())
                            save_fn = f'JK{mouse:03d}_s{session:03d}_plane{plane}_{image_type}_overlapmat_{overlap_ind}_pair_{pair_ind}.png'
                            fig.tight_layout()
                            fig.savefig(save_dir / 'overlap_curation' / save_fn)
                            plt.close(fig)

                            # save images, masks, and overlap_mats only for the first row of the same kind
                            if save_image_i == 1:
                                save_img = [img.copy()]
                            else:
                                save_img = [[]]
                            if save_mask_i == 1:
                                curr_mask = [mask[overlap_ind, :, :]]
                                next_mask = [mask[overlap_ind+1, :, :]]
                                overlap_mat = [overlap_mat]
                            else:
                                curr_mask = [[]]
                                next_mask = [[]]
                                overlap_mat = [[]]
                            
                            # Build df
                            temp_df = pd.DataFrame({'mouse': mouse,
                                                    'session': session,
                                                    'plane': plane,
                                                    'image_type': image_type,
                                                    'image': save_img,
                                                    'overlap_mat_ind': overlap_ind,
                                                    'pair_ind': pair_ind,
                                                    'curr_mask': curr_mask,
                                                    'next_mask': next_mask,
                                                    'overlap_mat': overlap_mat,
                                                    'curr_overlap_inds': [curr_inds],
                                                    'next_overlap_inds': [next_inds],
                                                    'ioa_curr': ioa_curr,
                                                    'ioa_next': ioa_next,
                                                    'curr_size_percentile': [curr_size_percentile],
                                                    'next_size_percentile': [next_size_percentile]})

                            # Merge df
                            overlap_df = pd.concat([overlap_df, temp_df], ignore_index=True)
# Save df
overlap_df.to_pickle(save_dir / 'overlap_df.pkl')
# Remove images, masks, and overlap_mats from df
overlap_df = overlap_df.drop(columns=['image', 'curr_mask', 'next_mask', 'overlap_mat'])
# Save to csv
overlap_df.to_csv(save_dir / 'overlap_df.csv')


mouse 25, session 1, plane 1
mouse 25, session 1, plane 2
mouse 25, session 1, plane 3
mouse 25, session 1, plane 4
mouse 25, session 1, plane 5
mouse 25, session 1, plane 6
mouse 25, session 1, plane 7
mouse 25, session 1, plane 8
mouse 25, session 15, plane 1
mouse 25, session 15, plane 2
mouse 25, session 15, plane 3
mouse 25, session 15, plane 4
mouse 25, session 15, plane 5
mouse 25, session 15, plane 6
mouse 25, session 15, plane 7
mouse 25, session 15, plane 8
mouse 27, session 1, plane 1
mouse 27, session 1, plane 2
mouse 27, session 1, plane 3
mouse 27, session 1, plane 4
mouse 27, session 1, plane 5
mouse 27, session 1, plane 6
mouse 27, session 1, plane 7
mouse 27, session 1, plane 8
mouse 27, session 15, plane 1
mouse 27, session 15, plane 2
mouse 27, session 15, plane 3
mouse 27, session 15, plane 4
mouse 27, session 15, plane 5
mouse 27, session 15, plane 6
mouse 27, session 15, plane 7
mouse 27, session 15, plane 8
mouse 30, session 1, plane 1
mouse 30, session 1, plane 

In [6]:
len(curr_mask)

0

In [5]:
temp_df = pd.DataFrame({'mouse': mouse,
                                                    'session': session,
                                                    'plane': plane,
                                                    'image_type': image_type,
                                                    'image': save_img,
                                                    'overlap_mat_ind': overlap_ind,
                                                    'pair_ind': pair_ind,
                                                    'curr_mask': curr_mask,
                                                    'next_mask': next_mask,
                                                    'overlap_mat': overlap_mat,
                                                    'curr_overlap_inds': [curr_inds],
                                                    'next_overlap_inds': [next_inds],
                                                    'ioa_curr': ioa_curr,
                                                    'ioa_next': ioa_next,
                                                    'curr_size_percentile': [curr_size_percentile],
                                                    'next_size_percentile': [next_size_percentile]})

ValueError: All arrays must be of the same length

In [None]:
curr_inds.tolist()

[84, 86]