In [None]:
import numpy as np
from proofreader.utils.io import read_cremi_volume, from_h5
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from proofreader.utils.plot import make_histogram

def segmentation_size_hist(seg):
    values, counts = np.unique(seg, return_counts=True)
    num = len(values)
    make_histogram(counts, bins=100, xlabel='Neuron Volume', title=f'{num} Total Neurons', logscale=True)


In [None]:
trueA = read_cremi_volume('A',seg=True, path='../../dataset/cremi')
trueB = read_cremi_volume('B',seg=True, path='../../dataset/cremi')
trueC = read_cremi_volume('C',seg=True, path='../../dataset/cremi')

In [None]:
imgA = read_cremi_volume('A', img=True, path='../../dataset/cremi')
imgB = read_cremi_volume('B', img=True, path='../../dataset/cremi')
imgC = read_cremi_volume('C', img=True, path='../../dataset/cremi')

In [None]:
predA = from_h5('../../dataset/segs/RSUnet_900000_seg_sample_A_pad.hdf', dataset_path='volumes/labels/neuron_ids')
predB = from_h5('../../dataset/segs/RSUnet_900000_seg_sample_B_pad.hdf', dataset_path='volumes/labels/neuron_ids')
predC = from_h5('../../dataset/segs/RSUnet_900000_seg_sample_C_pad.hdf', dataset_path='volumes/labels/neuron_ids')

In [None]:
from einops import rearrange
import open3d as o3d 

def convert_3D_img_to_point_cloud(img, threshold=0, flip=True):

    # flip zyx to xyz
    if flip:
        img = np.swapaxes(img,0,2)
    
    (sx, sy, sz) = img.shape
    # generate all coords in img
    cords =  np.mgrid[0:sx, 0:sy, 0:sx]
    # select cords where above threshold
    cords = cords[:][img > threshold]
    cords =  rearrange(cords, 'xyz x y z -> xyz (x y z)')

    return cords

def numpy_to_pointcloud(arr, colors=None):
    arr = arr.astype('float64')
    colors = colors.astype('float64')
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(arr)
    if colors is not None:
        pcd.colors = o3d.utility.Vector3dVector(colors)
    return pcd


def get_pointcloud(img, masks, resolution=(1, 1, 10), paint=False):
    pcds = []
    for mask in masks:
        (sz, sy, sx) = img.shape
        cords =  np.mgrid[0:sz, 0:sy, 0:sx]
        pc = cords[:, mask]
        pc = np.swapaxes(pc,0,1)
        cm = img[mask] / 255
        colors = np.swapaxes(np.vstack((cm,cm,cm)),0,1)
        pcd = numpy_to_pointcloud(pc, colors=colors)
        if paint:
            pcd.paint_uniform_color(np.random.rand((3)))
        pcds.append(pcd)

    return pcds
    # # fit to unit cube
    # pcd.scale(1 / np.max(pcd.get_max_bound() - pcd.get_min_bound()),
    #         center=pcd.get_center())



In [None]:
classes = get_classes_sorted_by_volume(trueB, reverse=True)

In [None]:
neurites = []
for i in range(4):  
    n = get_pointcloud(imgB, [trueB==classes[50+i]], paint=False)
    neurites.append(n)


In [None]:
nur = neurites[0]
for n in neurites:
    nur += n
o3d.visualization.draw_geometries(nur)

In [None]:
# try to find merger/splitters

trueA = read_cremi_volume('A',seg=True, path='../../dataset/cremi')
trueB = read_cremi_volume('B',seg=True, path='../../dataset/cremi')
trueC = read_cremi_volume('C',seg=True, path='../../dataset/cremi')

In [None]:
predA = from_h5('../../dataset/segs/RSUnet_900000_seg_sample_A_pad.hdf', dataset_path='volumes/labels/neuron_ids')
predB = from_h5('../../dataset/segs/RSUnet_900000_seg_sample_B_pad.hdf', dataset_path='volumes/labels/neuron_ids')
predC = from_h5('../../dataset/segs/RSUnet_900000_seg_sample_C_pad.hdf', dataset_path='volumes/labels/neuron_ids')

In [None]:
def get_classes_sorted_by_volume(vol, reverse=False, return_counts=False):

    classes, counts = np.unique(vol, return_counts=True)

    sort_indices = np.argsort(counts)
    if reverse:
        sort_indices = np.flip(sort_indices)
    classes = classes[sort_indices]
    if return_counts:
        counts = counts[sort_indices]
        return classes, counts
    return classes


In [None]:
def get_coverage_over_threshold(true, pred, c, threshold):
    overlap = pred[true == c]
    classes, counts = get_classes_sorted_by_volume(overlap, return_counts=True)
    vol = np.sum(counts)
    percents = np.round(counts / vol, 2)
    truncated_percents = []
    truncated_classes = []
    for i in reversed(range(len(percents))):
        if percents[i] > threshold:
            truncated_percents.append(percents[i])
            truncated_classes.append(classes[i])
        else:
            return np.array(truncated_classes), np.array(truncated_percents)
    return [], []


In [None]:
from proofreader.utils.plot import make_histogram
from matplotlib import pyplot as plt

def print_coverage_stats(c, classes, percents, true_base):
    if true_base:
        base = 'true'
        not_base = 'pred'
    else:
        base = 'pred'
        not_base = 'true'
    for i in range(len(percents)):
        print(f'{classes[i]} in {not_base} covers {percents[i]*100}% of {c} in {base}')

def get_coverage_recursive(A, B, c, threshold=.1, seen=[], true_base=True, depth=0, max_depth=5, final_depth=-1, verbose=True):
    seen.append(c)
    classes0, percents0 = get_coverage_over_threshold(A, B, c, threshold)
    if len(percents0) > 1:
        if verbose:
            print(f'DEPTH: {depth}')
            print_coverage_stats(c, classes0, percents0, true_base)
        final_depth = depth
        if depth < max_depth:
            for c0 in classes0:
                    final_depth = max(get_coverage_recursive(B, A, c0, threshold=threshold, seen=seen, true_base=not true_base, depth=depth+1, max_depth=max_depth, verbose=verbose), final_depth)

    return final_depth 


        

In [None]:
def get_order_0_splitters(true, pred):
    classes = get_classes_sorted_by_volume(true)
    order_0_splitters = {}
    for c in classes[-200:]:
        final_depth = get_coverage_recursive(true, pred, c, seen=[], max_depth=1, verbose=False)
        if final_depth == 0:
            classes0, _ = get_coverage_over_threshold(true, pred, c, 0.1)
            order_0_splitters[c] = classes0
    return order_0_splitters


In [None]:
from proofreader.utils.io import from_h5

path = '/mnt/home/jberman/sc/proofreader/dataset/CREMI/corrected/seg_A.h5'
seg_A = from_h5(path)


In [None]:
path = '/mnt/home/jberman/sc/proofreader/dataset/CREMI/corrected/syn_A.h5'
syn_A = from_h5(path)


In [None]:
path = '/mnt/home/jberman/sc/proofreader/dataset/CREMI/corrected/im_A.h5'
im_A = from_h5(path)

In [None]:
syn_A.shape