In [None]:
from spottunet.dataset.cc359 import *
from spottunet.split import one2all
from spottunet.torch.module.unet import UNet2D
from spottunet.utils import sdice
from dpipe.im.metrics import dice_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler 

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from PIL import Image

from monai import transforms as T
from monai.transforms import Compose, apply_transform
from fastprogress.fastprogress import master_bar, progress_bar

import json
import nibabel as nib
import pandas as pd
import numpy as np
from scipy import ndimage
from dpipe.im.shape_ops import zoom
import cv2
import os
import gc
from collections import defaultdict
from pathlib import Path
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt

#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

DEBUG=True

### Config & Logging

In [None]:
import wandb
from configs.config import CFG
from utils import *

def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

def class2str(f):
    return [[name, getattr(f, name)] for name in dir(f) if not name.startswith('__')]

def write_config(CFG):
    with open(f"{CFG.results_dir}/config.txt", "w") as f:
        for n,v in class2str(CFG):
            f.write(f"{n}={v} \n")
    f.close()

In [None]:
%reload_ext autoreload
%autoreload 2
from dataset.dataloader import *
from dataset.loader import PrefetchLoader, fast_collate
from dataset.dataloader_utils import *

from dataset.mixup import FastCollateMixup
from dataset.augment import get_transforms, get_test_transforms

from scheduler import LinearWarmupCosineAnnealingLR

In [None]:
if DEBUG:
    from dataset.mixup import *

    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]
    train_df = cc359_df.iloc[split[0][0]].reset_index()
    valid_df = cc359_df.iloc[split[0][1]].reset_index()
    

    sa_x,sa_y = create_shared_arrays(CFG,train_df,root_dir=CFG.dataset_path)
    valid_sa_x,valid_sa_y = create_3d_shared_arrays(CFG,valid_df,root_dir=CFG.dataset_path)
    fcm_ma = create_shared_fcm_masks(CFG,train_df,root_dir=CFG.dataset_path)


In [None]:
CFG.fcm_mask = "gm"
if DEBUG:
    train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_transforms("contrast_fcm"),
                                  mode="train", cache=True, cached_x=sa_x, cached_y=sa_y, cached_fcm_mask=fcm_ma)
    valid_dataset = CC359_Dataset(CFG,df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_test_transforms(),
                                  mode="val", cache=CFG.cache, cached_x=valid_sa_x,cached_y=valid_sa_y)

    for x1,y in [train_dataset[1400]]:
        
        fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
        #out = transform(image=x.numpy().squeeze(),mask=y.numpy().squeeze())
        #normed_img = out["image"]
        inp = ax.imshow(x1.numpy().squeeze())
        plt.colorbar(inp, ax=ax)
        

        #tfms = T.MaskIntensity()
        #fcm_masked_img = tfms(x.numpy().squeeze(),mask_data=y.numpy().squeeze())
        #inv_masked_img = tfms(x.numpy().squeeze(),mask_data=~(y.numpy().astype(bool).squeeze()))
        #Need to norm the non masked input too

        bx.imshow(y.numpy().squeeze())
        
        plt.show()
        #plt.hist(x.cpu().reshape(-1).numpy())
        #plt.show()
    for x4,y, id_ in [valid_dataset[1]]:
        
        print(x4.shape,x4.min(),x4.max(),y.max())
        fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
        transform = get_transforms("default")
        #out = transform(image=x.numpy().squeeze(),mask=y.numpy().squeeze())
        #normed_img = out["image"]
        inp = ax.imshow(x4[130].numpy().squeeze())
        plt.colorbar(inp, ax=ax)
        

        tfms = T.MaskIntensity()
        fcm_masked_img = tfms(x4[130].numpy().squeeze(),mask_data=y[130].numpy().squeeze())
        inv_masked_img = tfms(x4[130].numpy().squeeze(),mask_data=~(y[130].numpy().astype(bool).squeeze()))
        #Need to norm the non masked input too

        bx.imshow(y[130].numpy().squeeze())
        
        plt.show()
        #plt.hist(x.cpu().reshape(-1).numpy())
        #plt.show()

In [None]:
CFG.fcm_mask = None
if DEBUG:
    train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_transforms("default"),
                                  mode="train", cache=True, cached_x=sa_x, cached_y=sa_y, cached_fcm_mask=fcm_ma)
    valid_dataset = CC359_Dataset(CFG,df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_test_transforms(),
                                  mode="val", cache=CFG.cache, cached_x=valid_sa_x,cached_y=valid_sa_y)

    for x2,y in [train_dataset[1400]]:
        
    
        fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
        #out = transform(image=x.numpy().squeeze(),mask=y.numpy().squeeze())
        #normed_img = out["image"]
        inp = ax.imshow(x2.numpy().squeeze())
        plt.colorbar(inp, ax=ax)
        

        tfms = T.MaskIntensity()
        fcm_masked_img = tfms(x.numpy().squeeze(),mask_data=y.numpy().squeeze())
        inv_masked_img = tfms(x.numpy().squeeze(),mask_data=~(y.numpy().astype(bool).squeeze()))
        #Need to norm the non masked input too

        bx.imshow(y.numpy().squeeze())
        
        plt.show()
        #plt.hist(x.cpu().reshape(-1).numpy())
        #plt.show()
    for x3,y, id_ in [valid_dataset[1]]:
        
        print(x3.shape,x3.min(),x3.max(),y.max())
        fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
        transform = get_transforms("default")
        #out = transform(image=x.numpy().squeeze(),mask=y.numpy().squeeze())
        #normed_img = out["image"]
        inp = ax.imshow(x3[130].numpy().squeeze())
        plt.colorbar(inp, ax=ax)
        

        tfms = T.MaskIntensity()
        fcm_masked_img = tfms(x3[130].numpy().squeeze(),mask_data=y[130].numpy().squeeze())
        inv_masked_img = tfms(x3[130].numpy().squeeze(),mask_data=~(y[130].numpy().astype(bool).squeeze()))
        #Need to norm the non masked input too

        bx.imshow(y[130].numpy().squeeze())
        
        plt.show()
        #plt.hist(x.cpu().reshape(-1).numpy())
        #plt.show()

In [None]:
fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
#out = transform(image=x.numpy().squeeze(),mask=y.numpy().squeeze())
#normed_img = out["image"]
tmp_diff = x1.numpy().squeeze() - x2.numpy().squeeze()
tmp_diff2 = x3.numpy().squeeze() - x4.numpy().squeeze()
print(tmp_diff.min(),tmp_diff.max())
print(tmp_diff2.min(),tmp_diff2.max())
inp = ax.imshow(tmp_diff)
plt.colorbar(inp, ax=ax)
inp = bx.imshow(tmp_diff2[130])
plt.colorbar(inp, ax=bx)

In [None]:
if DEBUG:
    valid_df = cc359_df.iloc[split[0][1]].reset_index()

    valid_sa_x,valid_sa_y = create_3d_shared_arrays(CFG,valid_df,root_dir=CFG.dataset_path)
    valid_dataset = CC359_Dataset(CFG,df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=None, #get_test_transforms(),
                                  mode="val", cache=CFG.cache, cached_x=valid_sa_x,cached_y=valid_sa_y)


In [None]:
if DEBUG:
    from dataset.rand_augment import *
    aa_params = dict(
                translate_const=int(256 * 0.20),
                img_mean=tuple([0]),
                #interpolation=str_to_pil_interp(interpolation)
                )
    for x,y,id_ in valid_dataset:

        print(x.min(),x.max(),y.max())
        fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
        inp = ax.imshow(x.cpu().squeeze().numpy()[130])
        plt.colorbar(inp, ax=ax)
        bx.imshow(y.cpu().squeeze().numpy()[130])
        plt.show()
        fig, (ax,bx) = plt.subplots(1,2, figsize=(20,10))
        tfm_fct = solarize
        tmp = Image.fromarray(x.cpu().squeeze().numpy()[130]*255).convert('L')
        tmp2 = Image.fromarray(y.cpu().squeeze().numpy()[130]*255).convert('L')
        tmp, tmp2 = tfm_fct(tmp,tmp2, _solarize_level_to_arg(9, aa_params))
        inp = ax.imshow(np.array(tmp))
        plt.colorbar(inp, ax=ax)
        bx.imshow(y.cpu().squeeze().numpy()[130])
        plt.show()


### Training

In [None]:
from trainer import Trainer
from dataset.loader import *

In [None]:
def run_fold(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    #wandb.tensorboard.patch(root_logdir=result_dir+"/logs")
    run = wandb.init(project="domain_shift",
                     group=CFG.model_name,
                     name=f"mode_{str(fold)}",
                     job_type=CFG.exp_name,
                     config=class2dict(CFG),
                     reinit=True,
                     sync_tensorboard=True)
    
    writer = SummaryWriter(log_dir=result_dir+"/logs")
    write_config(CFG)
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
    
    mixup_fn = None
    if CFG.mixup:
        mixup_args = dict(
            mixup_alpha=CFG.mixup, cutmix_alpha=CFG.cutmix, cutmix_minmax=None,
            prob=CFG.mixup_prob, switch_prob=CFG.mixup_switch_prob, mode='batch',
            label_smoothing=CFG.smoothing, num_classes=2)
        collate_fn = FastCollateMixup(**mixup_args)
    elif fast_collate:
        collate_fn = fast_collate
        
    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]


    train_df = cc359_df.iloc[split[fold][0]].reset_index()
    valid_df = cc359_df.iloc[split[fold][1]].reset_index()
    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    sa_x = None; sa_y = None
    valid_sa_x = None; valid_sa_y = None
    if CFG.cache:
        print("Caching Train Data ...")
        sa_x,sa_y = create_shared_arrays(CFG,train_df,root_dir=CFG.dataset_path)
        valid_sa_x,valid_sa_y = create_3d_shared_arrays(CFG,valid_df,root_dir=CFG.dataset_path)
    train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_transforms(CFG.tfms),
                                  mode="train", cache=CFG.cache, cached_x=sa_x, cached_y=sa_y)
    
    valid_dataset = CC359_Dataset(CFG,df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_test_transforms(),
                                  mode="val", cache=CFG.cache, cached_x=valid_sa_x,cached_y=valid_sa_y)
    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                 voxel_spacing=CFG.voxel_spacing,
                                  transforms=get_test_transforms(),mode="test", cache=False)
    
    train_loader = PrefetchLoader(DataLoader(train_dataset,
                                              batch_size=CFG.bs,
                                              shuffle=True,
                                              num_workers=CFG.num_workers,
                                              sampler=None,
                                              collate_fn=collate_fn,
                                              pin_memory=False,
                                              drop_last=True),
                                  fp16=True)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=1,shuffle=False,
                              num_workers=1,pin_memory=False)
    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=1,pin_memory=False)

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model.to(CFG.device)
    
    optim_dict = dict(optim=CFG.optim,lr=CFG.lr,weight_decay=CFG.wd)
    optimizer = get_optimizer(model, **optim_dict)
    
    #scheduler = CFG.scheduler(optimizer, lr_lambda=lambda epoch: CFG.scheduler_multi_lr_fact )
    if CFG.scheduler==torch.optim.lr_scheduler.OneCycleLR:
        scheduler = CFG.scheduler(optimizer, max_lr=CFG.lr, steps_per_epoch=len(train_loader), epochs=CFG.epochs)
    elif CFG.scheduler==LinearWarmupCosineAnnealingLR:
        scheduler = LinearWarmupCosineAnnealingLR(optimizer,
                                                    warmup_epochs=CFG.warmup_epochs,
                                                    max_epochs=CFG.epochs,
                                                    warmup_start_lr=CFG.warmup_lr)
    else:
        print("no scheduler selected")
    criterion = CFG.crit    



    """from torch_lr_finder import LRFinder
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    lr_finder.range_test(train_loader, end_lr=100, num_iter=100)
    lr_finder.plot() # to inspect the loss-learning rate graph
    lr_finder.reset()"""

    
    trainer = Trainer(CFG,
                      model=model, 
                      device=CFG.device, 
                      optimizer=optimizer,
                      scheduler=scheduler,
                      criterion=criterion,
                      writer=writer,
                      fold=fold,
                      max_norm=CFG.max_norm,
                      mixup_fn=mixup_fn)
    
    history = trainer.fit(
            CFG.epochs, 
            train_loader, 
            valid_loader, 
            f"{result_dir}/", 
            CFG.epochs,
        )
    trainer.test(test_dataloader,result_dir)
    td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
    #writer.add_hparams(class2dict(CFG),td_sdice)
    wandb.log(td_sdice)
    writer.close()
    run.finish()

    del trainer
    del train_loader
    del valid_loader
    del train_dataset
    del valid_dataset
    gc.collect()

In [None]:
for fold in CFG.fold:
    run_fold(fold)

### Test

In [None]:
from hausdorff import hausdorff_distance
from multiprocessing import Pool


def test_run(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model.load_state_dict(torch.load(f"{result_dir}/e_39.pth")["model_state_dict"])
    #model.load_state_dict(torch.load(f"baseline_results/baseline_focal_lovasz_adam_default/mode_{fold}/e_39.pth")["model_state_dict"])
    
    model.to(CFG.device)


    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]

    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=get_test_transforms(),
                                  mode="test", cache=False)

    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=0,pin_memory = False)

    bs=256
    model.eval()
    metrics = {'sdice_score': sdice, 'dice_score': dice_score}

    results = defaultdict(dict)
    for step,(x,y,_id) in progress_bar(enumerate(test_dataloader, 1), len(test_dataloader)):
        y = y.squeeze(0)#.to(torch.int64)
        with torch.no_grad():
            x = x[0].to(CFG.device); _id=_id[0]
            c,h,w = x.shape
            outputs = []
            for idx in range(0,c,bs):
                out = model(x[idx:min(idx+bs,c)].unsqueeze(1))
                out = out.squeeze().sigmoid().cpu().detach().numpy()
                if len(out.shape)==2: out = out[None,:,:]
                outputs.append(out)
            outputs = np.concatenate(outputs)#.transpose(1,0,2,3)
            outputs = (outputs > .5).squeeze()
            y = (y > .5).numpy().squeeze()

        results['sdice_score'][_id] = sdice(y, outputs, CFG.voxel_spacing,CFG.sdice_tolerance)
        results['dice_score'][_id]  = dice_score(y, outputs)
        results['hausdorff'][_id]  = np.mean([hausdorff_distance(y_,o_) for o_,y_ in zip(np.uint8(outputs), np.uint8(y))])
        #def mp_hausdorf(i):
        #    return hausdorff_distance(np.uint8(outputs[i,...]),np.uint8(y[i,...]))

        #with Pool(5) as p:
        #    print(p.map(mp_hausdorf, np.arange(outputs.shape[0])))
        #break
    #with open(os.path.join(result_dir, 'sdice_score' + '.json'), 'w') as f:
    #    json.dump(results['sdice_score'], f, indent=0)
    #with open(os.path.join(result_dir, 'dice_score'+ '.json'), 'w') as f:
    #    json.dump(results['dice_score'], f, indent=0)
    with open(os.path.join(result_dir, 'hausdorff'+ '.json'), 'w') as f:
        json.dump(results['hausdorff'], f, indent=0)
    td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
    return td_sdice   

#for fold in CFG.fold:
CFG.results_dir = "baseline_results/baseline_focal_lovasz_adam_default/"
for fold in [0,2,3,4,5]:
    test_run(fold)

In [None]:
fold = 1
result_dir = CFG.results_dir + "/mode_"+str(fold)
td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
td_sdice

In [None]:
get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),3)

In [None]:
get_target_domain_metrics(CFG.dataset_path,Path("baseline_results/baseline_lovasz_default"),3)

### Predict

In [None]:
from monai import transforms as T

ttas = [None,
        #T.ShiftIntensityd(keys=("image"), offset=-0.1),
        #T.ShiftIntensityd(keys=("image"), offset=0.1), 
        #T.Flipd(keys=("image","seg"), spatial_axis=0, allow_missing_keys=False)
       ]

norm = T.NormalizeIntensityd(keys=("image"), subtrahend=(0.122), divisor=(0.224))

def apply_tta(img, segm, tta):
    if tta!=None:
        tfms = T.Compose([tta,
                          #norm
                         ])
    else:
        tfms = T.Compose([
                          #norm
        ])
        
    img = img[None,:,:,:]
    segm = segm[None,:,:,:]
    tfmed = tfms({'image':img, 'seg':segm})       
    img = tfmed['image'].squeeze()
    segm = tfmed['seg'].squeeze()
        

    
    return img, segm

In [None]:
fold=3
slice_index = 130
CFG.transpose = (2,0,1)
CFG.fcm_mask = None #"csf"
result_dir = "baseline_results/baseline_focal_lovasz_adam_default" + f"/mode_{fold}/"
output_dir =  "predictions/baseline_focal_lovasz_adam_default/" + f"/mode_{fold}/"
os.makedirs(output_dir, exist_ok=True)

cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
model.load_state_dict(torch.load(f"{result_dir}/e_39.pth")["model_state_dict"])
model.to(CFG.device)


seed = 0xBadCafe
val_size = 4
n_experiments = len(cc359_df.fold.unique())
split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]
test_df  = cc359_df.iloc[split[fold][2]].reset_index()

test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                             voxel_spacing=CFG.voxel_spacing,transforms=get_transforms("default"),
                             mode="test", cache=False)

test_dataloader = DataLoader(test_dataset, 
                              batch_size=1,shuffle=False,
                              num_workers=0,pin_memory = False)

def id_to_scanner(id):
    df = test_df[test_df['id']==id]
    return df['tomograph_model'].values[0] + str(df['tesla_value'].values[0])

bs=16    
model.eval()
import copy
for step,(x,y,_id) in progress_bar(enumerate(test_dataloader, 1), len(test_dataloader)):
    with torch.no_grad():
        _id=_id[0]
        if _id=="CC0002":
            for tta in ttas:
                x_c, y_c = copy.deepcopy(x), copy.deepcopy(y)
                x_c, y_c = apply_tta(x_c,y_c,tta)
                x_c = x_c.to(CFG.device); 
                c,h,w = x_c.shape
                outputs = []
                for idx in range(0,c,bs):
                    out = model(x_c[idx:min(idx+bs,c)].unsqueeze(1))
                    out = out.squeeze().sigmoid().cpu().detach().numpy()
                    if len(out.shape)==2: out = out[None,:,:]
                    outputs.append(out)
                logits = np.concatenate(outputs)
                outputs = (logits > .5)
                surface_dice = sdice(y_c.squeeze().numpy().astype(bool), outputs, CFG.voxel_spacing,CFG.sdice_tolerance)
                fig, (ax, bx, cx, dx) = plt.subplots(1,4, figsize=(20,5))
                inp = ax.imshow(x_c[slice_index].cpu().numpy())
                ax.set_title(str(_id)+" "+id_to_scanner(_id))
                fig.colorbar(inp, ax=ax)
                bx.imshow(outputs[slice_index])
                bx.set_title(str(round(surface_dice, 4)))
                cx.imshow(y_c[slice_index].squeeze())
                dx.imshow(y_c[slice_index].squeeze()-outputs[slice_index])
                plt.show()
            np.save(output_dir+f"/{_id}.npy",outputs)

In [None]:
fold=3
result_dir = "baseline_results/baseline_lovasz_sideways_default" + f"/mode_{fold}/"
output_dir =  "predictions/baseline_lovasz_sideways_default/" + f"/mode_{fold}/"
os.makedirs(output_dir, exist_ok=True)

cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
model.load_state_dict(torch.load(f"{result_dir}/e_39.pth")["model_state_dict"])
model.to(CFG.device)


seed = 0xBadCafe
val_size = 4
n_experiments = len(cc359_df.fold.unique())
split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]

test_df  = cc359_df.iloc[split[fold][2]].reset_index()

test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                             voxel_spacing=CFG.voxel_spacing,
                              transforms=None,mode="test", cache=False)

test_dataloader = DataLoader(test_dataset, 
                              batch_size=1,shuffle=False,
                              num_workers=0,pin_memory = False)

bs=16    
model.eval()
import copy
for step,(x,y,_id) in progress_bar(enumerate(test_dataloader, 1), len(test_dataloader)):
    with torch.no_grad():
        
        for tta in ttas:
            x_c, y_c = copy.deepcopy(x), copy.deepcopy(y)
            x_c, y_c = apply_tta(x_c,y_c,tta)
            x_c = x_c.to(CFG.device); _id=_id[0]
            c,h,w = x_c.shape
            outputs = []
            for idx in range(0,c,bs):
                out = model(x_c[idx:min(idx+bs,c)].unsqueeze(1))
                out = out.squeeze().cpu().detach().numpy()
                if len(out.shape)==2: out = out[None,:,:]
                outputs.append(out)
            logits = np.concatenate(outputs)
            outputs = (logits > .5)
            surface_dice = sdice(y_c.squeeze().numpy().astype(bool), outputs, CFG.voxel_spacing,CFG.sdice_tolerance)
            fig, (ax, bx, cx, dx) = plt.subplots(1,4, figsize=(20,5))
            inp = ax.imshow(x_c[130].cpu().numpy())
            #ax.set_title(str(tta))
            fig.colorbar(inp, ax=ax)
            lgts = bx.imshow(logits[130])
            fig.colorbar(lgts, ax=bx)
            cx.imshow(outputs[130])
            dx.imshow(y_c[130].squeeze())
            plt.title(str(round(surface_dice, 4)))
            plt.show()
    #np.save(output_dir+f"/{_id}.npy",outputs)

### Dice Scores

In [None]:
meta = pd.read_csv(f"meta.csv",delimiter=",", index_col='id')
meta.head()
import seaborn as sns


def load_json(path):
    """Load the contents of a json file."""
    with open(path, 'r') as f:
        return json.load(f)
    
def get_sdices(path_base, path_oracle=None):
    records = []
    for s in sorted(meta['fold'].unique()):
        res_row = {}

        # one2all results:
        try:
            sdices = load_json(path_base / f'mode_{s}/sdice_score.json')
        except:
            sdices = dict()
        for t in sorted(set(meta['fold'].unique()) - {s}):
            df_row = meta[meta['fold'] == t].iloc[0]
            target_name = df_row['tomograph_model'] + str(df_row['tesla_value'])

            ids_t = meta[meta['fold'] == t].index
            res_row[target_name] = np.mean([sdsc for _id, sdsc in sdices.items() if _id in ids_t])
        df_row = meta[meta['fold'] == s].iloc[0]
        source_name = df_row['tomograph_model'] + str(df_row['tesla_value'])
        sdices = {}
        if path_oracle:
            for n_val in range(3):
                try:
                    sdices = {**sdices,
                              **load_json(path_oracle / f'mode_{s * 3 + n_val}/sdice_score.json')}
                except:
                    None
            res_row[source_name] = np.mean(list(sdices.values()))

        res_row[' '] = source_name
        records.append(res_row)
    return records

def get_dice(path_base, path_oracle=None):
    records = []
    for s in sorted(meta['fold'].unique()):
        res_row = {}

        # one2all results:
        try:
            sdices = load_json(path_base / f'mode_{s}/dice_score.json')
        except:
            sdices = dict()
        #sdices = dict(sorted(sdices.items()))
        for t in sorted(set(meta['fold'].unique()) - {s}):
            df_row = meta[meta['fold'] == t].iloc[0]
            target_name = df_row['tomograph_model'] + str(df_row['tesla_value'])

            ids_t = meta[meta['fold'] == t].index
            res_row[target_name] = np.mean([sdsc for _id, sdsc in sdices.items() if _id in ids_t])
        df_row = meta[meta['fold'] == s].iloc[0]
        source_name = df_row['tomograph_model'] + str(df_row['tesla_value'])
        sdices = {}
        if path_oracle:
            for n_val in range(3):
                try:
                    sdices = {**sdices,
                              **load_json(path_oracle / f'mode_{s * 3 + n_val}/dice_score.json')}
                except:
                    None
            res_row[source_name] = np.mean(list(sdices.values()))

        res_row[' '] = source_name
        records.append(res_row)
    return records

def get_hausdorff(path_base, path_oracle=None):
    records = []
    for s in sorted(meta['fold'].unique()):
        res_row = {}

        # one2all results:
        try:
            sdices = load_json(path_base / f'mode_{s}/hausdorff.json')
        except:
            sdices = dict()
        #sdices = dict(sorted(sdices.items()))
        for t in sorted(set(meta['fold'].unique()) - {s}):
            df_row = meta[meta['fold'] == t].iloc[0]
            target_name = df_row['tomograph_model'] + str(df_row['tesla_value'])

            ids_t = meta[meta['fold'] == t].index
            res_row[target_name] = np.mean([sdsc for _id, sdsc in sdices.items() if _id in ids_t])
        df_row = meta[meta['fold'] == s].iloc[0]
        source_name = df_row['tomograph_model'] + str(df_row['tesla_value'])
        sdices = {}
        if path_oracle:
            for n_val in range(3):
                try:
                    sdices = {**sdices,
                              **load_json(path_oracle / f'mode_{s * 3 + n_val}/sdice_score.json')}
                except:
                    None
            res_row[source_name] = np.mean(list(sdices.values()))

        res_row[' '] = source_name
        records.append(res_row)
    return records

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd

from dpipe.io import load

path_base = Path('baseline_results/baseline_focal_lovasz_adam_default')
oracle_path = Path('oracle_results/focal_lovasz_adam_default_None')

records = get_sdices(path_base, oracle_path)
df = pd.DataFrame.from_records(records, index=' ')
df[df.index]
print(df.mean().mean())

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df[df.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("Surface Dice Score")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()


records = get_dice(path_base,oracle_path)
df = pd.DataFrame.from_records(records, index=' ')
df[df.index]
print(df.mean().mean())

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df[df.index], annot=True, vmin=0.8,vmax=1.0, annot_kws={'size': 15})
plt.title("Dice Score")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

records = get_sdices(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice TD mean: ",tmp.mean().mean())
records = get_sdices(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice Oracle mean: ",tmp.mean().mean())

records = get_dice(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice TD mean: ",tmp.mean().mean())
records = get_dice(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice Oracle mean: ",tmp.mean().mean())

In [None]:
path_base = Path('baseline_results/baseline_focal_lovasz_adam_default')
oracle_path = Path('oracle_results/focal_lovasz_adam_default_None')

records = get_hausdorff(path_base)
df_hd = pd.DataFrame.from_records(records, index=' ')
df_hd[df_hd.index]

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df_hd[df_hd.index], annot=True, vmin=1,vmax=5.0, annot_kws={'size': 15}, cmap="Blues")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

In [None]:
path_base = Path('baseline_results/baseline_lovasz_nostopping_default')
records = get_sdices(path_base)
df3 = pd.DataFrame.from_records(records, index=' ')
df3[df3.index]

In [None]:
path_base = Path('baseline_results/baseline_focal_lovasz_adam_rand_aug_default_v1')
oracle_path = Path('oracle_results/focal_lovasz_adam_rand_aug_default')
records = get_dice(path_base, oracle_path)
df2 = pd.DataFrame.from_records(records, index=' ')
df2[df2.index]
print(df2.mean().mean())

fig = plt.figure(figsize=(10, 8))
bx = sns.heatmap(df2[df2.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("Dice Score")
plt.show()


records = get_sdices(path_base, oracle_path)
df2 = pd.DataFrame.from_records(records, index=' ')
df2[df2.index]
print(df2.mean().mean())

fig = plt.figure(figsize=(10, 8))
bx = sns.heatmap(df2[df2.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("Surface Dice Score")
plt.show()

records = get_sdices(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice TD mean: ",tmp.mean().mean())
records = get_sdices(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice Oracle mean: ",tmp.mean().mean())

records = get_dice(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice TD mean: ",tmp.mean().mean())
records = get_dice(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice Oracle mean: ",tmp.mean().mean())

In [None]:
path_base = Path('predictions/baseline_focal_lovasz_multiview_ensemble')
oracle_path = Path('oracle_results/baseline_focal_lovasz_multiview_ensemble')


records = get_dice(path_base, oracle_path)
df2 = pd.DataFrame.from_records(records, index=' ')
df2[df2.index]
print(df2.mean().mean())

fig = plt.figure(figsize=(10, 8))
bx = sns.heatmap(df2[df2.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("Dice Scores")
plt.show()

records = get_sdices(path_base, oracle_path)
df_hd = pd.DataFrame.from_records(records, index=' ')
df_hd[df_hd.index]

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df_hd[df_hd.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("Surface Dice Scores")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

records = get_sdices(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice TD mean: ",tmp.mean().mean())
records = get_sdices(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice Oracle mean: ",tmp.mean().mean())

records = get_dice(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice TD mean: ",tmp.mean().mean())
records = get_dice(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice Oracle mean: ",tmp.mean().mean())

In [None]:
path_base = Path('predictions/baseline_focal_lovasz_multiview_ensemble')
oracle_path = Path('oracle_results/baseline_focal_lovasz_multiview_ensemble')

records = get_sdices(path_base, oracle_path)
df_hd = pd.DataFrame.from_records(records, index=' ')
df_hd[df_hd.index]

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df_hd[df_hd.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

records = get_sdices(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice TD mean: ",tmp.mean().mean())
records = get_sdices(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice Oracle mean: ",tmp.mean().mean())

records = get_dice(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice TD mean: ",tmp.mean().mean())
records = get_dice(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice Oracle mean: ",tmp.mean().mean())

In [None]:
path_base = Path('baseline_results/baseline_focal_lovasz_SGD_None_None')
records = get_dice(path_base)
df4 = pd.DataFrame.from_records(records, index=' ')
df4[df4.index]
print(df4.mean().mean())

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df4[df4.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("SGD Baseline")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

records = get_sdices(path_base)
df4 = pd.DataFrame.from_records(records, index=' ')
df4[df4.index]
print(df4.mean().mean())

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df4[df4.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("SGD Baseline")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

records = get_sdices(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice TD mean: ",tmp.mean().mean())
records = get_sdices(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("Sdice Oracle mean: ",tmp.mean().mean())

records = get_dice(path_base)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice TD mean: ",tmp.mean().mean())
records = get_dice(None,oracle_path)
tmp = pd.DataFrame.from_records(records, index=' ')
print("dice Oracle mean: ",tmp.mean().mean())

In [None]:
path_base = Path('baseline_results/baseline_focal_lovasz_SGD_default_None')
records = get_sdices(path_base)
df5 = pd.DataFrame.from_records(records, index=' ')
df5[df5.index]

fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df5[df5.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("SGD Baseline")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

In [None]:
path_base = Path('baseline_results/baseline_focal_lovasz_SGD_rand_aug_default')
records = get_dice(path_base)
df5 = pd.DataFrame.from_records(records, index=' ')
df5[df5.index]
print(df5.mean().mean())
fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df5[df5.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("SGD Rand Aug.")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

records = get_sdices(path_base)
df5 = pd.DataFrame.from_records(records, index=' ')
df5[df5.index]
print(df5.mean().mean())
fig = plt.figure(figsize=(10, 8))
ax = sns.heatmap(df5[df5.index], annot=True, vmin=0.5,vmax=1.0, annot_kws={'size': 15})
plt.title("SGD Rand Aug.")
plt.tick_params(axis = 'x', labelsize = 12) # x font label size
plt.tick_params(axis = 'y', labelsize = 12) # y font label size
plt.show()

In [None]:
records = get_sdices(Path('baseline_results/baseline_focal_lovasz_SGD_None_None'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("SGD TD mean: ",tmp.mean().mean())
records = get_sdices(Path('baseline_results/baseline_focal_lovasz_SGD_rand_aug_default'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("SGD TD mean: ",tmp.mean().mean())

In [None]:
oracle_path = Path('oracle_results/focal_lovasz_adam_default_None')


records = get_sdices(Path('baseline_results/baseline_focal_lovasz_adam_default'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("Adam TD mean: ",tmp.mean().mean())

records = get_sdices(Path('baseline_results/baseline_focal_lovasz_adam_rand_aug_default_v1'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("Adam AUG TD mean: ",tmp.mean().mean())#

records = get_sdices(Path('predictions/baseline_focal_lovasz_multiview_ensemble'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("Adam AUG MVE TD mean: ",tmp.mean().mean())#


In [None]:
records = get_sdices(None,Path('oracle_results/focal_lovasz_adam_default_None'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("Adam TD mean: ",tmp.mean().mean())

records = get_sdices(None,Path('oracle_results/focal_lovasz_adam_rand_aug_default'))
tmp = pd.DataFrame.from_records(records, index=' ')
print("Adam AUG TD mean: ",tmp.mean().mean())#



In [None]:
import seaborn as sns
import matplotlib.pyplot as plt





ax = sns.heatmap(df3[df3.index], annot=True, vmin=0.5,vmax=1.0)
plt.title("Old Baseline NoAug")
plt.show()


In [None]:
(df2-df)[df2.index]

In [None]:
np.mean(np.mean(df2))

In [None]:
0.8055838119624629
0.7694199297374601