In [78]:
import json
import numpy as np
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

from diced import DicedStore
store = DicedStore("gs://flyem-public-connectome")


In [79]:
repo = store.open_repo("medulla7column")
grayscale = repo.get_array("grayscale")
groundtruth = repo.get_array("groundtruth")


In [80]:
def get_extent(volume_coord, volume_extent):
    z_extent = slice(volume_coord[0], volume_coord[0] + volume_extent[0])
    y_extent = slice(volume_coord[1], volume_coord[1] + volume_extent[1])
    x_extent = slice(volume_coord[2], volume_coord[2] + volume_extent[2])
    return (z_extent, y_extent, x_extent)


In [81]:
volume_coord = [3490, 2103, 3253]
volume_extent = [520, 520, 520]
slc = get_extent(volume_coord, volume_extent)


In [82]:
label = groundtruth[slc]
image = grayscale[slc]


In [83]:
lom_radius = (16, 16, 16)
min_size = 10000
thresholds = [0.025,0.05,0.075,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9 ]


In [70]:
# label = np.random.randint(5, size=(7, 7, 7))
# lom_radius = (2, 2, 2)
# min_size = 1
# thresholds = [0.025,0.05,0.075,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9 ]
# label

array([[[1, 3, 4, 0, 0, 0, 1],
        [2, 4, 1, 1, 0, 0, 2],
        [4, 1, 1, 4, 4, 4, 2],
        [4, 1, 3, 2, 4, 0, 4],
        [4, 4, 1, 1, 1, 3, 0],
        [2, 1, 1, 4, 1, 2, 1],
        [0, 3, 3, 0, 0, 2, 1]],

       [[3, 2, 4, 2, 2, 3, 2],
        [3, 0, 3, 4, 4, 2, 0],
        [4, 4, 3, 4, 4, 3, 2],
        [3, 4, 0, 1, 1, 2, 0],
        [3, 3, 1, 2, 0, 1, 4],
        [0, 0, 0, 1, 4, 0, 3],
        [0, 4, 0, 2, 1, 3, 0]],

       [[0, 0, 4, 0, 0, 0, 3],
        [3, 0, 4, 0, 0, 4, 3],
        [3, 4, 2, 1, 3, 1, 0],
        [0, 3, 4, 1, 0, 0, 3],
        [4, 0, 3, 4, 3, 4, 4],
        [1, 4, 2, 3, 1, 0, 0],
        [1, 2, 4, 2, 3, 1, 3]],

       [[2, 4, 3, 3, 1, 3, 3],
        [4, 0, 1, 1, 4, 1, 2],
        [2, 4, 2, 2, 2, 3, 3],
        [1, 3, 2, 4, 0, 4, 1],
        [4, 4, 4, 3, 0, 1, 4],
        [0, 4, 0, 3, 3, 3, 3],
        [2, 2, 4, 3, 1, 1, 3]],

       [[1, 3, 3, 0, 3, 3, 0],
        [4, 4, 2, 3, 1, 2, 4],
        [3, 4, 4, 3, 4, 2, 2],
        [0, 0, 1, 0, 3, 0, 4],


# compute_partitions

In [84]:
shape = label.shape # [520, 520, 520]
lom_radius = [int(x) for x in lom_radius] # [16, 16, 16]
ids, sizes = np.unique(label, return_counts=True) # ids, 332

def _clear_dust(label, ids, sizes):
    small = ids[sizes < min_size]
    small_mask = np.in1d(label.flat, small).reshape(label.shape)
    label[small_mask] = 0
    return label

clean_label = _clear_dust(label, ids, sizes)
lom_radius = np.array(lom_radius) # array([16, 16, 16])
lom_radius_zyx = lom_radius[::-1] # array([16, 16, 16])
lom_diam_zyx = 2 * lom_radius_zyx + 1 # array([33, 33, 33])

def _sel(i):
    if i == 0:
        return slice(None)
    else:
        return slice(i, -i)
    
valid_sel = [_sel(x) for x in lom_radius_zyx] # [slice(16, -16, None), slice(16, -16, None), slice(16, -16, None)]
output = np.zeros(clean_label[valid_sel].shape, dtype=np.uint8) # shape = [488, 488, 488]
corner = lom_radius

labels = set(np.unique(clean_label))

fov_volume = np.prod(lom_diam_zyx) # 33 ** 3

def _summed_volume_table(val):
    val = val.astype(np.int32)
    svt = val.cumsum(axis=0).cumsum(axis=1).cumsum(axis=2)
    return np.pad(svt, [[1, 0], [1, 0], [1, 0]], mode='constant')

def _query_summed_volume(svt, diam):
    return (svt[diam[0]:, diam[1]:, diam[2]:] - svt[diam[0]:, diam[1]:, :-diam[2]] - 
            svt[diam[0]:, :-diam[1], diam[2]:] - svt[:-diam[0], diam[1]:, diam[2]:] +
            svt[:-diam[0], :-diam[1], diam[2]:] + svt[:-diam[0], diam[1]:, :-diam[2]]
            + svt[diam[0]:, :-diam[1], :-diam[2]] - svt[:-diam[0], :-diam[1], :-diam[2]])

# for l in labels:
#     if l == 0:
#         continue
    
#     object_mask = (clean_label == l)
    
#     svt = _summed_volume_table(object_mask) # shape = [521, 521, 521]
#     active_fraction = _query_summed_volume(svt, lom_diam_zyx) / float(fov_volume)
    
#     print(active_fraction)
    
#     object_mask = object_mask[valid_sel]

#     for i, th in enumerate(thresholds):
#         output[object_mask & (active_fraction < th) & (output == 0)] = i + 1
#     output[object_mask & (active_fraction >= thresholds[-1]) &
#            (output == 0)] = len(thresholds) + 1
    

In [87]:
object_mask = (clean_label == 700)
svt = _summed_volume_table(object_mask) # shape = [521, 521, 521]
active_fraction = _query_summed_volume(svt, lom_diam_zyx) / float(fov_volume)
print(active_fraction.shape)

(488, 488, 488)


In [34]:
labels

{0, 1, 2, 3, 4}

In [27]:
diam = lom_diam_zyx

svt[diam[0]:, diam[1]:, diam[2]:].shape
diam[0]

33

In [18]:
object_mask = (clean_label ==4746)
svt = _summed_volume_table(object_mask)
active_fraction = _query_summed_volume(svt, lom_diam_zyx)
active_fraction

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ..., 
       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, 

In [19]:
active_fraction.shape

(488, 488, 488)

In [23]:
active_fraction[0, 0, 0]

0

In [110]:
np.random.seed(3)
z, y, x = np.random.randint(0+24, 520-24, 3)


In [111]:
clean_label[z, y, x]


58969

In [119]:
bbox = clean_label[z-24:z+25, y-24:y+25, x-24:x+25]
bbox.shape

(49, 49, 49)

In [121]:
fraction = float(bbox[np.where(bbox == clean_label[z,y,x])].shape[0]) / 49.0**3