In [1]:
from spottunet.dataset.cc359 import *
from spottunet.split import one2one
from models.spottune_unet import UNet2D
from models.resnet import resnet
from spottunet.utils import sdice
from dpipe.im.metrics import dice_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler 

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from PIL import Image

from monai import transforms as T
from monai.transforms import Compose, apply_transform
from fastprogress.fastprogress import master_bar, progress_bar


import json
import nibabel as nib
import pandas as pd
import numpy as np
from scipy import ndimage
from dpipe.im.shape_ops import zoom
import cv2
import os
import gc
from collections import defaultdict
from pathlib import Path
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt

DEBUG=False

In [2]:
import wandb
from configs.config_spottune import CFG
from utils import *

def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

from dataset.dataloader import *
from dataset.loader import *

In [3]:
from trainer_spottune import SpotTuneTrainer
from dataset.dataloader import *
from dataset.loader import *
from dataset.dataloader_utils import *

In [4]:
if DEBUG:
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
    seed = 0xBadCafe
    pretrained = True
    n_first_exclude = 5
    n_exps = 30
    split = one2one(cc359_df, val_size=CFG.val_size, n_add_ids=CFG.n_add_ids,
                train_on_add_only=pretrained, seed=seed)[n_first_exclude:n_exps]
    train_df = cc359_df.iloc[split[0][0]].reset_index()

    sa_x,sa_y = create_shared_arrays(CFG,train_df,root_dir=CFG.dataset_path)
    train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=tfms,
                                  mode="train", cache=True, cached_x=sa_x, cached_y=sa_y)
    for x,y in train_dataset:
        plt.imshow(x.squeeze(), "gray")
        plt.show()
        plt.imshow(y)
        plt.show()

In [5]:

def run_fold(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    #wandb.tensorboard.patch(root_logdir=result_dir+"/logs")
    """run = wandb.init(project="domain_shift",
                     group=CFG.model_name,
                     name=f"mode_{str(fold)}",
                     job_type="rand_slice3",
                     config=class2dict(CFG),
                     reinit=True,
                     sync_tensorboard=True)"""
    
    writer = SummaryWriter(log_dir=result_dir+"/logs")
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
    

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    freeze_model_spottune(model)
    model.to(CFG.device)
    
    model_policy = resnet(num_class=64)
    model_policy.to(CFG.device)
    
    
    
    optimizer_main = CFG.optim(model.parameters(),lr=CFG.lr,weight_decay=CFG.wd)
    scheduler_main = CFG.scheduler(optimizer_main, lr_lambda=lambda epoch: CFG.scheduler_multi_lr_fact )
    
    optimizer_policy = CFG.optim(model.parameters(),lr=CFG.lr,weight_decay=CFG.wd)
    scheduler_policy = CFG.scheduler(optimizer_policy, lr_lambda=lambda epoch: CFG.scheduler_multi_lr_fact )
    
    criterion = CFG.crit

    seed = 0xBadCafe
    pretrained = True
    n_first_exclude = 5
    n_exps = 30
    split = one2one(cc359_df, val_size=CFG.val_size, n_add_ids=CFG.n_add_ids,
                train_on_add_only=pretrained, seed=seed)[n_first_exclude:n_exps]
    
    baseline_exp_path = "/media/mlk/New Volume/Lab/domain_shift_anatomy/dart_results/baseline"
    load_model_state_fold_wise(architecture=model, baseline_exp_path=baseline_exp_path, exp=fold,
                               modify_state_fn=modify_state_fn_spottune, n_folds=len(cc359_df.fold.unique()),
                               n_first_exclude=n_first_exclude),

    train_df = cc359_df.iloc[split[fold][0]].reset_index()
    valid_df = cc359_df.iloc[split[fold][1]].reset_index()
    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    print("Caching Train Data ...")
    
    sa_x,sa_y = create_shared_arrays(CFG,train_df,root_dir=CFG.dataset_path)
    train_dataset = CC359_Dataset(CFG,df=train_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,transforms=tfms,
                                  mode="train", cache=True, cached_x=sa_x, cached_y=sa_y)
    
    valid_dataset = CC359_Dataset(CFG,df=valid_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,
                                  transforms=None,mode="val", cache=False)
    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,
                                  transforms=None,mode="test", cache=False)
    
    train_loader = DataLoader(train_dataset,
                                              batch_size=CFG.bs,
                                              shuffle=True,
                                              num_workers=CFG.num_workers,
                                              sampler=None,
                                              collate_fn=fast_collate,
                                              pin_memory=False)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=1,shuffle=False,
                              num_workers=1,pin_memory=True)
    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=1,pin_memory=False)

    """from torch_lr_finder import LRFinder
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    lr_finder.range_test(train_loader, end_lr=100, num_iter=100)
    lr_finder.plot() # to inspect the loss-learning rate graph
    lr_finder.reset()"""
    
    trainer = SpotTuneTrainer(CFG,
                      model,
                      model_policy,
                      CFG.device, 
                      optimizer_main,
                      scheduler_main,
                      optimizer_policy,
                      scheduler_policy,
                      criterion,writer,fold,
                      CFG.max_norm,
                      CFG.temperature,CFG.k_reg,CFG.reg_mode)
    
    history = trainer.fit(
            CFG.epochs, 
            train_loader, 
            valid_loader, 
            f"{result_dir}/", 
            CFG.epochs,
        )
    
    
    del train_loader
    del valid_loader
    del train_dataset
    del valid_dataset
    gc.collect()
    torch.cuda.empty_cache()
    
    trainer.test(test_dataloader,result_dir)
    td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
    #wandb.log(td_sdice)
    writer.close()
    #run.finish()
    del trainer
    gc.collect()
    

In [6]:
for fold in [0,1,2,3,4]:
    run_fold(fold)

Caching Train Data ...


Epoch,train/loss,train/dice,train/sdice,valid/loss,valid/dice,valid/sdice,LR
0,0.1615,0.9665,0.9927,0.2159,0.8668,0.8036,0.0002
1,0.9432,0.8903,0.9762,0.2185,0.8694,0.8033,0.0002
2,0.8695,0.9039,0.981,0.2121,0.8749,0.8106,0.0002
3,1.5296,0.7949,0.9247,0.2288,0.8682,0.7953,0.0002
4,0.171,0.9646,0.9948,0.2146,0.8771,0.8069,0.0002
5,0.151,0.9665,0.9925,0.198,0.8874,0.8196,0.0002
6,1.1718,0.8615,0.9583,0.202,0.8868,0.8174,0.0002
7,0.4655,0.9296,0.9558,0.2052,0.8863,0.8155,0.0001
8,0.7557,0.9195,0.9599,0.202,0.8891,0.819,0.0001
9,0.8189,0.8967,0.9336,0.1884,0.8968,0.8298,0.0001


Caching Train Data ...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch,train/loss,train/dice,train/sdice,valid/loss,valid/dice,valid/sdice,LR
0,0.1264,0.974,0.9931,1.1989,0.336,0.4521,0.0002
1,0.6494,0.9037,0.971,1.1377,0.3647,0.479,0.0002
2,1.0922,0.8873,0.9568,1.065,0.3998,0.514,0.0002
3,0.7894,0.8645,0.951,1.0626,0.4107,0.5211,0.0002
4,1.0519,0.8892,0.9573,0.9958,0.4442,0.5513,0.0002
5,0.5135,0.9183,0.9898,0.9502,0.4639,0.5678,0.0002
6,0.627,0.8987,0.9045,1.0721,0.4391,0.5408,0.0002
7,1.179,0.8883,0.9665,1.0944,0.4505,0.5479,0.0001
8,0.5043,0.9145,0.9498,1.348,0.4076,0.4951,0.0001
9,0.7207,0.8678,0.8942,1.6871,0.3614,0.4372,0.0001


Caching Train Data ...


Epoch,train/loss,train/dice,train/sdice,valid/loss,valid/dice,valid/sdice,LR
0,1.0668,0.8915,0.9812,0.0713,0.9649,0.9511,0.0002
1,1.2513,0.8982,0.9891,0.0732,0.9643,0.9469,0.0002
2,1.5657,0.8725,0.9656,0.0758,0.9629,0.9426,0.0002
3,0.0854,0.9854,0.9982,0.0755,0.9643,0.9432,0.0002
4,0.0856,0.9855,0.9973,0.0756,0.9654,0.9434,0.0002
5,1.333,0.8904,0.9274,0.0817,0.9622,0.9355,0.0002
6,0.1253,0.9786,0.9965,0.0845,0.9598,0.9323,0.0002
7,1.476,0.8721,0.9741,0.0921,0.9544,0.9249,0.0001
8,0.1287,0.9769,0.9923,0.0898,0.9569,0.9274,0.0001
9,0.0936,0.9839,0.9974,0.0856,0.96,0.9314,0.0001




Caching Train Data ...


Epoch,train/loss,train/dice,train/sdice,valid/loss,valid/dice,valid/sdice,LR
0,1.3007,0.9036,0.9624,0.2036,0.7275,0.7733,0.0002
1,0.2723,0.9589,0.94,0.1656,0.7612,0.8049,0.0002
2,0.7187,0.9365,0.986,0.158,0.7684,0.8096,0.0002
3,1.2749,0.9063,0.9897,0.1267,0.8119,0.8421,0.0002
4,1.2707,0.8915,0.9548,0.1241,0.8209,0.8483,0.0002
5,1.1978,0.9117,0.9776,0.1281,0.8102,0.8431,0.0002
6,0.6234,0.9479,0.995,0.1185,0.8313,0.8585,0.0002
7,0.0656,0.99,0.9985,0.0912,0.89,0.8993,0.0001
8,0.6856,0.9355,0.9903,0.087,0.9021,0.9069,0.0001
9,1.4645,0.8971,0.9533,0.0764,0.9255,0.9275,0.0001




Caching Train Data ...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch,train/loss,train/dice,train/sdice,valid/loss,valid/dice,valid/sdice,LR
0,0.1544,0.9691,0.9857,0.2637,0.6708,0.772,0.0002
1,1.3427,0.7871,0.9952,0.2793,0.6611,0.7639,0.0002
2,0.5009,0.9304,0.9769,0.2687,0.6725,0.774,0.0002
3,0.1267,0.9747,0.9896,0.2503,0.687,0.7879,0.0002
4,1.5389,0.7557,0.7957,0.285,0.658,0.7645,0.0002
5,0.2696,0.9456,0.9722,0.323,0.6339,0.7437,0.0002
6,1.4702,0.762,0.976,0.359,0.6433,0.7312,0.0002
7,0.1171,0.9766,0.9943,0.3778,0.6283,0.7153,0.0001
8,0.1141,0.9774,0.9964,0.3965,0.6213,0.7043,0.0001
9,0.2189,0.9535,0.9733,0.4225,0.6141,0.6928,0.0001


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [7]:
def test_run(fold):
    result_dir = CFG.results_dir + "/mode_"+str(fold)
    os.makedirs(result_dir, exist_ok=True)
    
    cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")

    model = UNet2D(n_chans_in=CFG.n_chans_in, n_chans_out=CFG.n_chans_out, n_filters_init=CFG.n_filters)
    model.load_state_dict(torch.load(f"{result_dir}/mode_{fold}_best_epoch_model.pth")["model_state_dict"])    
    freeze_model_spottune(model)
    model.to(CFG.device)
    
    model_policy = resnet(num_class=64)
    model_policy.to(CFG.device)
    model_policy.load_state_dict(torch.load(f"{result_dir}/mode_{fold}_best_epoch_model.pth")["model_policy_state_dict"])  


    seed = 0xBadCafe
    pretrained = True
    n_first_exclude = 5
    n_exps = 30
    split = one2one(cc359_df, val_size=CFG.val_size, n_add_ids=CFG.n_add_ids,
                train_on_add_only=pretrained, seed=seed)[n_first_exclude:n_exps]

    test_df  = cc359_df.iloc[split[fold][2]].reset_index()

    test_dataset = CC359_Dataset(CFG,df=test_df,root_dir=CFG.dataset_path,
                                  voxel_spacing=CFG.voxel_spacing,
                                  transforms=None,mode="test", cache=False)
    trainer = SpotTuneTrainer(CFG,
                              model,
                              model_policy,
                              CFG.device, 
                              None,
                              None,
                              None,
                              None,
                              None,
                              None,fold,
                              CFG.max_norm,
                              CFG.temperature,CFG.k_reg,CFG.reg_mode)

    test_dataloader = DataLoader(test_dataset, 
                                  batch_size=1,shuffle=False,
                                  num_workers=1,pin_memory = False)

    trainer.test(test_dataloader, result_dir)
    td_sdice = get_target_domain_metrics(CFG.dataset_path,Path(CFG.results_dir),fold)
    print(td_sdice)
    return td_sdice

#test_run(0)

In [8]:
import json
with open("spottune_results/exp_01/mode_0/sdice_score.json","r") as f:
    data = json.load(f)
np.mean(list(data.values()))

0.9039291321106401

In [9]:
def get_stats_spottune(exp_path, fold='inference'):
    p = torch.load(Path(exp_path) / f'policy_{fold}_record/policy_record')
    f = open(Path(exp_path) / f'policy_{fold}_record/iter_record', 'r')
    n_iter = f.read()
    f.close()
    record = (p / int(n_iter)).detach().numpy()
    return record

get_stats_spottune("spottune_results/exp_03_tv_aug/mode_1")

FileNotFoundError: [Errno 2] No such file or directory: 'spottune_results/exp_03_tv_aug/mode_1/policy_inference_record/policy_record'

In [None]:
cc359_df = pd.read_csv(f"{CFG.dataset_path}/meta.csv",delimiter=",")
seed = 0xBadCafe
pretrained = True
n_first_exclude = 5
n_exps = 30
split = one2one(cc359_df, val_size=CFG.val_size, n_add_ids=CFG.n_add_ids,
            train_on_add_only=pretrained, seed=seed)[n_first_exclude:n_exps]


In [None]:

cc359_df.iloc[split[9][0]]

In [None]:

    Siemens15
Siemens3
    Ge15
    ge3
    philips15
    philips3

    siemens15
    siemens3
ge15
    ge3
    philips15
    philips3


In [24]:
import json
mode_domain = ["Siemens15", "Ge15", "Ge3", "philips15", "philips3"]
paths = ["exp_03_no_aug", "exp_03_rrc", "exp_03_gaussianblur", "exp_03_ssi_48_nid_1_gamma"]

results = []
for path in paths:
    mode_dict = {}
    modes = [0,1,2,3,4]
    for mode in modes:
        
        with open(f"spottune_results/{path}/mode_{str(mode)}/sdice_score.json","r") as f:
            data = json.load(f)
            mode_dict[mode_domain[mode]] = np.mean(list(data.values()))
    mode_dict[' '] = path
    results.append(mode_dict)
results.append({"Siemens15": 0.849, "Ge15": 0.937, "Ge3": 0.422, "philips15": 0.743, "philips3": 0.644, " ": "baseline"})

In [25]:
df = pd.DataFrame.from_records(results, index=' ')
df

Unnamed: 0,Siemens15,Ge15,Ge3,philips15,philips3
,,,,,
exp_03_no_aug,0.913258,0.818775,0.935291,0.932747,0.822767
exp_03_rrc,0.913521,0.892594,0.930822,0.934765,0.834499
exp_03_gaussianblur,0.908799,0.88035,0.935382,0.930348,0.829545
exp_03_ssi_48_nid_1_gamma,0.909606,0.871523,0.922274,0.935277,0.81267
baseline,0.849,0.937,0.422,0.743,0.644


In [33]:
np.mean(df.iloc[0:-1] - df.iloc[-1], axis=1)

 
exp_03_no_aug                0.165568
exp_03_rrc                   0.182240
exp_03_gaussianblur          0.177885
exp_03_ssi_48_nid_1_gamma    0.171270
dtype: float64