In [2]:
from spottunet.dataset.cc359 import *
from spottunet.split import one2all, single_cv
from spottunet.torch.module.unet import UNet2D
from spottunet.utils import sdice
from dpipe.im.metrics import dice_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler 

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from PIL import Image

from monai import transforms as T
from monai.transforms import Compose, apply_transform
from fastprogress.fastprogress import master_bar, progress_bar

import json
import nibabel as nib
import pandas as pd
import numpy as np
from scipy import ndimage
from dpipe.im.shape_ops import zoom
import cv2
import os
import gc
from collections import defaultdict
from pathlib import Path
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt

from trainer import Trainer
from dataset.loader import *

import wandb
from configs.config import CFG
from utils import *

def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

def class2str(f):
    return [[name, getattr(f, name)] for name in dir(f) if not name.startswith('__')]

def write_config(CFG):
    with open(f"{CFG.results_dir}/config.txt", "w") as f:
        for n,v in class2str(CFG):
            f.write(f"{n}={v} \n")
    f.close()
    
%reload_ext autoreload
%autoreload 2
from dataset.dataloader import *
from dataset.loader import PrefetchLoader, fast_collate
from dataset.dataloader_utils import *

from dataset.mixup import FastCollateMixup
from dataset.augment import get_transforms, get_test_transforms

from scheduler import LinearWarmupCosineAnnealingLR

In [3]:
models = [
    "baseline_results/baseline_focal_lovasz_adam_rand_aug_default_v1",
    "baseline_results/baseline_focal_lovasz_Adam_rand_aug_default_frontback",
    "baseline_results/baseline_focal_lovasz_adam_rand_aug_default_sideview",
    ]
views = [
        (2,0,1),
        (1,0,2),
        (0,1,2)
        ]

In [None]:
def run_fold(fold):
    slice_index=130
    CFG.transpose = (0,1,2)
    CFG.fcm_mask = None
    output_dir =  "predictions/baseline_focal_lovasz_multiview_ensemble_v2/" + f"/mode_{fold}/"
    os.makedirs(output_dir, exist_ok=True)

    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model.to(CFG.device)


    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]

    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                 voxel_spacing=CFG.voxel_spacing,transforms=get_test_transforms(),
                                 mode="test", cache=False, td_specific_aug=False)

    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=0,pin_memory = False)

    model.eval()
    import copy
    results = defaultdict(dict)
    for step,(x,y,_id) in progress_bar(enumerate(test_dataloader, 1), len(test_dataloader)):
        _id=_id[0]
        with torch.no_grad():
            all_logits = []
            if _id=="CC0002":
                for m,v in zip(models,views):
                    model.load_state_dict(torch.load(f"{m}/mode_{fold}/e_39.pth")["model_state_dict"])
                    x_c, y_c = copy.deepcopy(x.squeeze()), copy.deepcopy(y.squeeze())
                    x_c = torch.permute(x_c,v)
                    y_c = torch.permute(y_c,v)
                    x_c = x_c.to(CFG.device)
                    c,h,w = x_c.shape
                    logits = model(x_c.unsqueeze(1))
                    logits = logits.squeeze().sigmoid().cpu().detach().numpy()
                    all_logits.append(logits)
                    outputs = (logits > .5)
                    """surface_dice = sdice(y_c.squeeze().numpy().astype(bool), outputs, CFG.voxel_spacing,CFG.sdice_tolerance)
                    fig, (ax, bx, cx, dx) = plt.subplots(1,4, figsize=(20,5))
                    inp = ax.imshow(x_c[slice_index].cpu().numpy())
                    #ax.set_title(str(tta))
                    fig.colorbar(inp, ax=ax)
                    bx.imshow(outputs[slice_index])
                    bx.set_title(str(round(surface_dice, 4)))
                    cx.imshow(y_c[slice_index].squeeze())
                    dx.imshow(y_c[slice_index].squeeze()-outputs[slice_index])
                    plt.show()"""
                td_view = all_logits[0]
                fb_view = np.transpose(all_logits[1], (2,1,0))
                lr_view = np.transpose(all_logits[2], (2,0,1))
                y = np.transpose(y.squeeze(), (2,0,1))

                total_logits_mean = np.mean([td_view,fb_view,lr_view], axis=0)
                total_outputs = (total_logits_mean > .5)
                surface_dice = sdice(y.squeeze().numpy().astype(bool), total_outputs, CFG.voxel_spacing,CFG.sdice_tolerance)

                results['sdice_score'][_id] = surface_dice
                results['dice_score'][_id]  = dice_score(y.squeeze().numpy().astype(bool), total_outputs)
                """fig, (ax, bx, cx, dx) = plt.subplots(1,4, figsize=(20,5))
                inp = ax.imshow(x_c[slice_index].cpu().numpy())
                fig.colorbar(inp, ax=ax)
                bx.imshow(td_view[slice_index])
                cx.imshow(fb_view[slice_index])
                dx.imshow(lr_view[slice_index])
                plt.show()



                print(total_outputs.shape)

                fig, (ax, bx, cx, dx) = plt.subplots(1,4, figsize=(20,5))
                inp = ax.imshow(x_c[slice_index].cpu().numpy())
                fig.colorbar(inp, ax=ax)
                bx.imshow(total_outputs[slice_index])
                bx.set_title(str(round(surface_dice, 4)))
                cx.imshow(y[slice_index])
                plt.show()"""
                print(_id+str(surface_dice))
                np.save(output_dir+f"/{_id}.npy",outputs)

    with open(os.path.join(output_dir, 'sdice_score' + '.json'), 'w') as f:
        json.dump(results['sdice_score'], f, indent=0)
    with open(os.path.join(output_dir, 'dice_score'+ '.json'), 'w') as f:
        json.dump(results['dice_score'], f, indent=0)
        
        
for fold in [3,1,2,3,4,5]:
    run_fold(fold)

In [None]:
models = [
    "oracle_results/focal_lovasz_adam_rand_aug_default",
    "oracle_results/focal_lovasz_adam_rand_aug_default_frontview",
    "oracle_results/focal_lovasz_adam_rand_aug_default_sideview",
    ]
views = [
        (2,0,1),
        (1,0,2),
        (0,1,2)
        ]

def run_fold(fold):
    slice_index=130
    CFG.transpose = (0,1,2)
    output_dir =  "oracle_results/baseline_focal_lovasz_multiview_ensemble/" + f"/mode_{fold}/"
    os.makedirs(output_dir, exist_ok=True)

    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model.to(CFG.device)


    seed = 0xBadCafe
    n_splits = 3
    val_size = 2
    split = single_cv(cc359_df, n_splits=n_splits, val_size=val_size, seed=seed) 


    train_df = cc359_df.iloc[split[fold][0]].reset_index()
    valid_df = cc359_df.iloc[split[fold][1]].reset_index()
    test_df  = cc359_df.iloc[split[fold][2]].reset_index()


    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                 voxel_spacing=CFG.voxel_spacing,transforms=get_test_transforms(),
                                 mode="test", cache=False)

    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=0,pin_memory = False)

    model.eval()
    import copy
    results = defaultdict(dict)
    for step,(x,y,_id) in progress_bar(enumerate(test_dataloader, 1), len(test_dataloader)):
        _id=_id[0]
        with torch.no_grad():
            all_logits = []
            for m,v in zip(models,views):
                model.load_state_dict(torch.load(f"{m}/mode_{fold}/e_39.pth")["model_state_dict"])
                x_c, y_c = copy.deepcopy(x.squeeze()), copy.deepcopy(y.squeeze())
                x_c = torch.permute(x_c,v)
                y_c = torch.permute(y_c,v)
                x_c = x_c.to(CFG.device)
                c,h,w = x_c.shape
                logits = model(x_c.unsqueeze(1))
                logits = logits.squeeze().sigmoid().cpu().detach().numpy()
                all_logits.append(logits)
                outputs = (logits > .5)

            td_view = all_logits[0]
            fb_view = np.transpose(all_logits[1], (2,1,0))
            lr_view = np.transpose(all_logits[2], (2,0,1))
            y = np.transpose(y.squeeze(), (2,0,1))

            total_logits_mean = np.mean([td_view,fb_view,lr_view], axis=0)
            total_outputs = (total_logits_mean > .5)
            surface_dice = sdice(y.squeeze().numpy().astype(bool), total_outputs, CFG.voxel_spacing,CFG.sdice_tolerance)

            results['sdice_score'][_id] = surface_dice
            results['dice_score'][_id]  = dice_score(y.squeeze().numpy().astype(bool), total_outputs)

    with open(os.path.join(output_dir, 'sdice_score' + '.json'), 'w') as f:
        json.dump(results['sdice_score'], f, indent=0)
    with open(os.path.join(output_dir, 'dice_score'+ '.json'), 'w') as f:
        json.dump(results['dice_score'], f, indent=0)
        #np.save(output_dir+f"/{_id}.npy",outputs)
        
for fold in range(0,18):
    run_fold(fold)