In [14]:
import torch
import json

In [5]:
convex_dump_path = '/share/data_supergrover1/weihsbach/shared_data/tmp/curriculum_deeplab/moving_convex.json'
with open(convex_dump_path, 'r') as dump_convex:
    moving_convex = json.load(dump_convex)

deeds_dump_path = '/share/data_supergrover1/weihsbach/shared_data/tmp/curriculum_deeplab/moving_deeds.json'
with open(deeds_dump_path, 'r') as dump_deeds:
    moving_deeds = json.load(dump_deeds)

In [12]:
print(moving_convex)
print(moving_deeds)
union_set = set(moving_convex).union(set(moving_deeds))
common_set = set(moving_convex).intersection(set(moving_deeds))
print("union", len(union_set))
print("common", len(common_set))
print("deeds", len(moving_deeds))
print("convex", len(moving_convex))

['m024r', 'm105l', 'm031l', 'm020l', 'm048r', 'm045r', 'm033r', 'm037r', 'm100r', 'm049l', 'm026l', 'm056l', 'm014l', 'm029r', 'm102l', 'm012r', 'm040r', 'm041l', 'm062l', 'm019r']
['m019r', 'm049l', 'm037r', 'm105l', 'm034l', 'm022l', 'm021l', 'm102l', 'm040r', 'm013l', 'm012r', 'm024r', 'm058l', 'm020l', 'm062l', 'm042l', 'm051l', 'm029r', 'm100r', 'm041l', 'm103l', 'm061l', 'm002l', 'm063l', 'm026l', 'm043l', 'm038l', 'm015l', 'm018l', 'm010l', 'm031l', 'm014l', 'm033r', 'm048r', 'm056l', 'm028l', 'm045r', 'm055l', 'm104l', 'm064l']
union 40
common 20
deeds 40
convex 20


In [18]:
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#       format_version: '1.3'
#       jupytext_version: 1.13.5
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# %%
import sys
import os
import time
import random
import re
import warnings
import glob
from meidic_vtach_utils.run_on_recommended_cuda import get_cuda_environ_vars as get_vars
os.environ.update(get_vars(select="* -4"))
import pickle
import copy
from pathlib import Path
from tqdm import tqdm
from collections import OrderedDict

import functools
from enum import Enum, auto

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torchvision
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
import scipy

import wandb
import matplotlib.pyplot as plt
from IPython.display import display
from sklearn.model_selection import KFold

from mdl_seg_class.metrics import dice3d, dice2d
from mdl_seg_class.visualization import visualize_seg

from curriculum_deeplab.mindssc import mindssc
from curriculum_deeplab.utils import interpolate_sample, in_notebook, dilate_label_class, LabelDisturbanceMode, ensure_dense
from curriculum_deeplab.CrossmodaHybridIdLoader import CrossmodaHybridIdLoader, get_crossmoda_data_load_closure
from curriculum_deeplab.MobileNet_LR_ASPP_3D import MobileNet_LRASPP_3D, MobileNet_ASPP_3D

print(torch.__version__)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

if in_notebook:
    THIS_SCRIPT_DIR = os.path.abspath('')
else:
    THIS_SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
print(f"Running in: {THIS_SCRIPT_DIR}")

def reset_determinism():
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)
    # torch.use_deterministic_algorithms(True)
    
def get_batch_dice_per_class(b_dice, class_tags, exclude_bg=True) -> dict:
    score_dict = {}
    for cls_idx, cls_tag in enumerate(class_tags):
        if exclude_bg and cls_idx == 0:
            continue

        if torch.all(torch.isnan(b_dice[:,cls_idx])):
            score = float('nan')
        else:
            score = np.nanmean(b_dice[:,cls_idx]).item()

        score_dict[cls_tag] = score

    return score_dict

def get_batch_dice_over_all(b_dice, exclude_bg=True) -> float:

    start_idx = 1 if exclude_bg else 0
    if torch.all(torch.isnan(b_dice[:,start_idx:])):
        return float('nan')
    return np.nanmean(b_dice[:,start_idx:]).item()



def get_2d_stack_batch_size(b_input_size: torch.Size, stack_dim):
    assert len(b_input_size) == 5, f"Input size must be 5D: BxCxDxHxW but is {b_input_size}"
    if stack_dim == "D":
        return b_input_size[0]*b_input_size[2]
    if stack_dim == "H":
        return b_input_size[0]*b_input_size[3]
    if stack_dim == "W":
        return b_input_size[0]*b_input_size[4]
    else:
        raise ValueError(f"stack_dim '{stack_dim}' must be 'D' or 'H' or 'W'.")



def make_2d_stack_from_3d(b_input, stack_dim):
    assert b_input.dim() == 5, f"Input must be 5D: BxCxDxHxW but is {b_input.shape}"
    B, C, D, H, W = b_input.shape

    if stack_dim == "D":
        return b_input.permute(0, 2, 1, 3, 4).reshape(B*D, C, H, W)
    if stack_dim == "H":
        return b_input.permute(0, 3, 1, 2, 4).reshape(B*H, C, D, W)
    if stack_dim == "W":
        return b_input.permute(0, 4, 1, 2, 3).reshape(B*W, C, D, H)
    else:
        raise ValueError(f"stack_dim '{stack_dim}' must be 'D' or 'H' or 'W'.")



def make_3d_from_2d_stack(b_input, stack_dim, orig_stack_size):
    assert b_input.dim() == 4, f"Input must be 4D: (orig_batch_size/B)xCxSPAT1xSPAT0 but is {b_input.shape}"
    B, C, SPAT1, SPAT0 = b_input.shape
    b_input = b_input.reshape(orig_stack_size, int(B//orig_stack_size), C, SPAT1, SPAT0)

    if stack_dim == "D":
        return b_input.permute(0, 2, 1, 3, 4)
    if stack_dim == "H":
        return b_input.permute(0, 2, 3, 1, 4)
    if stack_dim == "W":
        return b_input.permute(0, 2, 3, 4, 1)
    else:
        raise ValueError(f"stack_dim is '{stack_dim}' but must be 'D' or 'H' or 'W'.")

# %%
class DataParamMode(Enum):
    INSTANCE_PARAMS = auto()
    DISABLED = auto()

class DotDict(dict):
    """dot.notation access to dictionary attributes
        See https://stackoverflow.com/questions/49901590/python-using-copy-deepcopy-on-dotdict
    """

    def __getattr__(self, item):
        try:
            return self[item]
        except KeyError as e:
            raise AttributeError from e

    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

config_dict = DotDict({
    'num_folds': 3,
    'only_first_fold': True,
    # 'fold_override': 0,
    # 'checkpoint_epx': 0,

    'use_mind': False,
    'epochs': 40,

    'batch_size': 8,
    'val_batch_size': 1,
    'use_2d_normal_to': None,
    'train_patchwise': False,

    'num_val_images': 20,
    'atlas_count': 1,

    'dataset': 'crossmoda',
    'reg_state': None,
    'train_set_max_len': None,
    'crop_3d_w_dim_range': (45, 95),
    'crop_2d_slices_gt_num_threshold': 0,

    'lr': 0.01,
    'use_scheduling': True,

    # Data parameter config
    'data_param_mode': DataParamMode.INSTANCE_PARAMS,
    'init_inst_param': 0.0,
    'lr_inst_param': 0.1,
    'use_risk_regularization': True,
    'use_fixed_weighting': True,
    'use_ool_dp_loss': False,

    'fixed_weight_file': None,
    'fixed_weight_min_quantile': None,#.9,
    'fixed_weight_min_value': None,
    'override_embedding_weights': False,
    # ),

    'save_every': 200,
    'mdl_save_prefix': 'data/models',

    'debug': False,
    'wandb_mode': 'disabled', # e.g. online, disabled
    'do_sweep': False,

    'checkpoint_name': None,
    'fold_override': None,
    'checkpoint_epx': None,

    'do_plot': False,
    'save_dp_figures': False,
    'save_labels': False,

    'disturbance_mode': None,
    'disturbance_strength': 0.,
    'disturbed_percentage': 0.,
})

if config_dict.train_patchwise:
    raise NotImplementedError()

# %%
def prepare_data(config):
    reset_determinism()
    if config.reg_state:
        print("Loading registered data.")

        if config.reg_state == "mix_combined_best":
            config.atlas_count = 1
            domain = 'source'
            label_data_left = torch.load('./data/optimal_reg_left.pth')
            label_data_right = torch.load('./data/optimal_reg_right.pth')
            loaded_identifier = label_data_left['valid_left_t1'] + label_data_right['valid_right_t1']

            perm = np.random.permutation(len(loaded_identifier))
            _clen = int(.5*len(loaded_identifier))
            best_choice = perm[:_clen]
            combined_choice = perm[_clen:]

            best_label_data = torch.cat([label_data_left['best_all'][:44], label_data_right['best_all'][:63]], dim=0)[best_choice]
            combined_label_data = torch.cat([label_data_left['combined_all'][:44], label_data_right['combined_all'][:63]], dim=0)[combined_choice]
            label_data = torch.zeros([107,128,128,128])
            label_data[best_choice] = best_label_data
            label_data[combined_choice] = combined_label_data
            var_identifier = ["mBST" if idx in best_choice else "mCMB" for idx in range(len(loaded_identifier))]
            loaded_identifier = [f"{_id}:{var_id}" for _id, var_id in zip(loaded_identifier, var_identifier)]

        elif config.reg_state == "acummulate_combined_best":
            config.atlas_count = 2
            domain = 'source'
            label_data_left = torch.load('./data/optimal_reg_left.pth')
            label_data_right = torch.load('./data/optimal_reg_right.pth')
            loaded_identifier = label_data_left['valid_left_t1'] + label_data_right['valid_right_t1']
            best_label_data = torch.cat([label_data_left['best_all'][:44], label_data_right['best_all'][:63]], dim=0)
            combined_label_data = torch.cat([label_data_left['combined_all'][:44], label_data_right['combined_all'][:63]], dim=0)
            label_data = torch.cat([best_label_data, combined_label_data])
            loaded_identifier = [_id+':mBST' for _id in loaded_identifier] + [_id+':mCMB' for _id in loaded_identifier]

        elif config.reg_state == "best":
            config.atlas_count = 1
            domain = 'source'
            label_data_left = torch.load('./data/optimal_reg_left.pth')
            label_data_right = torch.load('./data/optimal_reg_right.pth')
            loaded_identifier = label_data_left['valid_left_t1'] + label_data_right['valid_right_t1']
            label_data = torch.cat([label_data_left[config.reg_state+'_all'][:44], label_data_right[config.reg_state+'_all'][:63]], dim=0)
            postfix = 'mBST'
            loaded_identifier = [_id+':'+postfix for _id in loaded_identifier]

        elif config.reg_state == "combined":
            config.atlas_count = 1
            domain = 'source'
            label_data_left = torch.load('./data/optimal_reg_left.pth')
            label_data_right = torch.load('./data/optimal_reg_right.pth')
            loaded_identifier = label_data_left['valid_left_t1'] + label_data_right['valid_right_t1']
            label_data = torch.cat([label_data_left[config.reg_state+'_all'][:44], label_data_right[config.reg_state+'_all'][:63]], dim=0)
            postfix = 'mCMB'
            loaded_identifier = [_id+':'+postfix for _id in loaded_identifier]

        elif config.reg_state == "acummulate_convex_adam_FT2_MT1":
            config.atlas_count = 10
            domain = 'target'
            bare_data = torch.load("/share/data_supergrover1/weihsbach/shared_data/important_data_artifacts/curriculum_deeplab/20220318_crossmoda_convex_adam_lr/crossmoda_convex_registered_new_convex.pth")
            label_data = []
            loaded_identifier = []
            for fixed_id, moving_dict in bare_data.items():
                sorted_moving_dict = OrderedDict(moving_dict)
                for idx_mov, (moving_id, moving_sample) in enumerate(sorted_moving_dict.items()):
                    # Only use every third warped sample
                    if idx_mov % 3 == 0:
                        label_data.append(moving_sample['warped_label'].cpu())
                        loaded_identifier.append(f"{fixed_id}:m{moving_id}")

        elif config.reg_state == "acummulate_every_third_deeds_FT2_MT1":
            config.atlas_count = 10
            domain = 'target'
            bare_data = torch.load("/share/data_supergrover1/weihsbach/shared_data/important_data_artifacts/curriculum_deeplab/20220114_crossmoda_multiple_registrations/crossmoda_deeds_registered.pth")
            label_data = []
            loaded_identifier = []
            for fixed_id, moving_dict in bare_data.items():
                sorted_moving_dict = OrderedDict(moving_dict)
                for idx_mov, (moving_id, moving_sample) in enumerate(sorted_moving_dict.items()):
                    # Only use every third warped sample
                    if idx_mov % 3 == 0:
                        label_data.append(moving_sample['warped_label'].cpu())
                        loaded_identifier.append(f"{fixed_id}:m{moving_id}")

        elif config.reg_state == "acummulate_every_deeds_FT2_MT1":
            config.atlas_count = 30
            domain = 'target'
            bare_data = torch.load("/share/data_supergrover1/weihsbach/shared_data/important_data_artifacts/curriculum_deeplab/20220114_crossmoda_multiple_registrations/crossmoda_deeds_registered.pth")
            label_data = []
            loaded_identifier = []
            for fixed_id, moving_dict in bare_data.items():
                sorted_moving_dict = OrderedDict(moving_dict)
                for idx_mov, (moving_id, moving_sample) in enumerate(sorted_moving_dict.items()):
                    label_data.append(moving_sample['warped_label'].cpu())
                    loaded_identifier.append(f"{fixed_id}:m{moving_id}")


        else:
            raise ValueError()

        modified_3d_label_override = {}
        for idx, identifier in enumerate(loaded_identifier):
            # Find sth. like 100r:mBST or 100r:m001l
            nl_id, lr_id, m_id = re.findall(r'(\d{1,3})([lr]):m([A-Z0-9a-z]{3,4})$', identifier)[0]
            nl_id = int(nl_id)
            crossmoda_var_id = f"{nl_id:03d}{lr_id}:m{m_id}"
            modified_3d_label_override[crossmoda_var_id] = label_data[idx]

        prevent_disturbance = True

    else:
        domain = 'source'
        modified_3d_label_override = None
        prevent_disturbance = False

    if config.dataset == 'crossmoda':
        # Use double size in 2D prediction, normal size in 3D
        pre_interpolation_factor = 2. if config.use_2d_normal_to is not None else 1.5
        clsre = get_crossmoda_data_load_closure(
            base_dir="/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/",
            domain=domain, state='l4', use_additional_data=False,
            size=(128,128,128), resample=True, normalize=True, crop_3d_w_dim_range=config.crop_3d_w_dim_range,
            ensure_labeled_pairs=True, modified_3d_label_override=modified_3d_label_override,
            debug=config.debug
        )
        training_dataset = CrossmodaHybridIdLoader(
            clsre,
            size=(128,128,128), resample=True, normalize=True, crop_3d_w_dim_range=config.crop_3d_w_dim_range,
            ensure_labeled_pairs=True,
            max_load_3d_num=config.train_set_max_len,
            prevent_disturbance=prevent_disturbance,
            use_2d_normal_to=config.use_2d_normal_to,
            crop_2d_slices_gt_num_threshold=config.crop_2d_slices_gt_num_threshold,
            pre_interpolation_factor=pre_interpolation_factor,
            fixed_weight_file=config.fixed_weight_file, fixed_weight_min_quantile=config.fixed_weight_min_quantile, fixed_weight_min_value=config.fixed_weight_min_value,
        )

        # validation_dataset = CrossmodaHybridIdLoader("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/",
        #     domain="validation", state="l4", ensure_labeled_pairs=True)
        # target_dataset = CrossmodaHybridIdLoader("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/",
        #     domain="target", state="l4", ensure_labeled_pairs=True)

    if config.dataset == 'ixi':
        raise NotImplementedError()
        # Use double size in 2D prediction, normal size in 3D
        pre_interpolation_factor = 2. if config.use_2d_normal_to is not None else 1.
        clsre = get_ixi_data_load_closure()
        training_dataset = IXIHybridIdLoader(
            clsre,
            ensure_labeled_pairs=True,
            max_load_3d_num=config.train_set_max_len,
            modified_3d_label_override=modified_3d_label_override, prevent_disturbance=prevent_disturbance,
            use_2d_normal_to=config.use_2d_normal_to,
            crop_2d_slices_gt_num_threshold=config.crop_2d_slices_gt_num_threshold,
            pre_interpolation_factor=pre_interpolation_factor
        )
        training_dataset.eval()
        print(f"Nonzero slices: " \
            f"{sum([b['label'].unique().numel() > 1 for b in training_dataset])/len(training_dataset)*100}%"
        )
        # validation_dataset = CrossmodaHybridIdLoader("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/",
        #     domain="validation", state="l4", ensure_labeled_pairs=True)
        # target_dataset = CrossmodaHybridIdLoader("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/",
        #     domain="target", state="l4", ensure_labeled_pairs=True)


    elif config['dataset'] == 'organmnist3d':
        training_dataset = WrapperOrganMNIST3D(
            split='train', root='./data/medmnist', download=True, normalize=True,
            max_load_num=300, crop_3d_w_dim_range=None,
            disturbed_idxs=None, use_2d_normal_to='W'
        )
        print(training_dataset.mnist_set.info)
        print("Classes: ", training_dataset.label_tags)
        print("Samples: ", len(training_dataset))

    return training_dataset


### Recommended gpus on this machine (descending order) ###
  ID  Card name      Util    Mem free  Cuda             User(s)
----  -----------  ------  ----------  ---------------  ---------------
   0  TITAN RTX       0 %    7725 MiB  11.2(460.73.01)  weihsbach, root

Will apply following mapping

  ID  Card name        torch
----  -----------  --  -------
   0  TITAN RTX    ->  cuda:0
1.9.0a0+gitdfbd030
8200
TITAN RTX
Running in: /share/data_supergrover1/weihsbach/shared_data/tmp/curriculum_deeplab


In [19]:
dataset = prepare_data(config_dict)

Loading CrossMoDa ceT1 images and labels...


209 images, 209 labels: 100%|██████████| 418/418 [00:34<00:00, 12.06it/s]


Postprocessing 3D volumes
Removed 0 3D images in postprocessing
Equal image and label numbers: True (107)
Image shape: torch.Size([107, 128, 128, 50]), mean.: -0.00, std.: 1.00
Label shape: torch.Size([107, 128, 128, 50]), max.: 1
Data import finished.
CrossMoDa loader will yield 3D samples


In [43]:
d_ids = dataset.get_3d_ids()
target_path = "/share/data_rechenknecht01_2/weihsbach/nnunet/nnUNet_raw_data_base/nnUNet_raw_data/Task562_CM_domain_adaptation_insane_moving_convex_adam/"
# print(len(set(d_ids).intersection([key[1:] for key in moving_convex])))
# d_ids
# moving_convex
for moving_idx, _id in enumerate(moving_convex):
    _id = _id[1:]
    sample = dataset[dataset.switch_3d_identifiers([_id])]

    img_path = Path().joinpath(target_path, "imagesTr", f"CrossMoDa_{moving_idx:03d}_0000.nii.gz")
    label_path = Path().joinpath(target_path, "labelsTr", f"CrossMoDa_{moving_idx:03d}.nii.gz")
    nib.save(nib.Nifti1Image(sample['image'].numpy(), affine=torch.eye(4)), img_path)
    nib.save(nib.Nifti1Image(sample['label'].numpy(), affine=torch.eye(4)), label_path)

In [44]:
d_ids = dataset.get_3d_ids()
target_path = "/share/data_rechenknecht01_2/weihsbach/nnunet/nnUNet_raw_data_base/nnUNet_raw_data/Task561_CM_domain_adaptation_insane_moving_deeds/"
# print(len(set(d_ids).intersection([key[1:] for key in moving_convex])))
# d_ids
# moving_convex
for moving_idx, _id in enumerate(moving_deeds):
    _id = _id[1:]
    sample = dataset[dataset.switch_3d_identifiers([_id])]

    img_path = Path().joinpath(target_path, "imagesTr", f"CrossMoDa_{moving_idx:03d}_0000.nii.gz")
    label_path = Path().joinpath(target_path, "labelsTr", f"CrossMoDa_{moving_idx:03d}.nii.gz")
    nib.save(nib.Nifti1Image(sample['image'].numpy(), affine=torch.eye(4)), img_path)
    nib.save(nib.Nifti1Image(sample['label'].numpy(), affine=torch.eye(4)), label_path)