In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from proofreader.utils.vis import plot_3d
import torch
%load_ext autoreload
%autoreload 2
%matplotlib inline

def plot_example(x,y=None,title=None, lim=0.1):
    lim = (lim*-1,lim)
    lim = (lim,lim,lim)
    pc = np.swapaxes(x, 0, 1)
    t = ''
    if y is not None:
        t = y.item() == 1
    if title is not None:
        t = title
    plot_3d(pc, title=t, lims=lim)


In [None]:
from proofreader.data.cremi import prepare_cremi_vols

train_vols, test_vols = prepare_cremi_vols('../../dataset/cremi')

In [None]:
from proofreader.data.splitter import NeuriteDataset
from proofreader.data.augment import Augmentor
num_slices = [4, 4]
radius = 96
context_slices = 4
num_points = 1024

augmentor = Augmentor(center=True, shuffle=True, normalize=[125, 1250, 1250])
train_dataset = NeuriteDataset(test_vols, num_slices, radius, context_slices, num_points=num_points, torch=True, open_vol=True, verbose=False, Augmentor=augmentor)
print(len(train_dataset))

In [None]:
from torch.utils.data import DataLoader
import torch.nn as nn
config = get_config('cn_context_4_aug_small')
dataloader = DataLoader(dataset=train_dataset, batch_size=2, shuffle=True)
model, _ ,_ = build_full_model_from_config(config.model, config.dataset)
model = nn.DataParallel(model)
model = load_model(model, '../../330.ckpt', map_location=torch.device('cpu'))



In [None]:
from proofreader.model.classifier import *
with torch.no_grad():
    count, acc = 0,0
    for step, batch in enumerate(dataloader):
        count += 1
        # get batch
        x, y = batch
        y_hat = model(x)
        pred = predict_class(y_hat)
        accs = get_accuracy(y, pred)
        print(accs)
        acc += accs['total_acc']
        print(round(acc/count, 3))


In [None]:
from proofreader.data.augment import Augmentor
from proofreader.utils.torch import load_model
from proofreader.model.config import *
from proofreader.model.classifier import *
from torch.utils.data import DataLoader
import torch.nn as nn

config = get_config('pointnet-pre-cs2')
ds_config = config.dataset

model = PointNet(num_points=2000, classes=2, batch_norm=True)
model = nn.DataParallel(model)
model = load_model(model, '../../750.ckpt', map_location=torch.device('cpu'))

augmentor = Augmentor(center=True, shuffle=True, normalize=[125, 1250, 1250])
num_slices = 1
tester = SliceDataset(test_vols, num_slices, 96, 2, 
                    num_points=2000, Augmentor=augmentor, verbose=False, drop_false=True, candidate_group=True, randomize=True)


In [None]:

with torch.no_grad():
    model.eval()
    neurites = 0
    seen = 0
    seen_correct = 0
    neurite_correct = 0
    for step, batch in enumerate(tester):
      
        # get batch
        x, y = batch

        for i in range(x.shape[0]): 
            y_hat = model(x[i].unsqueeze(dim=0))
            pred = predict_class(y_hat)
            true = int(pred == y[i])
            seen_correct += true
            seen += 1
            
            if pred == 1 or y[i] == 1:
                neurites += 1
                neurite_correct += true
                # if true == 0:
                #     plot_example(x[i],y[i])
                break

        if step % 10 == 0:
            print('seen acc:', seen_correct/seen, 'neurite acc:', neurite_correct/neurites)
            print(tester.get_stats())
        if step > 1000:
            break

print('seen acc:', seen_correct/seen)
print('neurite acc:', neurite_correct/neurites)
print(tester.get_stats())


In [None]:
import torch
from proofreader.utils.vis import *
import numpy as np

# for i in range(1,6):
#     print('NUM SLICES: ', i)
path = f'/mnt/home/jberman/ceph/pf/dataset/ns=3_r=128_cs=2_np=2048_dataset_test.pt'
X, Y = torch.load(f'{path}')


In [None]:
lens = []
for b in X:
    lens.append(len(b))
lens = np.array(lens)

# if Y[:,0][-1].item() != 1:
#     if Y[:,0].count_nonzero().item() != 0:
#         print(y[:,0])

cutoff = 8
p_sum = 0
for i in range(15):
    p = len(lens[lens==i])/len(lens)
    print(i, p)
    if i > cutoff:
        p_sum += p
print(cutoff, p_sum)
make_histogram(lens, bins=50)

In [None]:
# reset batch id
ts = 0
total = 0
for y in Y:
    ts += y[:,0].count_nonzero().item()
    total += len(y[:,0])

print(ts, ts/total)

In [None]:
from proofreader.model.config import *
# merge canidate batches for testset
# reset batch id
for i in range(len(X)):
    y = Y[i]
    y[:,1] = i

test_dataset_merged = SimpleDataset(torch.cat(X), torch.cat(Y), shuffle=True)


In [None]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset=test_dataset_merged, batch_size=256, drop_last=False, shuffle=True)

In [None]:
from proofreader.model.classifier import *
ys, preds, bids = [], [] ,[]

for step, batch in enumerate(dataloader):
    # get batch
    x, y = batch
    bid = y[:, 1]
    y = y[:, 0]

    pred = torch.zeros_like(y)
    ys.append(y)
    preds.append(pred)
    bids.append(bid)


ys, preds, bids = torch.cat(ys), torch.cat(preds), torch.cat(bids)

uids = np.unique(bids)
batch_acc = {'neurite_acc': 0, 'seen_acc': 0}
print(uids)
for uid in uids:
    idxs = bids == uid
    print(ys[idxs], preds[idxs])
    accs = get_accuracy(ys[idxs], preds[idxs], ret_perfect=True)
    batch_acc['neurite_acc'] += accs['perfect']
    batch_acc['seen_acc'] += (accs['total_acc']*len(ys[idxs]))


batch_acc['neurite_acc'] /= len(uids)
batch_acc['seen_acc'] /= len(test_dataset_merged)
print(len(uids), len(test_dataset_merged))
print(batch_acc)

In [None]:
import random
import torch
W = torch.load(f'../../wrong.t')

In [None]:
i = random.randint(0,len(W))
bx, by, bp = W[i]
print(i)
for j in range(len(bx)):
    title = f'pred: {bp[j]} true: {by[j]}'
    plot_example(bx[j],by[j],title=title)

In [2]:
import torch
from proofreader.utils.vis import *
import numpy as np

path = f'/mnt/home/jberman/ceph/pf/dataset/DATASET_m=False_ns=1_cs=3_r=96_np=2048_t=0_sc=1000_test.pt'
X, Y = torch.load(f'{path}')
X, Y = torch.cat(X), torch.cat(Y)


In [15]:
num_slices = 1
new_cs = 2
for i,x in enumerate(X[200:300]):
    x = x.numpy()
    d, n = x.shape
    z_cords = x[0,:]
    unique_z = np.unique(z_cords)
    drop_start = -1
    for j in range(len(unique_z)-1):
        if unique_z[j]+1 != unique_z[j+1]:
            drop_start = j+1
            break

    if drop_start == -1:
        print(f'could not find drop start for zs {unique_z}, coudl be issue with data')

    top_z, bot_z = unique_z[:drop_start], unique_z[drop_start:]
    top_labels = np.isin(z_cords, top_z).astype(float)
    bot_labels = np.isin(z_cords, bot_z).astype(float)

    print(top_labels)
    print(bot_labels)
    print(bot_labels.shape)




[1. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 1. 1. 1.]
(2048,)
[0. 0. 0. ... 1. 0. 1.]
[1. 1. 1. ... 0. 1. 0.]
(2048,)
[0. 1. 1. ... 1. 1. 1.]
[1. 0. 0. ... 0. 0. 0.]
(2048,)
[0. 1. 0. ... 0. 0. 0.]
[1. 0. 1. ... 1. 1. 1.]
(2048,)
[0. 0. 0. ... 1. 1. 0.]
[1. 1. 1. ... 0. 0. 1.]
(2048,)
[1. 1. 1. ... 1. 0. 1.]
[0. 0. 0. ... 0. 1. 0.]
(2048,)
[0. 0. 0. ... 1. 0. 0.]
[1. 1. 1. ... 0. 1. 1.]
(2048,)
[1. 1. 1. ... 1. 1. 1.]
[0. 0. 0. ... 0. 0. 0.]
(2048,)
[1. 1. 0. ... 1. 1. 1.]
[0. 0. 1. ... 0. 0. 0.]
(2048,)
[0. 1. 1. ... 0. 1. 0.]
[1. 0. 0. ... 1. 0. 1.]
(2048,)
[0. 1. 0. ... 0. 1. 0.]
[1. 0. 1. ... 1. 0. 1.]
(2048,)
[1. 1. 1. ... 1. 1. 1.]
[0. 0. 0. ... 0. 0. 0.]
(2048,)
[1. 0. 0. ... 0. 1. 0.]
[0. 1. 1. ... 1. 0. 1.]
(2048,)
[0. 0. 0. ... 1. 0. 1.]
[1. 1. 1. ... 0. 1. 0.]
(2048,)
[0. 1. 0. ... 0. 0. 0.]
[1. 0. 1. ... 1. 1. 1.]
(2048,)
[0. 1. 1. ... 0. 0. 0.]
[1. 0. 0. ... 1. 1. 1.]
(2048,)
[0. 0. 0. ... 0. 1. 0.]
[1. 1. 1. ... 1. 0. 1.]
(2048,)
[1. 1. 0. ... 0. 0. 0.]
[0. 0. 1. ... 1. 1. 1.]


In [None]:
from proofreader.data.cremi import *
path = '/mnt/home/jberman/sc/proofreader/dataset/cremi'
trueA = read_cremi_volume('A', seg=True, path=path)
trueB = read_cremi_volume('B', seg=True, path=path)
trueC = read_cremi_volume('C', seg=True, path=path)

# A is clean
trueA_test = trueA[:16].copy()
trueA_train = trueA[16:].copy()

trueB_test = trueB[:16].copy()
trueB_train = trueB[16:].copy()


In [None]:
from proofreader.data.augment import Augmentor
from proofreader.data.splitter import SliceDataset
num_slices = 1
radius = 96
context_slices = 3
num_points = 2048
scale = 1000
z_amt = (context_slices*2+num_slices)-1
augmentor = Augmentor(center=True, shuffle=True,
                        normalize=(z_amt, scale, scale))
dataset = SliceDataset([trueB_train], num_slices, radius, context_slices, num_points=num_points, Augmentor=augmentor, truncate_candidates=0, candidate_group=False, scale=scale, verbose=False, allow_multiple=True)
itr = dataset.__iter__()

In [None]:
example = next(itr)
x, y = example
plot_example(x,y=y, lim=0.5)

In [None]:
print(x.shape)
print(torch.min(x[0]), torch.max(x[0]))
print(torch.min(x[1]), torch.max(x[1]))
print(torch.min(x[2]), torch.max(x[2]))