In [None]:
import os
import nrrd
import numpy as np
from pathlib import Path
from typing import Tuple, List, Union, Set, Dict
from collections import defaultdict
from tqdm import tqdm
import json

In [None]:
def get_pos_idxs(x:np.ndarray) -> np.ndarray:
    return np.stack(np.nonzero(x)).T

def get_foreground_idxs(targs:np.ndarray) -> np.ndarray:
    return get_pos_idxs(targs)

def get_background_idxs(targs:np.ndarray, heart_mask:np.ndarray) -> np.ndarray:
    background = (1-targs) * (1-heart_mask)
    return get_pos_idxs(1-heart_mask)

def filter_invalid_idxs(idxs:np.ndarray, patch_size:int, vol_shape: Tuple[int,int,int]) -> np.ndarray:
    idxs = idxs.copy()
    fits_left  = np.all((idxs - patch_size / 2) >= 0, axis=1)
    fits_right = np.all((idxs + patch_size / 2 - vol_shape) <= 0, axis=1)
    return idxs[fits_left&fits_right]

def filter_containing_foregroung(targs:np.ndarray, idxs:np.ndarray, patch_size:int, n_samples:int=None, thresh:float=0.0) -> np.ndarray:
    n_samples = n_samples or len(idxs)
    
    res = set()
    while len(res) < n_samples:
        coords = idxs[np.random.choice(len(idxs), n_samples - len(res), replace=False)]
        for coord in coords:
            bbox = get_patch_bbox(coord, patch_size)
            patch = targs[bbox]
            if patch.sum() <= thresh:
                res.add(tuple(coord))
    return np.array(list(res))

def get_patch_bbox(center_coords:np.ndarray, patch_size:int) -> Tuple[slice, slice, slice]:
    bbox = np.array([
        center_coords - patch_size / 2,
        center_coords + patch_size / 2]
    )
    x, y, z = bbox.T.astype(int)
    return slice(x[0],x[1]), slice(y[0],y[1]), slice(z[0],z[1]) 

def get_vol_id(vol_path):
    return int(vol_path.name.split('.')[0])

def get_paths(folder:Union[str, Path], sort_fn=lambda x: x) -> List:
    paths = [ Path(folder, path) for path in os.listdir(Path(folder))]
    paths = sorted(paths, key=sort_fn)
    return paths

def get_vol_paths(vol_dir:str,
                  vol_subdir:str='Train',
                  targ_subdir:str='Train_Masks',
                  heart_mask_subdir='Train_heart_mask') -> List:
    
    vol_paths  = get_paths(Path(vol_dir, vol_subdir), sort_fn=get_vol_id)
    targ_paths = get_paths(Path(vol_dir, targ_subdir), sort_fn=get_vol_id)
    heart_mask_paths = get_paths(Path(vol_dir, heart_mask_subdir), sort_fn=get_vol_id)
    
    vol_ids = [ get_vol_id(path) for path in vol_paths ]
    
    return list(zip(vol_ids, vol_paths, targ_paths, heart_mask_paths))

def sample_hard_coords(targs: np.ndarray, 
                       idxs: np.ndarray, 
                       patch_size:int, 
                       blacklist: Set[List], 
                       n_samples:int,
                       thresh:float=0.001) -> np.ndarray:
    
    res = set()
    while len(res) < n_samples:
        coords = idxs[np.random.choice(len(idxs), n_samples - len(res), replace=False)]
        for coord in coords:
            if tuple(coord) not in blacklist and tuple(coord) not in res:
                bbox = get_patch_bbox(coord, patch_size)
                patch = targs[bbox]
                if patch.sum() <= thresh:
                    res.add(tuple(coord))
    return np.array(list(res))

def get_pos_coords(targs: np.ndarray, patch_size:int, n_samples:int) -> np.ndarray:
    coords = get_foreground_idxs(targs)
    coords = filter_invalid_idxs(coords, patch_size, targs.shape)
    sampled = np.random.choice(len(coords), n_samples, replace=False)
    coords = coords[sampled]
    labels = np.ones((len(coords), 1))
    return np.hstack([coords, labels])

def get_neg_coords(targs:np.ndarray, heart_mask:np.ndarray, patch_size:int, n_samples:int) -> np.ndarray:
    coords = get_background_idxs(targs, heart_mask)
    coords = filter_invalid_idxs(coords, patch_size, targs.shape)
    coords = filter_containing_foregroung(targs, coords, patch_size, n_samples=n_samples)
    labels = np.zeros((len(coords), 1))
    return np.hstack([coords, labels])

def get_vol_hard_mask(vol:np.ndarray, targs:np.ndarray, heart_mask:np.ndarray) -> np.ndarray:
    mask = vol * (1-targs) * (1-heart_mask)
    mask[(
        (mask != 0) & 
        (mask > np.percentile(mask, 0.05)) &
        (mask < np.percentile(mask, 0.95))
    )] = 1
    return mask

def get_hard_neg_coords(vol_hard_mask:np.ndarray,
                        targs:np.ndarray,
                        already_sampled: Set[Tuple[int,int,int]],
                        patch_size:int,
                        n_samples:int):
    
    coords = get_foreground_idxs(vol_hard_mask)
    coords = filter_invalid_idxs(coords, patch_size, vol_hard_mask.shape)
    coords = sample_hard_coords(targs, coords, patch_size, already_sampled, n_samples=n_samples)
    labels = np.zeros((len(coords), 1))
    return np.hstack([coords, labels])

def normalize_vols(vol_paths:List[Union[Path,str]], output_dir:Union[Path, str], stats:Dict):
    os.makedirs(output_dir, exist_ok=True)
    for vol_id, path, _, _ in tqdm(vol_paths):
        vol, _  = nrrd.read(path, index_order='C')
        vol = np.clip(vol, stats['percentile_00_5'], stats['percentile_99_5'])
        vol = (vol - stats['mean']) / stats['std']
        np.save(Path(output_dir, f'{vol_id}.npy'), vol)

def get_patch_coords(vol_paths:List, patch_size:int, n_patches:int=100000):
    assert n_patches % 4 == 0
    res = {}
    n_per_vol = n_patches // len(vol_paths)
    
    for vol_id, vol_path, targ_path, heart_mask_path in tqdm(vol_paths):
        vol           = np.load(vol_path)
        targs, _      = nrrd.read(targ_path, index_order='C')
        heart_mask, _ = nrrd.read(heart_mask_path, index_order='C')          
        
        targs = targs.astype(np.uint8)
        heart_mask = heart_mask.astype(np.uint8)
        
        pos_coords = get_pos_coords(targs, patch_size, n_samples=n_per_vol//2)
        neg_coords = get_neg_coords(targs, heart_mask, patch_size, n_per_vol//4)
        
        already_sampled = set([tuple(x) for x in neg_coords[:,:-1].tolist()])
        vol_hard_mask = get_vol_hard_mask(vol, targs, heart_mask)
        hard_neg_coords = get_hard_neg_coords(vol_hard_mask, targs, already_sampled, patch_size, n_per_vol//4)
        
        res[vol_id] = np.vstack((pos_coords, neg_coords, hard_neg_coords)).astype(int)
    return res

In [None]:
root = '../dataset/raw/ASOCA2020Data/'
processed_subdir = 'processed'

In [None]:
patch_size = 68
n_patches = 100000
valid_split = [1, 9, 13, 19, 22, 28, 38, 39]
stats = {
        'mean': 347.14618,
        'std': 120.35282,
        'percentile_00_5': 95.0,
        'percentile_99_5': 698.0,
}

In [None]:
vol_paths = get_vol_paths(root)

In [None]:
if not Path(root, processed_subdir).is_dir() or os.listdir(Path(root, processed_subdir)) == 0:
    normalize_vols(vol_paths, Path(root, processed_subdir), stats)

In [None]:
vol_paths = get_vol_paths(root, vol_subdir=processed_subdir)

In [None]:
patch_idxs = get_patch_coords(vol_paths, patch_size, n_patches=n_patches)

In [None]:
for split in ['train', 'valid']:
    os.makedirs(Path(root, processed_subdir, split, 'vols'), exist_ok=True)
for vol_id, vol_path, targ_path, _ in vol_paths:
    split = 'valid' if vol_id in valid_split else 'train'
    os.rename(Path(root, processed_subdir, f'{vol_id}.npy'), Path(root, processed_subdir, split, 'vols', f'{vol_id}.npy'))

In [None]:
dataset = {
    'stats' : {
        'mean': 347.14618,
        'std': 120.35282,
        'percentile_00_5': 95.0,
        'percentile_99_5': 698.0,
    },
    'patch_size': patch_size,
    'patch_stride': 1,
    'N': n_patches,
    'vol_meta': {
        k: {
        'split': 'valid' if int(k) in valid_split else 'train',
        'n_patches': len(v),
        'patches': v.tolist()
    }
    for k,v in patch_idxs.items()}
}

In [None]:
with open('../dataset/raw/ASOCA2020Data/processed/dataset.json', 'w') as f:
    json.dump(dataset, f)

In [None]:
for vol_id in range(40):
    for i, patch in enumerate(dataset['vol_meta'][vol_id]['patches']):
        if patch[0] == 0 or patch[1] == 0 or patch[2] == 0:
            print(vol_id, i, patch[-1])

In [None]:
# wrong_label = []
pos_sum = 0
pos_count = 0
neg_sum = 0
neg_count = 0
for vol_id, coords in patch_idxs.items():
    idxs, labels = coords[...,:-1], coords[...,-1]
    mask, _ = nrrd.read(Path(root, 'Train_Masks', f'{vol_id}.nrrd'), index_order='C')
    for i, idx in enumerate(idxs):
        bbox = get_patch_bbox(idx, patch_size)
        patch = mask[bbox]
        if labels[i] == 1:
            pos_count += 1
            pos_sum += patch.sum()
        elif labels[i] == 0:
            neg_count += 1
            neg_sum += patch.sum()
# assert len(wrong_label) == 0

In [None]:
vol = np.load(Path(root, 'processed', '10.npy'))
mask, _ = nrrd.read(Path(root, 'Train_Masks', '10.nrrd'), index_order='C')

In [None]:
dataset['vol_meta'][29]['patches'][-1]

In [None]:
bbox = get_patch_bbox(np.array([64, 203, 459]), patch_size)

In [None]:
patch = vol[bbox]
patch_m = mask[bbox]

In [None]:
import k3d

In [None]:
patch_m.mean()

In [None]:
k3d_volume = k3d.volume(
    patch.astype(np.float32),
    alpha_coef=1000,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
    compression_level=9
)

plot = k3d.plot(camera_auto_fit=True)
plot += k3d_volume

plot.lighting = 2
plot.display()

In [None]:
import sys
sys.path.append('..')
from data_utils.datamodule import AsocaClassificationDataModule
from data_utils.helpers import get_volume_pred
from data_utils.helpers_classification import get_patch_bbox
import torch

In [None]:
adm = AsocaClassificationDataModule(data_dir='../dataset/classification', sourcepath='../dataset/ASOCA2020Data.zip')

In [None]:
adm.prepare_data()

In [None]:
dl = adm.train_dataloader()

In [None]:
n = 10
x = torch.empty(n,68,68,68)
y = torch.empty(n)
meta = []
cur = 0
for i, batch in enumerate(dl):
    if batch[1].item() == 1:
        x[cur], y[cur] = batch[:2]
        meta.append(batch[2])
        cur += 1
    if cur == n: break
    

In [None]:
y

In [None]:
meta

In [None]:
import nrrd

In [None]:
mask, _ = nrrd.read('../dataset/raw/ASOCA2020Data/Train_Masks/3.nrrd', index_order='C')

In [None]:
import json

In [None]:
with open('../dataset/classification/dataset.json', 'r') as f:
    meta = json.load(f)

In [None]:
mask_patch = mask[get_patch_bbox(np.array(meta['vol_meta']['3']['patches'][461][:-1]), 68)]

In [None]:
import numpy as np
import k3d

k3d_volume = k3d.volume(
    x[7].numpy().astype(np.float32),
    alpha_coef=15,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
    compression_level=9
)

plot = k3d.plot(camera_auto_fit=True)
plot += k3d_volume

k3d_volume = k3d.volume(
    mask_patch.astype(np.float32),
    alpha_coef=1000,
#     shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Reds,
    compression_level=9
)

plot += k3d_volume

plot.lighting = 2
plot.display()

In [None]:
y

In [None]:
dl, meta = adm.volume_dataloader(0, batch_size=1)

In [None]:
patches = torch.cat(list(iter(dl)))
volume_rec = get_volume_pred(patches, meta, [128,128,128], [92,92,92], normalize=False); volume_rec.shape

In [None]:
heart_mask, _ = nrrd.read(Path(root, 'Train_heart_mask', '0.nrrd'), index_order='C')

In [None]:
bg = 1 - heart_mask

In [None]:
volume_rec_bg = volume_rec * bg

In [None]:
volume_hard_mask = volume_rec_bg.copy()
volume_hard_mask[(
    (volume_hard_mask != 0) & 
    (volume_hard_mask > np.percentile(volume_hard_mask, 0.05)) &
    (volume_hard_mask < np.percentile(volume_hard_mask, 0.95))
)] = 1

In [None]:
volume_rec_bg[volume_rec_bg == 0] = volume_rec_bg.min()

In [None]:
k3d_volume = k3d.volume(
    volume_hard_mask[::4,::4,::4].astype(np.float32),
    alpha_coef=1000,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
    compression_level=9
)

plot = k3d.plot(camera_auto_fit=True)
plot += k3d_volume

plot.lighting = 2
plot.display()

In [None]:
model = Baseline3DCNN.load_from_checkpoint('/var/scratch/ebekkers/damyan/models/cnn-baseline-epoch=49-step=7109.ckpt', arch='strided')

In [None]:
meta['n_patches']

In [None]:
model.eval()
model.cuda()
res = torch.empty((meta['n_patches'], *[92,92,92]))
for i, x in enumerate(dl):
    preds = model(x.cuda())
    preds = torch.sigmoid(preds).round()
    res[i] = preds.detach().cpu()

In [None]:
pred_rec = get_volume_pred(res, meta, [128,128,128], [92,92,92], normalize=False)

In [None]:
k3d_volume = k3d.volume(
    pred_rec[::4,::4,::4].astype(np.float32),
    alpha_coef=1000,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
    compression_level=9
)

plot = k3d.plot(camera_auto_fit=True)
plot += k3d_volume

plot.lighting = 2
plot.display()

In [None]:
center_coords = np.array([
    [1,2,3],
    [4,5,6],
    [7,8,9]
])

In [None]:
bboxo = np.array([
    center_coords - np.ceil(np.array([5,5,5])/2),
    center_coords + np.ceil(np.array([5,5,5])/2),
]).T

In [None]:
bboxo.shape

In [None]:
[ col for row in bboxo for col in row ]

In [None]:
import sys
sys.path.append('..')
from models.classification.cnn import Baseline3DClassification
from data_utils.datamodule import AsocaClassificationDataModule
from data_utils.helpers_classification import get_patch_bbox
import torch
import numpy as np

In [None]:
adm = AsocaClassificationDataModule(data_dir='../dataset/classification', sourcepath='../dataset/ASOCA2020Data.zip')

In [None]:
bs = 16

In [None]:
dl = adm.val_dataloader(batch_size=bs)

In [None]:
model = Baseline3DClassification.load_from_checkpoint('../wandb/run-20210725_111105-1jyukl95/files/asoca/1jyukl95/checkpoints/epoch=1-step=313.ckpt')

In [None]:
model = torch.nn.DataParallel(model)

In [None]:
model.eval()
model.cuda();

preds = np.empty(1000)
targs = np.empty(1000)

with torch.no_grad():
    for i, (x, t) in enumerate(dl):
        if i >= 100: break
        x = x.cuda()
        y = torch.sigmoid(model(x)).round()
        targs[i*bs:i*bs+len(t)] = t.cpu().squeeze(-1).numpy()
        preds[i*bs:i*bs+len(y)] = y.cpu().squeeze(-1).numpy()

preds = preds[:-8]

targs = targs[:-8]

(targs == preds).mean()

In [None]:
import nrrd
from tqdm import tqdm
import json
from data_utils.helpers_classification import get_foreground_idxs, filter_invalid_idxs

In [None]:
vol, _ = nrrd.read('../dataset/raw/ASOCA2020Data/Train/1.nrrd', index_order='C')
mask, _ = nrrd.read('../dataset/raw/ASOCA2020Data/Train_Masks/1.nrrd', index_order='C')

In [None]:
dims = vol.shape
dims_max = dims - np.array([57,150,150])
dims_min = np.array([57,150,150])

In [None]:
pred_shape = dims_max-dims_min

targ_center = np.array([112, 239, 142])

targs = mask[get_patch_bbox(targ_center, 68)]

In [None]:
targs = mask[
    dims_min[0]:dims_max[0],
    dims_min[1]:dims_max[1],
    dims_min[2]:dims_max[2],
]

In [None]:
preds = np.load('../class_vol_preds_136_6201.npy')

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(preds.sum(axis=1)[:-14].reshape(-1, 28))

preds = np.concatenate((preds.flatten(), np.zeros(32))).reshape((68,68,68))

In [None]:
voxels = np.stack(np.meshgrid(
    np.arange(dims_min[0], dims_max[0]),
    np.arange(dims_min[1], dims_max[1]),
    np.arange(dims_min[2], dims_max[2]),
    indexing='ij'
)).T.reshape((-1,3))

In [None]:
np.stack(np.meshgrid(
    np.arange(dims_min[0], dims_max[0]),
    np.arange(dims_min[1], dims_max[1]),
    np.arange(dims_min[2], dims_max[2]),
    indexing='ij'
)).shape

In [None]:
pred_shape

In [None]:
preds = np.concatenate((preds.flatten(), np.zeros(136))).reshape((212,69,212))

In [None]:
(preds * targs).mean()

In [None]:
import k3d
k3d_volume = k3d.volume(
    preds.transpose(1,0,2).astype(np.float32),
    alpha_coef=1000,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
    compression_level=9
)

plot = k3d.plot(camera_auto_fit=True)
plot += k3d_volume

k3d_volume = k3d.volume(
    targs.astype(np.float32),
    alpha_coef=1000,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Greens,
    compression_level=9
)


# plot += k3d_volume


plot.lighting = 2
plot.display()

In [None]:
8**3 / 68**3

In [None]:
with open('../dataset/classification/dataset.json', 'r') as f:
    meta = json.load(f)

In [None]:
stats = meta['stats']

In [None]:
vol = np.clip(vol, stats['percentile_00_5'], stats['percentile_99_5'])
vol = (vol - stats['mean']) / stats['std']

In [None]:
dims = np.array(vol.shape)

In [None]:
dims_max = dims - np.array([57,150,150])

In [None]:
dims_min = np.array([57,150,150])

In [None]:
dims_min, dims_max

In [None]:
meta['vol_meta']['1']

In [None]:
center = np.array([112, 239, 142])
left = center - 34
right = center + 34

In [None]:
voxels = np.stack(np.meshgrid(
    np.arange(left[0], right[0]),
    np.arange(left[1], right[1]),
    np.arange(left[2], right[2])
)).T.reshape((-1,3))

In [None]:
voxels = np.stack(np.meshgrid(
    np.arange(dims_min[0], dims_max[0]),
    np.arange(dims_min[1], dims_max[1]),
    np.arange(dims_min[2], dims_max[2])
)).T.reshape((-1,3))

In [None]:
bs = 600

In [None]:
remainder = len(voxels) - bs * (len(voxels)//bs)

voxels = voxels[:-remainder]

voxels = voxels.reshape(-1,bs,3)

In [None]:
voxels.shape

In [None]:
preds = np.empty_like(voxels.shape[:2])

In [None]:
with torch.no_grad():
    for i in tqdm(range(0, voxels.shape[0]*voxels.shape[1] // bs)):
        x = np.empty((bs,68,68,68))
        for j in range(bs):
            x[j] = vol[get_patch_bbox(voxels[i][j], 68)]
        x = torch.from_numpy(x).float().unsqueeze(1).cuda()
        y = torch.sigmoid(model(x)).round()
        preds[i] = y.cpu().squeeze(-1).numpy()
        if i % 500 == 0: np.save(f'class_vol_preds_{i}.npy', preds)

In [None]:
preds_vol = np.concatenate((preds, np.zeros(remainder))).reshape(dims_max-dims_min)

In [None]:
import k3d
k3d_volume = k3d.volume(
    preds_vol.astype(np.float32),
    alpha_coef=1000,
    shadow='dynamic',
    samples=600,
    color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
    compression_level=9
)

plot = k3d.plot(camera_auto_fit=True)
plot += k3d_volume

plot.lighting = 2
plot.display()

In [None]:
remainder

In [None]:
np.save(f'../class_vol_preds_{i}.npy', preds)

In [None]:
preds = np.load('../class_vol_preds_5167.npy')

In [None]:
preds.mean()

In [None]:
dims_max - dims_min

In [None]:
3100800 / (94*330*330)

In [None]:
preds.reshape((94,330,-1))

In [None]:
i

In [None]:
preds.sum()

In [None]:
np.save('preds.npy', preds)

In [None]:
patch = vol[get_patch_bbox(voxels[-1][120], 68)]

In [None]:
pred = model(torch.from_numpy(patch).float().unsqueeze(0).unsqueeze(0).cuda())

In [None]:
torch.sigmoid(pred).item()

In [None]:
voxels.shape

In [None]:
preds