### TODOs

- [x] Inference Loop
- [ ] Oracle?
- [ ] Finetune
- [ ] Spottune
- [x] Voxelization NaN to num
- [x] Precompute 3rd order sample voxel spacing upfront


#### Augs?
- [ ] LAMB
- [ ] Label Smoothing
- [ ] Stoch. depth
- [ ] CutMix / MixUp
- [ ] Hflip? (prob not)
- [ ] RandomResizedCrop
- [ ] Rand Augment


In [4]:
from spottunet.dataset.cc359 import *
from spottunet.split import one2all
from spottunet.torch.module.unet import UNet2D
from spottunet.utils import sdice
from dpipe.im.metrics import dice_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler 

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from PIL import Image

from monai import transforms as T
from monai.transforms import Compose, apply_transform
from fastprogress.fastprogress import master_bar, progress_bar

import json
import nibabel as nib
import pandas as pd
import numpy as np
from scipy import ndimage
from dpipe.im.shape_ops import zoom
import cv2
import os
import gc
from collections import defaultdict
from pathlib import Path
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt

### Config & Logging

In [5]:
import wandb
from config import CFG
from utils import *

def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))



In [6]:
from dataset.dataloader import *
from dataset.loader import *

In [4]:
"""cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
seed = 0xBadCafe
val_size = 4
n_experiments = len(cc359_df.fold.unique())
split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]
train_df = cc359_df.iloc[split[0][0]].reset_index()

train_dataset = CC359_Dataset(df=train_df,root_dir=CFG.dataset_path,
                              voxel_spacing=CFG.voxel_spacing,transforms=tfms,
                              mode="train")"""

'cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")\nseed = 0xBadCafe\nval_size = 4\nn_experiments = len(cc359_df.fold.unique())\nsplit = one2all(df=cc359_df,val_size=val_size)[:n_experiments]\ntrain_df = cc359_df.iloc[split[0][0]].reset_index()\n\ntrain_dataset = CC359_Dataset(df=train_df,root_dir=CFG.dataset_path,\n                              voxel_spacing=CFG.voxel_spacing,transforms=tfms,\n                              mode="train")'

### Training

In [7]:
from trainer import Trainer

In [6]:
from loader import *
def run_fold(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    #wandb.tensorboard.patch(root_logdir=result_dir+"/logs")
    run = wandb.init(project="domain_shift",
                     group=CFG.model_name,
                     name=f"mode_{str(fold)}",
                     job_type="rand_slice3",
                     config=class2dict(CFG),
                     reinit=True,
                     sync_tensorboard=True)
    
    writer = SummaryWriter(log_dir=result_dir+"/logs")
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
    

    #model = smp.Unet(encoder_name="resnet50", encoder_weights="swsl", in_channels=CFG.n_chans_in,classes=CFG.n_chans_out)

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model.to(CFG.device)
    optimizer = CFG.optim(model.parameters(),lr=CFG.lr,weight_decay=CFG.wd)
    #optimizer = CFG.optim(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd, betas=(.9, .999), adam=False)
    scheduler = CFG.scheduler(optimizer, lr_lambda=lambda epoch: CFG.scheduler_multi_lr_fact )
    criterion = CFG.crit
    


    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]


    train_df = cc359_df.iloc[split[fold][0]].reset_index()
    valid_df = cc359_df.iloc[split[fold][1]].reset_index()
    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    print("Caching Train Data ...")
    train_dataset = CC359_Dataset(df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=tfms,
                                  mode="train")
    print("Caching Valid Data ...")
    valid_dataset = CC359_Dataset(df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=None,
                                  mode="val")
    test_dataset = CC359_Dataset(df=test_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=None,
                                  mode="test", cache=False)

    train_loader = PrefetchLoader(DataLoader(train_dataset,
                                              batch_size=CFG.bs,
                                              shuffle=True,
                                              num_workers=CFG.num_workers,
                                              sampler=None,
                                              collate_fn=fast_collate,
                                              pin_memory=False,
                                              drop_last=True),
                                    fp16=CFG.fp16)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.bs,shuffle=False,
                              num_workers=CFG.num_workers,pin_memory = False)
    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=1,pin_memory = False)

    """from torch_lr_finder import LRFinder
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    lr_finder.range_test(train_loader, end_lr=100, num_iter=100)
    lr_finder.plot() # to inspect the loss-learning rate graph
    lr_finder.reset()"""
    
    trainer = Trainer(model, 
                      CFG.device, 
                      optimizer,
                      scheduler,
                      criterion,writer)
    
    history = trainer.fit(
            CFG.epochs, 
            train_loader, 
            valid_loader, 
            f"{CFG.exp_name}/mode_{str(fold)}/", 
            CFG.epochs,
        )
    trainer.test(test_dataloader,result_dir)
    td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
    wandb.log(td_sdice)
    writer.close()
    run.finish()

    del trainer
    del train_loader
    del valid_loader
    del train_dataset
    del valid_dataset
    gc.collect()

In [7]:
for fold in CFG.fold:
    run_fold(fold)

[34m[1mwandb[0m: Currently logged in as: [33mmklasen[0m (use `wandb login --relogin` to force relogin)


Caching Train Data ...
(12000, 256, 256)


Caching Valid Data ...
(1000, 256, 256)


Epoch,train/loss,train/dice,train/sdice,valid/loss,valid/dice,valid/sdice,LR
0,0.5618,0.6702,0.7167,0.561,0.0028,0.0043,0.0002
1,0.4972,0.8335,0.874,0.5484,0.0,0.0003,0.0002
2,0.4802,0.8477,0.8884,0.6027,0.1556,0.1475,0.0002
3,0.4636,0.8625,0.9023,0.5252,0.0,0.0001,0.0002
4,0.4505,0.8671,0.9058,0.5162,0.0,0.0,0.0002


improved from inf to 0.5610. Saved model to 'baseline5/mode_0/e0-loss0.561.pth'
improved from 0.5610 to 0.5484. Saved model to 'baseline5/mode_0/e1-loss0.548.pth'
improved from 0.5484 to 0.5252. Saved model to 'baseline5/mode_0/e3-loss0.525.pth'
improved from 0.5252 to 0.5162. Saved model to 'baseline5/mode_0/e4-loss0.516.pth'


KeyboardInterrupt: 

In [None]:
%debug

In [None]:
def test_run(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

    #model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model = smp.Unet(encoder_name="resnet50", encoder_weights="swsl", in_channels=CFG.n_chans_in,classes=CFG.n_chans_out)
    model.load_state_dict(torch.load(f"{result_dir}/mode_{fold}_best_model.pth")["model_state_dict"])
    
    model.to(CFG.device)


    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]

    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    test_dataset = CC359_Dataset(df=test_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=tfms,
                                  mode="test", cache=False)
    trainer = Trainer(model,CFG.device, None,None,None,None)

    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=1,pin_memory = False)

    trainer.test(test_dataloader, result_dir)
    
for fold in CFG.fold:
    test_run(fold)

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd

from dpipe.io import load

path_base = 'baseline_2/'

meta = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",", index_col='id')
meta.head()

records = []
for s in sorted(meta['fold'].unique()):
    res_row = {}
    
    # one2all results:
    sdices = load(path_base / f'mode_{s}/sdice_score.json')
    #sdices = dict(sorted(sdices.items()))
    for t in sorted(set(meta['fold'].unique()) - {s}):
        df_row = meta[meta['fold'] == t].iloc[0]
        target_name = df_row['tomograph_model'] + str(df_row['tesla_value'])
        
        ids_t = meta[meta['fold'] == t].index
        res_row[target_name] = np.mean([sdsc for _id, sdsc in sdices.items() if _id in ids_t])
    print(res_row)
    df_row = meta[meta['fold'] == s].iloc[0]
    source_name = df_row['tomograph_model'] + str(df_row['tesla_value'])
    sdices = {}
    #for n_val in range(3):
    #    sdices = {**sdices,
    #              **load(path_base / f'mode_{s}/sdice_score.json')}
    #res_row[source_name] = np.mean(list(sdices.values()))

    res_row[' '] = source_name
    records.append(res_row)
df = pd.DataFrame.from_records(records, index=' ')
df[df.index]