In [None]:
import os
import time
from meidic_vtach_utils.run_on_recommended_cuda import get_cuda_environ_vars as get_vars
os.environ.update(get_vars(select="* -3 -4"))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp

from IPython.display import display
import nibabel as nib
from torch.utils.checkpoint import checkpoint

from sklearn.model_selection import KFold

from mdl_seg_class.metrics import dice3d, dice2d
from mdl_seg_class.visualization import get_overlay_grid
from curriculum_deeplab.mindssc import mindssc

import curriculum_deeplab.ml_data_parameters_utils as ml_data_parameters_utils
import wandb

import glob
import re
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np

print(torch.__version__)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

In [None]:
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv3d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            # nn.AdaptiveAvgPool2d(1),
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-3:]
        # x = F.adaptive_avg_pool3d(x, (1))
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='nearest')  # , align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = [nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())]

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)
 
        self.project = nn.Sequential(
            nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)


# Mobile-Net with depth-separable convolutions and residual connections
class ResBlock(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs


def create_model(output_classes: int = 14,input_channels: int = 1):
    # in_channels = torch.Tensor([1,16,24,24,32,32,32,64]).long()
    in_channels = torch.Tensor([input_channels, 24, 24, 32, 48, 48, 48, 64]).long()
    mid_channels = torch.Tensor([64, 128, 192, 192, 256, 256, 256, 384]).long()
    out_channels = torch.Tensor([24, 24, 32, 48, 48, 48, 64, 64]).long()
    mid_stride = torch.Tensor([1, 1, 1, 2, 1, 1, 1, 1])
    net = [nn.Identity()]
    for i in range(8):
        inc = int(in_channels[i])
        midc = int(mid_channels[i])
        outc = int(out_channels[i])
        strd = int(mid_stride[i])
        layer = nn.Sequential(nn.Conv3d(inc, midc, 1, bias=False), nn.BatchNorm3d(midc), nn.ReLU6(True),
                              nn.Conv3d(midc, midc, 3, stride=strd, padding=1, bias=False, groups=midc),
                              nn.BatchNorm3d(midc), nn.ReLU6(True),
                              nn.Conv3d(midc, outc, 1, bias=False), nn.BatchNorm3d(outc))
        if i == 0:
            layer[0] = nn.Conv3d(inc, midc, 3, padding=1, stride=2, bias=False)
        if (inc == outc) & (strd == 1):
            net.append(ResBlock(layer))
        else:
            net.append(layer)

    backbone = nn.Sequential(*net)

    count = 0
    # weight initialization
    for m in backbone.modules():
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            count += 1
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.zeros_(m.bias)

    print('#CNN layer', count)
    # complete model: MobileNet + ASPP + head (with a single skip connection)
    # newer model (one more stride, no groups in head)
    aspp = ASPP(64, (2, 4, 8, 16, 32), 128)
    head = nn.Sequential(nn.Conv3d(128 + 24, 64, 1, padding=0, groups=1, bias=False), nn.BatchNorm3d(64), nn.ReLU(), \
                         nn.Conv3d(64, 64, 3, groups=1, padding=1, bias=False), nn.BatchNorm3d(64), nn.ReLU(), \
                         nn.Conv3d(64, output_classes, 1))
    return backbone, aspp, head


def apply_model(backbone, aspp, head, img, checkpointing=True, return_intermediate=False):
    if checkpointing:
        x1 = checkpoint(backbone[:3], img)
        x2 = checkpoint(backbone[3:], x1)
        y = checkpoint(aspp, x2)
        y1 = torch.cat((x1, F.interpolate(y, scale_factor=2)), 1)
        output_j = checkpoint(head, y1)
    else:
        x1 = backbone[:3](img)
        x2 = backbone[3:](x1)
        y = aspp(x2)
        y1 = torch.cat((x1, F.interpolate(y, scale_factor=2)), 1)
        output_j = head(y1)
    if return_intermediate:
        return y1,output_j
    else:
        return output_j


In [None]:
from pathlib import Path

class CrossMoDa_Data(Dataset):
    def __init__(self,
        base_dir, domain, state,
        ensure_labeled_pairs=True, use_additional_data=False, resample = True,
        size:tuple = (96,96,60), normalize:bool = True):
        """
        Function to create Dataset structure with crossMoDa data.
        The function allows to use different preproccessing steps of the crossMoDa data set
        and using additinal data from TCIA database.
        The data can also be resampled to a desired size and normalized to mean=0 and std=1.

        Parameters:
                base_dir (os.Pathlike): provide the directory which contains "L1..." to "L4..." directories
                domain (str): choose which domain to load. Can be set to "source", "target" or "validation". Source are ceT1, target and validation hrT2 images.

                state (str): state of preprocessing:    "l1" = original data,
                                                        "l2" = resampled data @ 0.5mm,
                                                        "l3" = center-cropped data,
                                                        "l4" = image specific crops for desired anatomy

                ensure_labeled_pairs (bool): Only images with corresponding labels will be loaded (default: True)
                
                use_additional_data (bool): set to True to use additional data from TCIA (default: False)

                resample (bool): set to False to disable resampling to desired size (default: True)

                size (tuple): 3d-tuple(int) to which the data is resampled. Unused if resample=False. (default: (96,96,60)).
                    WARNING: choosing large sizes or not resampling can lead to excess memory usage

                normalize (bool): set to False to disable normalization to mean=0, std=1 for each image (default: True)

        Returns:
                torch.utils.data.Dataset containing CrossMoDa data

        Useful Links:
        CrossMoDa challenge:
        https://crossmoda.grand-challenge.org/

        ToDos:
            extend to other preprocessing states

        Example:
            dataset = CrossMoDa_source('original')

            data = dataset.get_data()

        """

        #define finished preprocessing states here with subpath and default size
        states = {
            'l1':('L1_original/', (512,512,160)),
            'l2':('L2_resampled_05mm/', (420,420,360)),
            'l3':('L3_coarse_fixed_crop/', (128,128,192)),
            'l4':('L4_fine_localized_crop/', (128,128,128))
        }
        t0 = time.time()
        #choose directory with data according to chosen preprocessing state
        if state not in states: raise Exception("Unknown state. Choose one of: "+str(states.keys))
        state_dir = states[state.lower()][0] #get sub directory
        if not resample: size = states[state.lower()][1] #set size to default defined at top of file
        path = base_dir + state_dir
        #get file list
        if domain.lower() =="ceT1" or domain.lower() =="source":
            directory = "source_training_labeled/"
            add_directory = "__additional_data_source_domain__"
            domain = "ceT1"
        elif domain.lower() =="hrT2" or domain.lower() =="target":
            directory = "target_training_unlabeled/"
            add_directory = "__additional_data_target_domain__"
            domain = "hrT2"
        elif domain.lower() =="validation":
            directory = "target_validation_unlabeled/"
        else:
            raise Exception("Unknown domain. Choose either 'source', 'target' or 'validation'")
        files = sorted(glob.glob(os.path.join(path+directory , "*.nii.gz")))

        if domain == "hrT2":
            files = files+sorted(glob.glob(os.path.join(path+"__omitted_labels_target_training__" , "*.nii.gz")))
        if domain.lower() =="validation":
            files = files+sorted(glob.glob(os.path.join(path+"__omitted_labels_target_validation__" , "*.nii.gz")))
        if use_additional_data and domain.lower() != "validation": #add additional data to file list
            files = files+sorted(glob.glob(os.path.join(path+add_directory , "*.nii.gz")))
            files = [i for i in files if "additionalLabel" not in i] #remove additional label files

        if ensure_labeled_pairs:
            def get_bare_basename(_path):
                return str(Path(_path.replace('.nii.gz', '')).stem)

            labeled_files = [
                _path for _path in files \
                    if '_Label' in get_bare_basename(_path) \
                    or get_bare_basename(_path)+'_Label' in [get_bare_basename(_path) for _path in files]
            ]
            files = labeled_files
            
        #initialize variables
        self.imgs = torch.zeros(0,size[0],size[1],size[2])
        self.labels = torch.zeros(0,size[0],size[1],size[2])
        self.img_nums = []
        self.label_nums = []
        #load data
        print("Loading CrossMoDa {} images and labels...".format(domain))

        for i,f in enumerate(tqdm(files)):
            # tqdm.write(f"Loading {f}")
            if "Label" in f:
                self.label_nums.append(int(re.findall(r'\d+', os.path.basename(f))[0]))
                tmp = torch.from_numpy(nib.load(f).get_fdata())
                if resample: #resample image to specified size
                    tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size,mode='nearest').squeeze()
                if tmp.shape != size: #for size missmatch use symmetric padding with 0
                    difs = [size[0]-tmp.size(0),size[1]-tmp.size(1),size[2]-tmp.size(2)]
                    pad = (difs[-1]//2,difs[-1]-difs[-1]//2,difs[-2]//2,difs[-2]-difs[-2]//2,difs[-3]//2,difs[-3]-difs[-3]//2)
                    tmp = F.pad(tmp,pad)
                self.labels = torch.cat((self.labels,tmp.unsqueeze(0)),dim=0)
            elif domain in f:
                self.img_nums.append(int(re.findall(r'\d+', os.path.basename(f))[0]))
                tmp = torch.from_numpy(nib.load(f).get_fdata())
                if resample: #resample image to specified size
                    tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size,mode='trilinear',align_corners=False).squeeze()
                if normalize: #normalize image to zero mean and unit std
                    tmp = (tmp - tmp.mean()) / tmp.std()
                if tmp.shape != size: #for size missmatch use symmetric padding with 0
                    difs = [size[0]-tmp.size(0),size[1]-tmp.size(1),size[2]-tmp.size(2)]
                    pad = (difs[-1]//2,difs[-1]-difs[-1]//2,difs[-2]//2,difs[-2]-difs[-2]//2,difs[-3]//2,difs[-3]-difs[-3]//2)
                    tmp = F.pad(tmp,pad)
                self.imgs = torch.cat((self.imgs,tmp.unsqueeze(0)),dim=0)
        self.labels = self.labels.long()
        #check for consistency
        print("Equal image and label numbers: {}".format(self.img_nums==self.label_nums))
        print("Image shape: {}, mean.: {:.2f}, std.: {:.2f}".format(self.imgs.shape,self.imgs.mean(),self.imgs.std()))
        print("Label shape: {}, max.: {}".format(self.labels.shape,torch.max(self.labels)))
        print("Data import finished. Elapsed time: {:.1f} s".format(time.time()-t0 ))

    def __len__(self):
        return int(self.imgs.size(0))

    def __getitem__(self, idx):
        image = self.imgs[idx]
        label = self.labels[idx]
        return image, label, idx 

    def get_data(self):
        return self.imgs,self.labels

    def get_image_numbers(self):
        return self.img_nums

    def get_label_numbers(self):
        return self.label_nums

In [None]:
#training routine
def display_nonempty_seg_slices(img_slices, seg_slices, alpha=.5):
    color_map = {
        0: None, 
        1: (255,0,0), #ONEHOT id and RGB color
        2: (0,255,0)
    }

    idx_dept_with_segs, *_ = torch.nonzero(seg_slices > 0, as_tuple=True)
    
    if idx_dept_with_segs.nelement() > 0:
        idx_dept_with_segs = idx_dept_with_segs.unique()

        img_slices = img_slices[idx_dept_with_segs]
        seg_slices = seg_slices[idx_dept_with_segs]
        
        pil_ov, _ = get_overlay_grid(
            img_slices, 
            torch.nn.functional.one_hot(seg_slices, 3), 
            color_map, n_per_row=10, alpha=alpha
        )
        display(pil_ov)

In [None]:
training_dataset = CrossMoDa_Data("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/", 
    domain="source", state="l4", ensure_labeled_pairs=True)
validation_dataset = CrossMoDa_Data("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/", 
    domain="validation", state="l4", ensure_labeled_pairs=True)
# target_dataset = CrossMoDa_Data("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/", 
#     domain="target", state="l4", ensure_labeled_pairs=True)

In [None]:
train_subset = torch.utils.data.Subset(training_dataset,range(2))
for img, seg, sample_idx in train_subset:
    print(f"Sample {sample_idx}:")
    img_slices = img.permute(2,0,1).unsqueeze(1)
    seg_slices = seg.permute(2,0,1)
    print("With ground-truth overlay")
    display_nonempty_seg_slices(img_slices, seg_slices, alpha=.3)

    print("W/o ground-truth overlay")
    display_nonempty_seg_slices(img_slices, seg_slices, alpha=.0)

In [None]:
for img, seg, _ in validation_dataset:
    img_slices = img.permute(2,0,1).unsqueeze(1)
    seg_slices = seg.permute(2,0,1)
    display_nonempty_seg_slices(img_slices, seg_slices, alpha=.3)
    # display_nonempty_seg_slices(img_slices, seg_slices, alpha=.0)

In [None]:
img, seg, _ = target_dataset[30]
img_slices = img.permute(2,0,1).unsqueeze(1)
seg_slices = seg.permute(2,0,1)
display_nonempty_seg_slices(img_slices, seg_slices, alpha=.3)
display_nonempty_seg_slices(img_slices, seg_slices, alpha=.0)

In [None]:
def augmentAffine(img_in, seg_in, strength=0.05):
    """
    3D affine augmentation on image and segmentation mini-batch on GPU.
    (affine transf. is centered: trilinear interpolation and zero-padding used for sampling)
    :input: img_in batch (torch.cuda.FloatTensor), seg_in batch (torch.cuda.LongTensor)
    :return: augmented BxCxTxHxW image batch (torch.cuda.FloatTensor), augmented BxTxHxW seg batch (torch.cuda.LongTensor)
    """
    B,C,D,H,W = img_in.size()
    affine_matrix = (torch.eye(3,4).unsqueeze(0) + torch.randn(B, 3, 4) * strength).to(img_in.device)

    meshgrid = F.affine_grid(affine_matrix,torch.Size((B,1,D,H,W)), align_corners=False)

    img_out = F.grid_sample(img_in, meshgrid, padding_mode='border')
    seg_out = F.grid_sample(seg_in.float(), meshgrid, mode='nearest')

    return img_out, seg_out.long()



def augmentNoise(img_in,strength=0.05):
    return img_in + strength*torch.randn_like(img_in)

In [None]:
def save_model(backbone, aspp, head, inst_parameters, class_parameters, 
    optimizer, optimizer_inst_param, optimizer_class_param, 
    scaler, name):
    
    torch.save(backbone.state_dict(), name + '_backbone.pth')
    torch.save(aspp.state_dict(), name + '_aspp.pth')
    torch.save(head.state_dict(), name + '_head.pth')
    
    torch.save(inst_parameters, name + '_inst_parameters.pth')
    torch.save(class_parameters, name + '_class_parameters.pth')
    
    torch.save(optimizer.state_dict(), name + '_optimizer.pth')
    torch.save(optimizer_inst_param.state_dict(), name + '_optimizer_inst_param.pth')
    torch.save(optimizer_class_param.state_dict(), name + '_optimizer_class_param.pth')
    
    torch.save(scaler.state_dict(), name + '_grad_scaler.pth')

def load_model(name, config, dataset_len):
    if config.use_mind:
        input_channels = 12
    else:
        input_channels = 1

    backbone, aspp, head = create_model(output_classes=config.num_classes, input_channels=input_channels)
    optimizer = torch.optim.Adam(
        list(backbone.parameters()) + list(aspp.parameters()) + list(head.parameters()),
        lr=config.lr
    )
    
    (_, _, optimizer_class_param, optimizer_inst_param) = \
        ml_data_parameters_utils.get_class_inst_data_params_n_optimizer(
            config.init_class_param, config.learn_class_parameters, config.lr_class_param,
            config.init_inst_param, config.learn_inst_parameters, config.lr_inst_param,
            nr_classes=config.num_classes,
            nr_instances=dataset_len,
            device='cuda'
        )
    
    scaler = amp.GradScaler()
    
    backbone.load_state_dict(torch.load(name + '_backbone.pth'))
    aspp.load_state_dict(torch.load(name + '_aspp.pth'))
    head.load_state_dict(torch.load(name + '_head.pth'))
    
    inst_parameters = torch.load(name + '_inst_parameters.pth')
    class_parameters = torch.load(name + '_class_parameters.pth')
    
    optimizer.load_state_dict(torch.load(name + '_optimizer.pth'))
    optimizer_inst_param.load_state_dict(torch.load(name + '_optimizer_inst_param.pth'))
    optimizer_class_param.load_state_dict(torch.load(name + '_optimizer_class_param.pth'))
    
    scaler.load_state_dict(torch.load(name + '_grad_scaler.pth'))
                                          
    return (backbone, aspp, head, inst_parameters, class_parameters, 
        optimizer, optimizer_inst_param, optimizer_class_param, 
        scaler)

In [103]:
def train_DL(run_name, config_dict, training_dataset):
    
    kf = KFold(n_splits=config_dict['num_folds'])
    kf.get_n_splits(training_dataset)
    
    fold_means_no_bg = []
    
    for fold_idx, (train_idxs, val_idxs) in enumerate(kf.split(training_dataset)):
        run = wandb.init(project="curriculum_deeplab", name=run_name, group=f"fold{fold_idx}", job_type="train",
            config=config_dict, settings=wandb.Settings(start_method="thread"),
            mode=config_dict['wandb_mode']
        )
        config = wandb.config

        disturbed_idxs = train_idxs[:config.disturbed_flipped_num]

        if config.use_mind:
            C =12
        else:
            C = 1
        _, all_segs = training_dataset.get_data()

        class_weight = torch.sqrt(1.0/(torch.bincount(all_segs.long().view(-1)).float()))
        class_weight = class_weight/class_weight.mean()
        class_weight[0] = 0.15
        class_weight = class_weight.cuda()
        print('inv sqrt class_weight', class_weight)
        
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idxs)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_idxs)

        train_dataloader = DataLoader(training_dataset, batch_size=config.batch_size, shuffle=False, sampler=train_subsampler)
        val_dataloader = DataLoader(training_dataset, batch_size=config.batch_size, shuffle=False, sampler=val_subsampler)

        backbone, aspp, head = create_model(output_classes=config.num_classes, input_channels=C)
        optimizer = torch.optim.Adam(list(backbone.parameters())+list(aspp.parameters())+list(head.parameters()),
            lr=config.lr)

        # Initialize class and instance based temperature
        (class_parameters, inst_parameters, optimizer_class_param, optimizer_inst_param) = \
            ml_data_parameters_utils.get_class_inst_data_params_n_optimizer(
                config.init_class_param, config.learn_class_parameters, config.lr_class_param,
                config.init_inst_param, config.learn_inst_parameters, config.lr_inst_param,
                nr_classes=config.num_classes,
                nr_instances=len(train_dataloader.dataset),
                device='cuda'
            )

        criterion = nn.CrossEntropyLoss(class_weight)
        scaler = amp.GradScaler()

        top1 = ml_data_parameters_utils.AverageMeter('Acc@1', ':6.2f')
        top5 = ml_data_parameters_utils.AverageMeter('Acc@5', ':6.2f')

        backbone.cuda() 
        backbone.train()
        aspp.cuda() 
        aspp.train()
        head.cuda() 
        head.train()
        t0 = time.time()

        for epx in range(config.epochs):
            backbone.train()
            aspp.train()
            head.train()

            optimizer.zero_grad()
            if config.learn_class_parameters:
                optimizer_class_param.zero_grad()
            if config.learn_inst_parameters:
                optimizer_inst_param.zero_grad()

            # Load datta
            b_img, b_seg, idxs_dataset = next(iter(train_dataloader))
            
            if disturbed_idxs:
                # 3D flip the label target
                b_seg[disturbed_idxs] = torch.flip(b_seg[disturbed_idxs], dims=(-3,-2,-1))

            b_img = b_img.unsqueeze(1)
            b_seg = b_seg.unsqueeze(1)

            b_img, b_seg = b_img.float().cuda(), b_seg.cuda()
            b_img, b_seg = augmentAffine(b_img, b_seg, strength=0.1)
            
            b_img = augmentNoise(b_img, strength=0.02)
            
            if config.use_mind:
                b_img = mindssc(b_img)

            b_interpolated_seg = F.interpolate(b_seg.float(), scale_factor=0.5, mode='nearest').long()
            b_interpolated_seg = b_interpolated_seg.squeeze(1)
            
            b_img.requires_grad = True
            
            #img_mr.requires_grad = True
            with amp.autocast(enabled=True):
                logits = apply_model(backbone, aspp, head, b_img, checkpointing=True)

                if config.learn_class_parameters or config.learn_inst_parameters:
                    # Compute data parameters for instances in the minibatch
                    class_parameter_minibatch = torch.tensor([0])
                    # class_parameter_minibatch = class_parameters[b_seg] TODO: Readd that again
                    inst_parameter_minibatch = inst_parameters[idxs_dataset]
                    data_parameter_minibatch = ml_data_parameters_utils.get_data_param_for_minibatch(
                                                    learn_class_parameters=learn_class_parameters, 
                                                    learn_inst_parameters=learn_inst_parameters,
                                                    class_param_minibatch=class_parameter_minibatch,
                                                    inst_param_minibatch=inst_parameter_minibatch)

                    # Compute logits scaled by data parameters
                    logits = logits / data_parameter_minibatch.view([-1] + [1]*(logits.dim()-1))
  
                loss = criterion(logits, b_interpolated_seg)
                # Apply weight decay on data parameters
                if config.learn_class_parameters or config.learn_inst_parameters:
                    loss = ml_data_parameters_utils.apply_weight_decay_data_parameters(
                        learn_inst_parameters, wd_inst_param,
                        learn_class_parameters, wd_class_param,
                        loss,
                        class_parameter_minibatch=class_parameter_minibatch,
                        inst_parameter_minibatch=inst_parameter_minibatch)

            scaler.scale(loss).backward()
            scaler.step(optimizer)

            if config.learn_class_parameters:
                scaler.step(optimizer_class_param)
            if config.learn_inst_parameters:
                scaler.step(optimizer_inst_param)

            scaler.update()

            # Clamp class and instance level parameters within certain bounds
            if config.learn_class_parameters or config.learn_inst_parameters:
                ml_data_parameters_utils.clamp_data_parameters(
                    config.skip_clamp_data_param, config.learn_inst_parameters, config.learn_class_parameters,
                    class_parameters, inst_parameters,
                    config.clamp_inst_sigma_config, config.clamp_cls_sigma_config)

            # # Measure accuracy and record loss # TODO add again
            # acc1, acc5 = ml_data_parameters_utils.compute_topk_accuracy(logits, b_interpolated_seg, topk=(1, 1))
            # top1.update(acc1[0], b_img.size(0))
            # top5.update(acc5[0], b_img.size(0))
            
            if epx % config.log_every == 0:
                dice = dice3d(
                    torch.nn.functional.one_hot(logits.argmax(1), 3),
                    torch.nn.functional.one_hot(b_interpolated_seg, 3), one_hot_torch_style=True
                )
                # Log data parameters
                ml_data_parameters_utils.log_intermediate_iteration_stats(
                    epx,
                    config.learn_class_parameters, config.learn_inst_parameters,
                    class_parameters, inst_parameters, top1, top5)

                with amp.autocast(enabled=True):
                    backbone.eval()
                    aspp.eval()
                    head.eval()

                    with torch.no_grad():
                        b_val_img, b_val_seg, _ = next(iter(val_dataloader))
                        b_val_img, b_val_seg = (
                            b_val_img.unsqueeze(1).float().cuda(), 
                            b_val_seg.unsqueeze(1).float().cuda()
                        )
                        if config.do_plot:
                            print("Show val img/lbl")
                            val_img_slices = b_val_img.detach().squeeze(0).permute(3,0,1,2)
                            val_seg_slices = b_val_seg.detach().squeeze(0).squeeze(0).permute(2,0,1).to(dtype=torch.int64)
                            display_nonempty_seg_slices(val_img_slices, val_seg_slices)

                        if config.use_mind:
                            b_val_img = mindssc(b_val_img)

                        b_interpolated_val_seg = F.interpolate(b_val_seg, scale_factor=0.5, mode='nearest').long()
                        b_interpolated_val_seg = b_interpolated_val_seg.squeeze(1)
                        
                        output_val = apply_model(backbone, aspp, head, b_val_img, checkpointing=False)
                    
                        val_dice = dice3d(
                            torch.nn.functional.one_hot(output_val.argmax(1), 3),
                            torch.nn.functional.one_hot(b_interpolated_val_seg, 3), one_hot_torch_style=True
                        )

                        if config.do_plot:
                            print("Show val lbl/prediction")
                            pred_seg_slices = output_val.argmax(1).squeeze(0).permute(2,0,1)
                            pred_seg_slices = F.upsample_nearest(pred_seg_slices.unsqueeze(0).unsqueeze(0).float(), scale_factor=2).squeeze(0).squeeze(0).long()
                            display_nonempty_seg_slices(val_seg_slices.unsqueeze(1), pred_seg_slices)

                dice_mean_no_bg = round(dice.mean(dim=0)[1:].mean().item(),4)
                val_dice_mean_no_bg = round(val_dice.mean(dim=0)[1:].mean().item(),4)
                
                print(
                    f'fold{fold_idx}_epx', epx,round(time.time()-t0,2),'s',
                    f'fold{fold_idx}_loss', round(loss.item(),6),
                    f'fold{fold_idx}_dice_tensor', dice, 
                    f'fold{fold_idx}_dice mean (nobg)', dice_mean_no_bg,
                    f'fold{fold_idx}_val_dice_mean (nobg)', val_dice_mean_no_bg
                )

                wandb.log({f'losses/loss': loss}, step=epx)
                # wandb.log({f'scores/dice_tensor': dice}, step=epx)
                # wandb.log({f'scores/val_dice_tensor': val_dice}, step=epx)
                wandb.log({f'scores/dice_mean_wo_bg': dice_mean_no_bg}, step=epx)
                wandb.log({f'scores/val_dice_mean_wo_bg': round(val_dice_mean_no_bg, 4)}, step=epx)
                
                # print("Class parameters: ", class_parameters)
                # print("Instance parameters: ", inst_parameters)

            if config.debug:
                break

        # TODO log instance parameters and disturbed instance parameters here
        backbone.cpu()
        aspp.cpu() 
        head.cpu()
        
        save_model(backbone, aspp, head, inst_parameters, class_parameters, 
            optimizer, optimizer_inst_param, optimizer_class_param, 
            scaler, f"{config.mdl_save_prefix}_fold{fold_idx}")
        

In [106]:
config_dict = {
    'num_folds': 3,
    'num_classes': 3,
    'use_mind': True,
    'epochs': 2000,
    'batch_size': 4,
    'lr': 0.001,
    # Data parameter config
    'init_class_param': 1.0, 
    'learn_class_parameters': False, 
    'lr_class_param': 0.1,
    'init_inst_param': 1.0, 
    'learn_inst_parameters': False, 
    'lr_inst_param': 0.1,
    'wd_inst_param': 0.0,
    'wd_class_param': 0.0,
    
    'skip_clamp_data_param': False,
    'clamp_inst_sigma_config': {
        'min': np.log(1/20),
        'max': np.log(20)
    },
    'clamp_cls_sigma_config': {
        'min': np.log(1/20),
        'max': np.log(20)
    },

    'log_every': 50,
    'mdl_save_prefix': 'data/models',
    
    'do_plot': False,
    'debug': False,
    'wandb_mode': "online",

    'disturbed_flipped_num': 0,
}

In [107]:
run_name = wandb.util.generate_id()
train_DL(run_name, config_dict, training_dataset)
wandb.finish()

0,1
losses/loss,▁
scores/dice_mean_wo_bg,▁
scores/val_dice_mean_wo_bg,▁
train_iteration_stats/accuracy_top1,▁
train_iteration_stats/accuracy_top5,▁

0,1
losses/loss,1.02801
scores/dice_mean_wo_bg,0.0114
scores/val_dice_mean_wo_bg,0.0003
train_iteration_stats/accuracy_top1,0.0
train_iteration_stats/accuracy_top5,0.0


inv sqrt class_weight tensor([0.1500, 0.6111, 2.3484], device='cuda:0')
#CNN layer 24


  if disturbed_idxs:


fold0_epx 0 0.88 s fold0_loss 1.539262 fold0_dice_tensor tensor([[0.0931, 0.0366, 0.0004],
        [0.1142, 0.0000, 0.0003],
        [0.1047, 0.0000, 0.0005],
        [0.0989, 0.0033, 0.0005]]) fold0_dice mean (nobg) 0.0052 fold0_val_dice_mean (nobg) 0.0002
fold0_epx 50 33.07 s fold0_loss 0.221999 fold0_dice_tensor tensor([[0.9985, 0.0351, 0.0000],
        [0.9969, 0.0096, 0.0000],
        [0.9996, 0.0000, 0.0000],
        [0.9998, 0.0000, 0.0000]]) fold0_dice mean (nobg) 0.0056 fold0_val_dice_mean (nobg) 0.0206
fold0_epx 100 66.35 s fold0_loss 0.097489 fold0_dice_tensor tensor([[0.9998, 0.0000, 0.0000],
        [0.9981, 0.0000, 0.0000],
        [0.9970, 0.6499, 0.0000],
        [0.9973, 0.7398, 0.0000]]) fold0_dice mean (nobg) 0.1737 fold0_val_dice_mean (nobg) 0.1088
fold0_epx 150 99.09 s fold0_loss 0.042586 fold0_dice_tensor tensor([[0.9998, 0.0000, 0.5385],
        [0.9997, 0.0000, 0.5455],
        [0.9941, 0.7193, 0.4490],
        [0.9991, 0.0000, 0.6364]]) fold0_dice mean (nobg) 0

In [None]:
def inference_DL(run_name, config_dict, inf_dataset):
    
    run = wandb.init(project="curriculum_deeplab", name=run_name, group=f"testing", job_type="test",
            config=config_dict, settings=wandb.Settings(start_method="thread"),
            mode=config_dict['wandb_mode']
    )
    config = wandb.config

    debug = config.debug
    num_folds = config.num_folds

    score_dicts = []
    
    for fold_idx in range(num_folds):
        backbone, aspp, head, *_ = load_model(f"{config.mdl_save_prefix}_fold{fold_idx}", config, len(validation_dataset))

        backbone.eval()
        aspp.eval()
        head.eval()
        
        for img, seg, sample_idx in inf_dataset:

            img, seg = (
                img.unsqueeze(0).unsqueeze(0).float(), 
                seg.unsqueeze(0).unsqueeze(0)
            )

            if config.use_mind:
                img = mindssc(img)

            if config.do_plot:
                img_slices = img[0:1].permute(0,1,4,2,3).squeeze().unsqueeze(1)
                seg_slices = seg[0:1].permute(0,1,4,2,3).squeeze().to(dtype=torch.int64)
                display_nonempty_seg_slices(img_slices, seg_slices)

            interpolated_seg = F.interpolate(seg.float(), scale_factor=0.5, mode='nearest').squeeze(0).long()

            img.requires_grad = True
            with amp.autocast(enabled=True):
                with torch.no_grad():

                    interpolated_seg = F.interpolate(seg.float(), scale_factor=0.5, mode='nearest').long()
                    interpolated_seg = interpolated_seg.squeeze(1)
                    
                    output_val = apply_model(backbone, aspp, head, img, checkpointing=False)

                    inf_dice = dice3d(
                        torch.nn.functional.one_hot(output_val.argmax(1), 3),
                        torch.nn.functional.one_hot(interpolated_seg, 3), one_hot_torch_style=True
                    )
                if config.do_plot:
                    lbl_slices = seg.detach().squeeze(0).permute(3,0,1,2)
                    pred_seg_slices = output_val.argmax(1).squeeze(0).permute(2,0,1)
                    pred_seg_slices = F.upsample_nearest(pred_seg_slices.unsqueeze(0).unsqueeze(0).float(), scale_factor=2).squeeze(0).squeeze(0).long()
                    display_nonempty_seg_slices(lbl_slices, pred_seg_slices)

                for class_idx, class_dice in enumerate(inf_dice.tolist()[0]):
                    score_dicts.append(
                        {
                            'fold_idx': fold_idx,
                            'sample_idx': sample_idx,
                            'class_idx': class_idx,
                            'dice': class_dice,
                        }
                    )
                # Mean over all classes (w/o background)
                dice_mean_no_bg = inf_dice.mean(dim=0)[1:].mean()
                wandb.log({f'scores/dice_fold_{fold_idx}': dice_mean_no_bg, 'sample_idx': sample_idx})
                print(f"Dice of validation sample {sample_idx} @(fold={fold_idx}): {dice_mean_no_bg.item():.2f}")

            if debug:
                break

    mean_inf_dice = torch.tensor([score['dice'] for score in score_dicts if score['class_idx'] != 0]).mean()
    print(f"Mean dice over all folds, classes and samples: {mean_inf_dice.item()*100:.2f}%")
    wandb.log({'scores/mean_dice_all_folds_samples_classes': mean_inf_dice})
    wandb.finish()

    return score_dicts

In [None]:
score_dicts = inference_DL(run_name, config_dict, validation_dataset)

In [None]:
backbone, aspp, head, inst_parameters, *_ = load_model(f"{config.mdl_save_prefix}_fold{fold_idx}", config, len(validation_dataset))

table = wandb.Table(data=inst_parameters.tolist(), columns=["instance_parameters"])
wandb.log(
    {'data_parameters/instance_parameters': 
        wandb.plot.histogram(table, "instance_parameters", title="Instance parameters")
    }
)