In [1]:
%reload_ext autoreload
%autoreload 2

from spottunet.dataset.cc359 import *
from spottunet.split import one2all
from spottunet.utils import sdice
from dpipe.im.metrics import dice_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler 

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from PIL import Image

from monai import transforms as T
from monai.transforms import Compose, apply_transform
from fastprogress.fastprogress import master_bar, progress_bar

import json
import nibabel as nib
import pandas as pd
import numpy as np
from scipy import ndimage
from dpipe.im.shape_ops import zoom
import cv2
import os
import gc
from collections import defaultdict
from pathlib import Path
from monai.networks.nets import UNETR, UNet

import matplotlib.pyplot as plt


torch.cuda.set_device('cuda:0')

In [2]:
torch.cuda.current_device()

0

In [3]:
import wandb
from configs.volume_config import CFG
from utils import *

def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

def class2str(f):
    return [[name, getattr(f, name)] for name in dir(f) if not name.startswith('__')]

def write_config(CFG):
    with open(f"{CFG.results_dir}/config.txt", "w") as f:
        for n,v in class2str(CFG):
            f.write(f"{n}={v} \n")
    f.close()

In [4]:

from dataset.dataloader_3d import *
from dataset.loader import PrefetchLoader, fast_collate
from dataset.dataloader_utils import *

from dataset.mixup import FastCollateMixup
from dataset.augment import get_transforms, get_test_transforms

from volume_trainer import Trainer
from dataset.loader import *

from scheduler import LinearWarmupCosineAnnealingLR

In [5]:
def run_fold(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    #wandb.tensorboard.patch(root_logdir=result_dir+"/logs")
    run = wandb.init(project="domain_shift",
                     group=CFG.model_name,
                     name=f"mode_{str(fold)}",
                     job_type=f"baseline_{CFG.tfms}",
                     config=class2dict(CFG),
                     reinit=True,
                     sync_tensorboard=True)
    
    writer = SummaryWriter(log_dir=result_dir+"/logs")
    write_config(CFG)
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
    
    mixup_fn = None
    if CFG.mixup:
        mixup_args = dict(
            mixup_alpha=CFG.mixup, cutmix_alpha=CFG.cutmix, cutmix_minmax=None,
            prob=CFG.mixup_prob, switch_prob=CFG.mixup_switch_prob, mode='batch',
            label_smoothing=CFG.smoothing, num_classes=2)
        collate_fn = FastCollateMixup(**mixup_args)
    elif fast_collate:
        collate_fn = fast_collate
        
    seed = 0xBadCafe
    val_size = 4
    n_experiments = len(cc359_df.fold.unique())
    split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]


    train_df = cc359_df.iloc[split[fold][0]].reset_index()
    valid_df = cc359_df.iloc[split[fold][1]].reset_index()
    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    
    train_sa_x, train_sa_y = None, None
    valid_sa_x, valid_sa_y = None, None
    if CFG.cache:
        print("Caching Train Data ...")
        train_sa_x,train_sa_y = create_3d_shared_arrays(CFG,train_df,root_dir=CFG.dataset_path)
        valid_sa_x,valid_sa_y = create_3d_shared_arrays(CFG,valid_df,root_dir=CFG.dataset_path)
    train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=None,
                                  mode="train", cache=CFG.cache, cached_x=train_sa_x, cached_y=train_sa_y)
    
    valid_dataset = CC359_Dataset(CFG,df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing, transforms=None,
                                  mode="val", cache=False)
    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                 voxel_spacing=CFG.voxel_spacing,
                                  transforms=None,mode="test", cache=False)
    
    train_loader = PrefetchLoader(DataLoader(train_dataset,
                                              batch_size=CFG.bs,
                                              shuffle=True,
                                              num_workers=CFG.num_workers,
                                              sampler=None,
                                              collate_fn=collate_fn,
                                              pin_memory=False,
                                              drop_last=True),
                                  fp16=True)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=1,shuffle=False,
                              num_workers=1,pin_memory=False)
    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=1,pin_memory=False)

    model = UNet(
                spatial_dims=3,
                in_channels=1,
                out_channels=1,
                channels=(8, 16, 32, 64),
                strides=(1, 1, 1)
                )
    #model = UNETR(spatial_dims=3,in_channels=1, out_channels=1, img_size=(128,128,128), 
    #              hidden_size=1024,num_heads=4,feature_size=32, norm_name='batch')
    model.to(CFG.device)
    
    optim_dict = dict(optim=CFG.optim,lr=CFG.lr,weight_decay=CFG.wd)
    optimizer = get_optimizer(model, **optim_dict)
    
    #scheduler = CFG.scheduler(optimizer, lr_lambda=lambda epoch: CFG.scheduler_multi_lr_fact )
    if CFG.scheduler==torch.optim.lr_scheduler.OneCycleLR:
        scheduler = CFG.scheduler(optimizer, max_lr=CFG.lr, steps_per_epoch=len(train_loader)//CFG.accumulation_steps, epochs=CFG.epochs)
    elif CFG.scheduler=="warmup_cosine":
        scheduler = LinearWarmupCosineAnnealingLR(optimizer,
                                                  warmup_epochs=CFG.warmup_epochs,
                                                  max_epochs=CFG.epochs)
    criterion = CFG.crit    



    """from torch_lr_finder import LRFinder
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    lr_finder.range_test(train_loader, end_lr=100, num_iter=100)
    lr_finder.plot() # to inspect the loss-learning rate graph
    lr_finder.reset()"""

    
    trainer = Trainer(CFG,
                      model=model, 
                      device=CFG.device, 
                      optimizer=optimizer,
                      scheduler=scheduler,
                      criterion=criterion,
                      writer=writer,
                      fold=fold,
                      max_norm=CFG.max_norm,
                      mixup_fn=mixup_fn)
    
    history = trainer.fit(
            CFG.epochs, 
            train_loader, 
            valid_loader, 
            f"{result_dir}/", 
            CFG.epochs,
        )
    trainer.test(test_dataloader,result_dir)
    td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
    #writer.add_hparams(class2dict(CFG),td_sdice)
    wandb.log(td_sdice)
    writer.close()
    run.finish()

    del trainer
    del train_loader
    del valid_loader
    del train_dataset
    del valid_dataset
    gc.collect()

In [6]:
for fold in CFG.fold:
    run_fold(fold)

[34m[1mwandb[0m: Currently logged in as: [33mmklasen[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Caching Train Data ...


Epoch,train/loss,train/dice,valid/loss,valid/dice,valid/sdice,LR
0,6.4837,0.399,6.5246,0.0643,0.0592,0.0003
1,5.8271,0.244,5.1743,0.0557,0.052,0.0005
2,4.2453,0.3392,3.5465,0.081,0.0814,0.0008
3,3.0327,0.5203,2.6271,0.1274,0.121,0.001
4,2.1807,0.6808,1.9135,0.1906,0.1594,0.001
5,1.6528,0.7671,1.5327,0.2382,0.2131,0.001
6,1.2544,0.8304,1.247,0.2646,0.2784,0.001
7,1.0555,0.8488,1.0797,0.3333,0.2567,0.001
8,0.83,0.8755,0.868,0.3492,0.2962,0.001
9,0.6856,0.8898,0.8107,0.3615,0.348,0.001


KeyboardInterrupt: 

#### Test

In [None]:
cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
seed = 0xBadCafe
val_size = 4
n_experiments = len(cc359_df.fold.unique())
split = one2all(df=cc359_df,val_size=val_size)[:n_experiments]
train_df = cc359_df.iloc[split[0][0]].reset_index()

sa_x,sa_y = create_3d_shared_arrays(CFG,train_df,root_dir=CFG.dataset_path)

In [None]:
train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                              voxel_spacing=CFG.voxel_spacing,transforms=get_transforms("rrc"),
                              mode="train", cache=True, cached_x=sa_x, cached_y=sa_y)
train_loader = PrefetchLoader(DataLoader(train_dataset,
                                              batch_size=CFG.bs,
                                              shuffle=True,
                                              num_workers=CFG.num_workers,
                                              sampler=None,
                                              collate_fn=None,
                                              pin_memory=False,
                                              drop_last=True),
                                  fp16=True)
for x,y in train_loader:
    print(x.size(),y.size())
    break