In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import sys
sys.path.append('..')
from data_utils import AsocaDataModule
from models.base import Baseline3DCNN
from models.unet import UNet
import h5py
import nrrd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from data_utils.helpers import get_padding, get_patch_padding, vol2patches, patches2vol, get_volume_pred
import k3d

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = UNet.load_from_checkpoint('../logs/unet-1618346108/version_0/checkpoints/epoch=9-step=7999.ckpt')

In [None]:
batch_size = 20
patch_size = 32
output_dim = patch_size - 2 * model.crop
stride = output_dim
file_id = 39

In [None]:
adm = AsocaDataModule(batch_size=batch_size, 
                      patch_size=patch_size, 
                      stride=stride, 
                      normalize=False, 
                      datapath='../dataset/asoca.hdf5')

In [None]:
vol_dl, vol_meta = adm.volume_dataloader(file_id)

In [None]:
data, header = nrrd.read(f'../dataset/Train/{file_id}.nrrd', index_order='C')
targs, targs_header = nrrd.read(f'../dataset/Train_Masks/{file_id}.nrrd', index_order='C')

In [None]:
def patching_test(data, patch_size, stride):
    data = torch.tensor(data).float()
    padding = get_patch_padding(data.shape, patch_size, stride)
    patches, patched_shape = vol2patches(data, patch_size, stride, padding)

    data_rec = patches2vol(patches.view(patched_shape), patch_size, stride, padding=padding)
    return torch.allclose(data, data_rec)

In [None]:
assert patching_test(data, patch_size, stride)

In [None]:
model = model.to(device)

In [None]:
preds = torch.empty((vol_meta['n_patches'],)+(output_dim,output_dim,output_dim))
cur = 0
for batch in vol_dl:
    x = batch[0].to(device)
    bs = x.shape[0]

    pred = torch.sigmoid(model(x))
    pred = pred.squeeze(1).detach().cpu()
    preds[cur:cur+bs] = pred
    
    cur += bs

In [None]:
preds = get_volume_pred(preds, vol_meta, stride)

In [None]:
assert data.shape == preds.shape

In [None]:
preds[preds<0.5] = -10

In [None]:
data[data<0] = -3000
data[data>350] = -3000

In [None]:
plot = k3d.plot(camera_auto_fit=True, fps=30)

plot += k3d.volume(
    data[::4,::4,::4].astype(np.float32),
#     bounds=[0,vol.shape[0],0,vol.shape[1],0,vol.shape[2]],
    interpolation=False,
    name='input',
    alpha_coef=50,
    samples=600,
    compression=6,
    color_range=[0, 3000],
    color_map=k3d.colormaps.matplotlib_color_maps.Coolwarm,
)

plot += k3d.volume(
    preds[::4,::4,::4].numpy().astype(np.float32),
#     bounds=[0,vol.shape[0],0,vol.shape[1],0,vol.shape[2]],
    interpolation=False,
    name='pred',
    alpha_coef=50,
    samples=600,
    compression=6,
    color_range=[0, 1],
    color_map=k3d.colormaps.matplotlib_color_maps.Greens,
)

plot += k3d.volume(
    targs[::4,::4,::4].astype(np.float32),
#     bounds=[0,vol.shape[0],0,vol.shape[1],0,vol.shape[2]],
    interpolation=False,
    name='target',
    alpha_coef=50,
    samples=600,
    compression=6,
    color_range=[0, 1],
    color_map=k3d.colormaps.matplotlib_color_maps.Oranges,
)


plot.display()

In [None]:
from metrics import dice_score, hausdorff_95

In [None]:
spacing = np.diag(targs_header['space directions'])

In [None]:
spacing

In [None]:
dice_score(res, torch.tensor(targs))

In [None]:
%%time
hausdorff_95(res.contiguous(), torch.tensor(targs).contiguous(), spacing)