In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from proofreader.utils.vis import plot_3d
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from proofreader.data.cremi import prepare_cremi_vols

train_vols, test_vols = prepare_cremi_vols('../../dataset/cremi')

In [None]:
from proofreader.data.splitter import NeuriteDataset
from proofreader.data.augment import Augmentor
num_slices = [4, 4]
radius = 96
context_slices = 4
num_points = 1024

augmentor = Augmentor(center=True, shuffle=True, normalize=[125, 1250, 1250])
train_dataset = NeuriteDataset(test_vols, num_slices, radius, context_slices, num_points=num_points, torch=True, open_vol=True, verbose=False, Augmentor=augmentor)
print(len(train_dataset))

In [None]:
from torch.utils.data import DataLoader
import torch.nn as nn
config = get_config('cn_context_4_aug_small')
dataloader = DataLoader(dataset=train_dataset, batch_size=2, shuffle=True)
model, _ ,_ = build_full_model_from_config(config.model, config.dataset)
model = nn.DataParallel(model)
model = load_model(model, '../../330.ckpt', map_location=torch.device('cpu'))



In [None]:
from proofreader.model.classifier import *
with torch.no_grad():
    count, acc = 0,0
    for step, batch in enumerate(dataloader):
        count += 1
        # get batch
        x, y = batch
        y_hat = model(x)
        pred = predict_class(y_hat)
        accs = get_accuracy(y, pred)
        print(accs)
        acc += accs['total_acc']
        print(round(acc/count, 3))


In [41]:
from proofreader.utils.data import *
from proofreader.utils.all import *
from proofreader.utils.vis import *
import cc3d
from skimage.segmentation import find_boundaries
from scipy import ndimage
from proofreader.utils.torch import *
from proofreader.data.augment import Augmentor

class TestDataset(torch.utils.data.IterableDataset):
    def __init__(self,
                 vols,
                 num_slices: int,
                 radius: int,
                 context_slices: int,
                 num_points: int = None,
                 add_batch_id: bool = False,
                 Augmentor: Augmentor = Augmentor(),
                 verbose: bool = False,
                 ):

        # clean vols
        self.vols = []
        for vol in vols:
            vol = zero_classes_with_min_volume(vol, 500)
            self.vols.append(vol)

        self.num_slices = num_slices
        self.radius = radius
        self.context_slices = context_slices
        self.num_points = num_points
        self.Augmentor = Augmentor
        self.verbose = verbose
        self.add_batch_id = add_batch_id

        self.test_iteration_batch = None
        self.test_iteration_i = 0
        self.test_iteration_len = 0

        self.top_neurites = np.zeros((0))
        self.cur_neurite_i = 0
        self.vol_relabeled = None
        self.label_map = None

        self.cur_drop_start = context_slices
        self.cur_vol_i = 0

        self.no_true = 0
        self.multi_true = 0


    def load_next_candidate_batch(self):
        
        if self.cur_neurite_i >= self.top_neurites.shape[0]:
            if self.verbose:
                print('getting all top neurites for drop')

            # only increment if its not the init
            if self.top_neurites.shape[0] > 0:
                self.increment_vol_and_drop()

            vol = self.get_cur_vol()
            (drop_start, drop_end) = self.get_cur_drop()
            cs = self.context_slices

            # build a new vol with slices dropped in the middle
            # and do connected_components do relablel/detach neurites on
            # either side of the volume
            vol_relabeled = np.zeros_like(vol)
            vol_relabeled[drop_start -
                          cs:drop_start] = vol[drop_start-cs:drop_start]
            vol_relabeled[drop_end:drop_end+cs] = vol[drop_end:drop_end+cs]

            # remeber where the background is then reset in after cc
            zero_indices = vol_relabeled == 0
            vol_relabeled = cc3d.connected_components(vol_relabeled)
            vol_relabeled[zero_indices] = 0

            # create a map from the new lables to the original labels
            # this allows us to figure out the ground truth for accuracy
            label_map = correspond_labels(vol_relabeled, vol, bg_label=0)

            # take the neurites on the top border of the missing slices
            # and attempt to reattach
            top_neurites = np.unique(vol_relabeled[drop_start-1])
            top_neurites = np.delete(top_neurites, 0)
            np.random.shuffle(top_neurites)
            self.top_neurites = top_neurites
            self.cur_neurite_i = 0
            self.vol_relabeled = vol_relabeled
            self.label_map = label_map

        drop = self.get_cur_drop()
        c = self.top_neurites[self.cur_neurite_i]

        examples, labels = self.get_examples_from_top_class(
            self.vol_relabeled, c, drop, self.label_map)

        # # sanity check
        num_true = labels.count_nonzero()
        if num_true == 0:
            self.no_true += 1
        if num_true > 1:
            self.multi_true += 1

        self.cur_neurite_i += 1

        print()
        return (examples, labels)

    def get_examples_from_top_class(self, vol, c, drop, label_map):

        top_c = c
        (sz, sy, sx) = vol.shape

        # Find min and max z slice on which c occurs #
        for i in range(sz):
            if c in vol[i]:
                zmin = i
                break

        drop_start, drop_end = drop
        num_slices = drop_end - drop_start
        top_z_len = min(self.context_slices, drop_start-zmin)
        bot_z_len = min(self.context_slices, sz-drop_end)

        # Alloc final vol, we dont know how large it will be in y and x but we know max z #
        mz = num_slices + top_z_len + bot_z_len
        final_vol = np.zeros((mz, sy, sx), dtype='uint')

        # Build top section #
        top_vol_section = final_vol[0:top_z_len]
        top_vol_section[vol[drop_start-top_z_len:drop_start] == top_c] = top_c

        # Get midpoint of neurite on 2D top cross section, #
        top_border = top_vol_section[-1]
        # use the relabeled top section
        (com_x, com_y) = ndimage.measurements.center_of_mass(top_border)
        (com_x, com_y) = round(com_x), round(com_y)

        # Find all neurites with distnce D from that point on bottom cross section #
        bot_border = vol[drop_end].copy()  # need copy because we zero
        mask = circular_mask(
            bot_border.shape[0], bot_border.shape[1], center=(com_y, com_x), radius=self.radius)
        bot_border[~mask] = 0

        # get classes in order of distance from top neurite
        # for efficieny we just look at the top_border and bot_border stack
        # with mask already applied to bot_border
        d_vol = np.stack([top_border, bot_border])
        d_vol = crop_where(d_vol, d_vol != 0)
        mismatch_classes = get_classes_sorted_by_distance(d_vol, top_c, method='mean')

        final_vol[0: top_z_len] = top_vol_section
        final_examples = torch.zeros(
            (len(mismatch_classes), 3, self.num_points))
        final_lables = []
        for i, bot_c in enumerate(mismatch_classes):

            cur_vol = final_vol.copy()
            # Build bot section #
            bot_vol_section = cur_vol[num_slices+top_z_len:]
            bot_vol_section[vol[drop_end:drop_end+bot_z_len] == bot_c] = bot_c

            # Build final volume of bottom sections #
            cur_vol[num_slices+top_z_len:] = bot_vol_section

            pc = self.convert_volumetric_to_final(cur_vol)
            final_examples[i] = pc
            label = int(label_map[top_c] == label_map[bot_c])
            final_lables.append(label)

        return final_examples, torch.tensor(final_lables)

    def remove_vol_interiors(self, vol):

        def rm_interior(v):
            return v * find_boundaries(
                v, mode='inner')

        for i in range(vol.shape[0]):
            vol[i] = rm_interior(vol[i])

        return vol

    def convert_to_point_cloud(self, vol):

        pc = convert_grid_to_pointcloud(vol)
        if self.num_points is not None:
            num_points = pc.shape[0]

            if num_points < self.num_points:
                pc = random_sample_arr(
                    pc, count=self.num_points, replace=True)

            else:
                pc = random_sample_arr(pc, count=self.num_points)

        return pc

    def convert_volumetric_to_final(self, vol_example):

        # final crop and relabel
        vol_example = crop_where(vol_example, vol_example != 0)
        vol_example = cc3d.connected_components(vol_example)

        # sanity check
        all_classes = np.unique(vol_example)
        assert len(
            all_classes) == 3, f'final sample should have 3 classes, [0, n1, n2] not {all_classes}'

        # grid_volume(color.label2rgb(vol_example, bg_label=0))

        # remove interiors
        vol_example = self.remove_vol_interiors(vol_example)

        # convert to point cloud
        pc_example = self.convert_to_point_cloud(vol_example)
        if self.Augmentor is not None:
            pc_example = self.Augmentor.transfrom(pc_example)
        pc_example = np.swapaxes(pc_example, 0, 1)

        pc_example = torch.from_numpy(pc_example).type(torch.float32)

        return pc_example

    def get_drop_start_range(self):
        # point at each drop end cannot exceede
        cur_vol = self.get_cur_vol()
        range_start = self.context_slices
        range_stop = cur_vol.shape[0] - self.num_slices - self.context_slices


        return (range_start, range_stop)

    def increment_vol_and_drop(self):
        if self.verbose:
            print('increment drop')
        self.cur_drop_start += 1
        range_start, range_stop = self.get_drop_start_range()

        # if we have reached the end of the vol, do to next vol
        if self.cur_drop_start >= range_stop:
            self.cur_vol_i += 1
            self.cur_drop_start = range_start
            if self.verbose:
                print('increment vol')
            if self.cur_vol_i >= len(self.vols):
                raise StopIteration

    def get_cur_vol(self):
        return self.vols[self.cur_vol_i]

    def get_cur_drop(self):
        cur_drop_end = self.cur_drop_start + self.num_slices
        return (self.cur_drop_start, cur_drop_end)

    def get_next(self):
        (examples, labels) = self.load_next_candidate_batch()

        if self.verbose:
            print(
                f'vol: {self.cur_vol_i}, drop: {self.get_cur_drop()}, neurite: {self.cur_neurite_i}, candidate: {self.test_iteration_i}')

        return (examples, labels)

    def __iter__(self):
        while True:
            try:
                yield self.get_next()
            except StopIteration:
                return

In [42]:
from proofreader.data.augment import Augmentor
from proofreader.utils.torch import load_model
from proofreader.model.config import *
from torch.utils.data import DataLoader
import torch.nn as nn

config = get_config('cn_context_4_aug_small')

model, _ ,_ = build_full_model_from_config(config.model, config.dataset)
model = nn.DataParallel(model)
model = load_model(model, '../../330.ckpt', map_location=torch.device('cpu'))

num_slices = 4
radius = 96
context_slices = 4
num_points = 1024
augmentor = Augmentor(center=True, shuffle=True, normalize=[125, 1250, 1250])
tester = TestDataset(test_vols[2:], num_slices, radius, context_slices, num_points=num_points, Augmentor=augmentor, verbose=False)


In [43]:
dataloader = DataLoader(dataset=tester, batch_size=64)
with torch.no_grad():
    model.eval()
    count, acc = 0,0
    for step, batch in enumerate(tester):
        count += 1
        # get batch
        x, y = batch

        for i in range(x.shape[0]):
            y_hat = model(x[i])
            pred = predict_class(y_hat)
            accs = get_accuracy(y, pred)
            




0.28125
{'total_acc': 0.734375, 'true_acc': 0.8888888955116272, 'false_acc': 0.6739130616188049}
0.734
0.171875
{'total_acc': 0.59375, 'true_acc': 0.9090909361839294, 'false_acc': 0.5283018946647644}
0.664
0.28125


In [None]:

lim = (-.1,.1)
lim = (lim,lim,lim)
for step, batch in enumerate(tester):
    x, y = batch
    print(step, y)
    # for i in range(x.shape[0]):
    #     pc = np.swapaxes(x[i], 0, 1)
    #     label = y[i].item() == 1
    #     plot_3d(pc, title=label, lims=lim)
    if step > 50:
        break

