In [None]:
import os
import time
from meidic_vtach_utils.run_on_recommended_cuda import get_cuda_environ_vars as get_vars
os.environ.update(get_vars(select="* -3 -4"))
from curriculum_deeplab.crossmoda_dataloader import CrossMoDa_Data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
from torchvision import transforms
import torchvision.utils as vision_utils
from IPython.display import display
import nibabel as nib
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import pandas as pd  
import numpy as np
import scipy.ndimage
from sklearn.model_selection import KFold
import cc3d
from mdl_seg_class.metrics import dice3d

from PIL import Image as pil_image

import glob
import re
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

print(torch.__version__)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

In [None]:
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv3d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            # nn.AdaptiveAvgPool2d(1),
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-3:]
        # x = F.adaptive_avg_pool3d(x, (1))
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='nearest')  # , align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = [nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())]

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)
 
        self.project = nn.Sequential(
            nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)


# Mobile-Net with depth-separable convolutions and residual connections
class ResBlock(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs


def create_model(output_classes: int = 14,input_channels: int = 1):
    # in_channels = torch.Tensor([1,16,24,24,32,32,32,64]).long()
    in_channels = torch.Tensor([input_channels, 24, 24, 32, 48, 48, 48, 64]).long()
    mid_channels = torch.Tensor([64, 128, 192, 192, 256, 256, 256, 384]).long()
    out_channels = torch.Tensor([24, 24, 32, 48, 48, 48, 64, 64]).long()
    mid_stride = torch.Tensor([1, 1, 1, 2, 1, 1, 1, 1])
    net = [nn.Identity()]
    for i in range(8):
        inc = int(in_channels[i])
        midc = int(mid_channels[i])
        outc = int(out_channels[i])
        strd = int(mid_stride[i])
        layer = nn.Sequential(nn.Conv3d(inc, midc, 1, bias=False), nn.BatchNorm3d(midc), nn.ReLU6(True),
                              nn.Conv3d(midc, midc, 3, stride=strd, padding=1, bias=False, groups=midc),
                              nn.BatchNorm3d(midc), nn.ReLU6(True),
                              nn.Conv3d(midc, outc, 1, bias=False), nn.BatchNorm3d(outc))
        if i == 0:
            layer[0] = nn.Conv3d(inc, midc, 3, padding=1, stride=2, bias=False)
        if (inc == outc) & (strd == 1):
            net.append(ResBlock(layer))
        else:
            net.append(layer)

    backbone = nn.Sequential(*net)

    count = 0
    # weight initialization
    for m in backbone.modules():
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            count += 1
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.zeros_(m.bias)

    print('#CNN layer', count)
    # complete model: MobileNet + ASPP + head (with a single skip connection)
    # newer model (one more stride, no groups in head)
    aspp = ASPP(64, (2, 4, 8, 16, 32), 128)
    head = nn.Sequential(nn.Conv3d(128 + 24, 64, 1, padding=0, groups=1, bias=False), nn.BatchNorm3d(64), nn.ReLU(), \
                         nn.Conv3d(64, 64, 3, groups=1, padding=1, bias=False), nn.BatchNorm3d(64), nn.ReLU(), \
                         nn.Conv3d(64, output_classes, 1))
    return backbone, aspp, head


def apply_model(backbone, aspp, head, img, checkpointing=True, return_intermediate=False):
    if checkpointing:
        x1 = checkpoint(backbone[:3], img)
        x2 = checkpoint(backbone[3:], x1)
        y = checkpoint(aspp, x2)
        y1 = torch.cat((x1[0:1], F.interpolate(y[0:1], scale_factor=2)), 1)
        output_j = checkpoint(head, y1)
    else:
        x1 = backbone[:3](img)
        x2 = backbone[3:](x1)
        y = aspp(x2)
        y1 = torch.cat((x1[0:1], F.interpolate(y[0:1], scale_factor=2)), 1)
        output_j = head(y1)
    if return_intermediate:
        return y1,output_j
    else:
        return output_j


In [None]:
class CrossMoDa_Data(Dataset):
    def __init__(self,
        base_dir, domain, state,
        use_additional_data=False, resample = True,
        size:tuple = (96,96,60), normalize:bool = True):
        """
        Function to create Dataset structure with crossMoDa data.
        The function allows to use different preproccessing steps of the crossMoDa data set
        and using additinal data from TCIA database.
        The data can also be resampled to a desired size and normalized to mean=0 and std=1.

        Parameters:
                base_dir (os.Pathlike): provide the directory which contains "L1..." to "L4..." directories
                domain (str): choose which domain to load. Can be set to "source", "target" or "validation". Source are ceT1, target and validation hrT2 images.

                state (str): state of preprocessing:    "l1" = original data,
                                                        "l2" = resampled data @ 0.5mm,
                                                        "l3" = center-cropped data,
                                                        "l4" = image specific crops for desired anatomy

                use_additional_data (bool): set to True to use additional data from TCIA (default: False)

                resample (bool): set to False to disable resampling to desired size (default: True)

                size (tuple): 3d-tuple(int) to which the data is resampled. Unused if resample=False. (default: (96,96,60)).
                    WARNING: choosing large sizes or not resampling can lead to excess memory usage

                normalize (bool): set to False to disable normalization to mean=0, std=1 for each image (default: True)

        Returns:
                torch.utils.data.Dataset containing CrossMoDa data

        Useful Links:
        CrossMoDa challenge:
        https://crossmoda.grand-challenge.org/

        ToDos:
            extend to other preprocessing states

        Example:
            dataset = CrossMoDa_source('original')

            data = dataset.get_data()

        """

        #define finished preprocessing states here with subpath and default size
        states = {
            'l1':('L1_original/', (512,512,160)),
            'l2':('L2_resampled_05mm/', (420,420,360)),
            'l3':('L3_coarse_fixed_crop/', (128,128,192)),
            'l4':('L4_fine_localized_crop/', (128,128,128))
        }
        t0 = time.time()
        #choose directory with data according to chosen preprocessing state
        if state not in states: raise Exception("Unknown state. Choose one of: "+str(states.keys))
        state_dir = states[state.lower()][0] #get sub directory
        if not resample: size = states[state.lower()][1] #set size to default defined at top of file
        path = base_dir + state_dir
        #get file list
        if domain.lower() =="ceT1" or domain.lower() =="source":
            directory = "source_training_labeled/"
            add_directory = "__additional_data_source_domain__"
            domain = "ceT1"
        elif domain.lower() =="hrT2" or domain.lower() =="target":
            directory = "target_training_unlabeled/"
            add_directory = "__additional_data_target_domain__"
            domain = "hrT2"
        elif domain.lower() =="validation":
            directory = "target_validation_unlabeled/"
        else:
            raise Exception("Unknown domain. Choose either 'source', 'target' or 'validation'")
        files = sorted(glob.glob(os.path.join(path+directory , "*.nii.gz")))

        if domain == "hrT2":
            files = files+sorted(glob.glob(os.path.join(path+"__omitted_labels_target_training__" , "*.nii.gz")))
        if domain.lower() =="validation":
            files = files+sorted(glob.glob(os.path.join(path+"__omitted_labels_target_validation__" , "*.nii.gz")))
        if use_additional_data and domain.lower() != "validation": #add additional data to file list
            files = files+sorted(glob.glob(os.path.join(path+add_directory , "*.nii.gz")))
            files = [i for i in files if "additionalLabel" not in i] #remove additional label files

        #initialize variables
        self.imgs = torch.zeros(0,size[0],size[1],size[2])
        self.labels = torch.zeros(0,size[0],size[1],size[2])
        self.img_nums = []
        self.label_nums = []
        #load data
        print("Loading CrossMoDa {} images and labels...".format(domain))

        for i,f in enumerate(tqdm(files)):
            # tqdm.write(f"Loading {f}")
            if "Label" in f:
                self.label_nums.append(int(re.findall(r'\d+', os.path.basename(f))[0]))
                tmp = torch.from_numpy(nib.load(f).get_fdata())
                if resample: #resample image to specified size
                    tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size,mode='nearest').squeeze()
                if tmp.shape != size: #for size missmatch use symmetric padding with 0
                    difs = [size[0]-tmp.size(0),size[1]-tmp.size(1),size[2]-tmp.size(2)]
                    pad = (difs[-1]//2,difs[-1]-difs[-1]//2,difs[-2]//2,difs[-2]-difs[-2]//2,difs[-3]//2,difs[-3]-difs[-3]//2)
                    tmp = F.pad(tmp,pad)
                self.labels = torch.cat((self.labels,tmp.unsqueeze(0)),dim=0)
            elif domain in f:
                self.img_nums.append(int(re.findall(r'\d+', os.path.basename(f))[0]))
                tmp = torch.from_numpy(nib.load(f).get_fdata())
                if resample: #resample image to specified size
                    tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size,mode='trilinear',align_corners=False).squeeze()
                if normalize: #normalize image to zero mean and unit std
                    tmp = (tmp - tmp.mean()) / tmp.std()
                if tmp.shape != size: #for size missmatch use symmetric padding with 0
                    difs = [size[0]-tmp.size(0),size[1]-tmp.size(1),size[2]-tmp.size(2)]
                    pad = (difs[-1]//2,difs[-1]-difs[-1]//2,difs[-2]//2,difs[-2]-difs[-2]//2,difs[-3]//2,difs[-3]-difs[-3]//2)
                    tmp = F.pad(tmp,pad)
                self.imgs = torch.cat((self.imgs,tmp.unsqueeze(0)),dim=0)
        self.labels = self.labels.long()
        #check for consistency
        print("Equal image and label numbers: {}".format(self.img_nums==self.label_nums))
        print("Image shape: {}, mean.: {:.2f}, std.: {:.2f}".format(self.imgs.shape,self.imgs.mean(),self.imgs.std()))
        print("Label shape: {}, max.: {}".format(self.labels.shape,torch.max(self.labels)))
        print("Data import finished. Elapsed time: {:.1f} s".format(time.time()-t0 ))

    def __len__(self):
        return int(self.imgs.size(0))

    def __getitem__(self, idx):
        image = self.imgs[idx]
        label = self.labels[idx]
        return image, label.long()

    def get_data(self):
        return self.imgs,self.labels

    def get_image_numbers(self):
        return self.img_nums

    def get_label_numbers(self):
        return self.label_nums

In [None]:
crossmoda_dataset = CrossMoDa_Data("/share/data_supergrover1/weihsbach/shared_data/tmp/CrossMoDa/", 
domain="source", state="l3")#, size=(128,128,192))

In [None]:
def pil_images_from_gray_tensor(s_tensor, scale_min_max=True):
    assert s_tensor.dim() == 4 # S,C,H,W

    s_tensor = s_tensor.detach().cpu()
    _min, _max = s_tensor.min(), s_tensor.max()

    if scale_min_max:
        if _max == _min:
            s_tensor = torch.zeros_like(s_tensor)
        else:
            s_tensor = s_tensor.sub(_min).div(_max-_min).mul(255)

    s_tensor = s_tensor.permute(0,2,3,1)
    if s_tensor.shape[-1] == 1:
        # Got a stack of grayscale images
        s_tensor = s_tensor.squeeze(-1)

    s_numpy = s_tensor.numpy()
    
    images = [pil_image.fromarray(numpy_rgb).convert('RGBA') for numpy_rgb in s_numpy]
    # images = [transforms.functional.to_pil_image(numpy_rgb).convert('RGBA') for numpy_rgb in b_tensor]
    return images

def pil_images_from_onehot_seg(s_tensor_onehot, onehot_colormap, alpha) -> list:
    assert s_tensor_onehot.dim() == 4, "Tensor must be 2d with onehot encoding: Dim = [S, H, W, E]"
    # S,H,W,E = b_tensor_onehot.shape
    s_tensor_onehot = s_tensor_onehot.detach().cpu()
    alpha_channel = (int(255.*alpha),)
    # Create RBG tensor with shape S,H,W,RGBA
    s_rgba_tensor = torch.stack([torch.zeros(s_tensor_onehot.shape[:-1])]*4, dim=-1).type(torch.uint8)

    for onehot_id, rgb_val in onehot_colormap.items():
        if isinstance(rgb_val, tuple):
            bhw_idx = s_tensor_onehot.argmax(dim=-1) == onehot_id
            s_rgba_tensor[bhw_idx] = torch.tensor(rgb_val + alpha_channel, dtype=torch.uint8)

    b_rgb_numpy = s_rgba_tensor.numpy()

    # Append a list of S,C,H,W pil images
    list_images = [pil_image.fromarray(numpy_rgb).convert('RGBA') for numpy_rgb in b_rgb_numpy]
    return list_images

def get_stacked_overlays(s_2d_img_tensor, s_2d_seg_tensor_onehot, onehot_colormap, alpha=0.3):
    assert s_2d_img_tensor.dim() == 4, "" #S,C,H,W
    pil_imgs = pil_images_from_gray_tensor(s_2d_img_tensor, scale_min_max=True)
    pil_segs = pil_images_from_onehot_seg(s_2d_seg_tensor_onehot, onehot_colormap=onehot_colormap, alpha=alpha)
    pil_overlays = []
    
    for rgb_img, rgb_seg in zip(pil_imgs, pil_segs):
        pil_overlay = pil_image.alpha_composite(rgb_img, rgb_seg)
        pil_overlays.append(pil_overlay)

    tensor_overlays = torch.stack([transforms.functional.to_tensor(ovl) for ovl in pil_overlays], dim=0)
    return pil_overlays, tensor_overlays

def get_overlay_grid(s_2d_img_tensor, s_2d_seg_tensor_onehot, onehot_colormap, alpha=0.3, n_per_row=4):
    _, tensor_overlays = get_stacked_overlays(s_2d_img_tensor, s_2d_seg_tensor_onehot, onehot_colormap, alpha=alpha)
    grid_tensor = vision_utils.make_grid(tensor_overlays, nrow=n_per_row)
    return transforms.functional.to_pil_image(grid_tensor), grid_tensor

In [None]:
img, seg = crossmoda_dataset[1]
img_slices = img.permute(2,0,1).unsqueeze(1)
seg_slices = seg.permute(2,0,1)
print(seg.unique(return_counts=True))

def get_cmap_dict(class_max_id, pyplot_map_name='rainbow', include_background=False):
    cmap = plt.get_cmap(pyplot_map_name)
    cmap_dict = {}

    if include_background:
        num_ids = class_max_id+1
        id_offset = 0
    else:
        cmap_dict[0] = None
        num_ids = class_max_id
        id_offset = 1

    discretized_map = (cmap((np.arange(num_ids)/float(num_ids)))*255).astype(np.int32)
    for onehot_idx, rgb_list in enumerate(discretized_map):
        cmap_dict[onehot_idx+id_offset] = tuple(rgb_list)[:3] # Extract only RGB not alpha

    return cmap_dict

color_map = {
    0: None, 
    1: (255,0,0), #ONEHOT id and RGB color
    2: (0,255,0)
}
# color_map = get_cmap_dict(2, pyplot_map_name='rainbow')
# pil_images_from_gray_tensor()
pil_ov, _ = get_overlay_grid(
    img_slices, 
    torch.nn.functional.one_hot(seg_slices,3), 
    color_map, n_per_row=10, alpha=.2)

display(pil_ov)

In [None]:
def overlaySegment(gray1,seg1,colors,flag=False):
    H, W = seg1.squeeze().size()
    #colors=torch.FloatTensor([0,0,0,199,67,66,225,140,154,78,129,170,45,170,170,240,110,38,111,163,91,235,175,86,202,255,52,162,0,183]).view(-1,3)/255.0
    segs1 = F.one_hot(seg1.long(),29).float().permute(2,0,1)

    seg_color = torch.mm(segs1.view(29,-1).t(),colors).view(H,W,3)
    alpha = torch.clamp(1.0 - 0.5*(seg1>0).float(),0,1.0)

    overlay = (gray1*alpha).unsqueeze(2) + seg_color*(1.0-alpha).unsqueeze(2)
    if(flag):
        plt.imshow((overlay).numpy()); 
        plt.axis('off');
        plt.show()
    return overlay

def plot_slice_overlay(img,seg,slc= [100,60,100]):
    print('image size:',img.size())
    
    i1 = overlaySegment(img[0,:,slc[1],:], seg[:,slc[1],:], flag=False)
    i2 = overlaySegment(img[0,:,:,slc[2]], seg[:,:,slc[2]], flag=False)
    i3 = overlaySegment(img[0,slc[0],:,:], seg[slc[0],:,:], flag=False)
    fig,axs = plt.subplots(1, 3)
    axs[0].imshow(i1.cpu().numpy())
    axs[1].imshow(i2.cpu().numpy())
    axs[2].imshow(i3.cpu().numpy())
    plt.show()
    return None

In [None]:
def augmentAffine(img_in, seg_in, strength=0.05):
    """
    3D affine augmentation on image and segmentation mini-batch on GPU.
    (affine transf. is centered: trilinear interpolation and zero-padding used for sampling)
    :input: img_in batch (torch.cuda.FloatTensor), seg_in batch (torch.cuda.LongTensor)
    :return: augmented BxCxTxHxW image batch (torch.cuda.FloatTensor), augmented BxTxHxW seg batch (torch.cuda.LongTensor)
    """
    B,C,D,H,W = img_in.size()
    affine_matrix = (torch.eye(3,4).unsqueeze(0) + torch.randn(B, 3, 4) * strength).to(img_in.device)

    meshgrid = F.affine_grid(affine_matrix,torch.Size((B,1,D,H,W)))

    img_out = F.grid_sample(img_in, meshgrid,padding_mode='border')
    seg_out = F.grid_sample(seg_in.float(), meshgrid, mode='nearest')

    return img_out, seg_out



def augmentNoise(img_in,strength=0.05):
    return img_in + strength*torch.randn_like(img_in)


    
def pad_center_plane(img,size):
    s0,s1,s2 = img.size()
    i0 = (size[0]-s0)//2
    i1 = (size[1]-s1)//2
    i2 = 0
    pd = (i2,size[2]-s2-i2,i1,size[1]-s1-i1,i0,size[0]-s0-i0)
    #print('pad',pd)
    img = F.pad(img, pd, "constant", 0)
    return img



def crop_center_plane(img, size):
    s0,s1,s2 = img.size()
    i0 = (s0-size[0])//2
    i1 = (s1-size[1])//2
    i2 = 0
    img = img[i0:i0+size[0],i1:i1+size[1],i2:i2+size[2]]
    return img

In [None]:
DEBUG = True
#training routine

color_map = {
    0: None, 
    1: (255,0,0), #ONEHOT id and RGB color
    2: (0,255,0)
}

def train_DL(dataset ,epochs=500, update_epx = 50, use_mind = True):
    img_num,C  = len(dataset), dataset[0][0].shape[0]
    if use_mind:
        C =12
    else:
        C = 1
    all_segs = torch.cat([seg for _, seg in dataset])

    num_class = int(torch.max(all_segs).item()+1)
    class_weight = torch.sqrt(1.0/(torch.bincount(all_segs.long().view(-1)).float()))
    class_weight = class_weight/class_weight.mean()
    class_weight[0] = 0.15
    class_weight = class_weight.cuda()
    print('inv sqrt class_weight',class_weight)

    backbone, aspp, head = create_model(output_classes=num_class,input_channels=C)
    optimizer = torch.optim.Adam(list(backbone.parameters())+list(aspp.parameters())+list(head.parameters()),lr=0.001)

    criterion = nn.CrossEntropyLoss(class_weight)
    scaler = amp.GradScaler()
    backbone.cuda() 
    backbone.train()
    aspp.cuda() 
    aspp.train()
    head.cuda() 
    head.train()
    t0 = time.time()

    for epx in range(epochs):
        optimizer.zero_grad()
        ind = torch.randint(0,img_num,(1,))
        img, seg = dataset[ind:ind+1]
        img, seg = img.unsqueeze(0).float().cuda(), seg.unsqueeze(0).float().cuda()
        
        if use_mind:
            # img = mindssc(img)
            pass

        img, seg = augmentAffine(img, seg, strength=0.1)
        img = augmentNoise(img, strength=0.02)
        
        if DEBUG:
            img_slices = img.permute(0,1,4,2,3).squeeze().unsqueeze(1)
            seg_slices = seg.permute(0,1,4,2,3).squeeze().to(dtype=torch.int64)
            idx_dept_with_segs, *_ = torch.nonzero(seg_slices > 0, as_tuple=True)
            idx_dept_with_segs = idx_dept_with_segs.unique()

            img_slices = img_slices[idx_dept_with_segs]
            seg_slices = seg_slices[idx_dept_with_segs]
            
            pil_ov, _ = get_overlay_grid(
                img_slices, 
                torch.nn.functional.one_hot(seg_slices,3), 
                color_map, n_per_row=10, alpha=.5
            )
            display(pil_ov)
            
        seg = F.interpolate(seg, scale_factor=0.5, mode='nearest').squeeze(0).long()

        img.requires_grad = True
        #img_mr.requires_grad = True
        with amp.autocast(enabled=True):
            output_j = apply_model(backbone,aspp,head,img,checkpointing=True)
            loss = criterion(output_j, seg)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if epx%update_epx==update_epx-1 or epx == 0:
            dice = dice3d(output_j.permute(0,2,3,4,1),torch.nn.functional.one_hot(seg, 3), one_hot_torch_style=True)
            print('epx',epx,round(time.time()-t0,2),'s','loss',round(loss.item(),6),'dice mean', round(dice.mean().item(),4),'dice',dice)
        
        if DEBUG:
            break
#    stat_cuda('Visceral training')
    backbone.cpu()
    aspp.cpu() 
    head.cpu() 
    return backbone,aspp,head

In [None]:
epochs = 2000
updates = 50
# imgs = torch.cat((imgs_train_source,imgs_train_target),dim=0)
# label = torch.cat((labels_train_source,labels_train_target),dim=0)
backbone,aspp,head = train_DL(crossmoda_dataset, epochs=epochs, update_epx=updates, use_mind=False)

In [None]:
path = '/share/data_sam1/ckruse/image_data/CrossMoDa/target_training/'
label_path = '/share/data_supergrover1/weihsbach/tmp/crossmoda_full_set/'
plot = False
backbone.cuda() 
aspp.cuda() 
head.cuda() 
target_dices = torch.zeros(32)
source_dices = torch.zeros(32)
for i in range(32):
    ind = i+150
    nii_img = nib.load(path + 'crossmoda_'+ str(ind) + '_hrT2.nii.gz')
    nii_label = nib.load(label_path + 'crossmoda_'+ str(ind) + '_hrT2_Label.nii.gz')
    tmp = torch.from_numpy(nii_img.get_fdata())
    label = torch.from_numpy(nii_label.get_fdata()).cuda()
    org_size = tmp.size()  
    size = (192,192,64)
    tmp = crop_center_plane(tmp,size)
    tmp = (tmp - tmp.mean()) / tmp.std()
    org_img = torch.from_numpy(nii_img.get_fdata())
    
    img= tmp.float().cuda()
    img = mindssc(img.unsqueeze(0).unsqueeze(0))
    with torch.no_grad():
        with amp.autocast(enabled=True):
            output_j = apply_model(backbone,aspp,head,img,checkpointing=False)#
    modeled_seg = F.interpolate(output_j,scale_factor=2).argmax(1)
    modeled_seg = pad_center_plane(modeled_seg.squeeze(),org_size)
    #print(modeled_seg.shape,torch.max(modeled_seg))
    connectivity = 18 # only 4,8 (2D) and 26, 18, and 6 (3D) are allowed
    np_label = modeled_seg.long().cpu().numpy().astype('int32')
    tmp = pad_center_plane(tmp.squeeze(),org_size)
    #plot = True
    if plot:
        slc = 30
        i0 = overlaySegment(tmp[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        i1 = overlaySegment(org_img[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        fig,axs = plt.subplots(1, 2,figsize=(18, 9))
        axs[0].imshow((i0+i1).cpu().numpy())
        axs[1].imshow(i1.cpu().numpy())
        plt.show()
        fig.set_figwidth(40)
        fig.set_figheight(10)
    #print(modeled_seg.shape,torch.max(modeled_seg))
    target_dices[i] = dice_coeff(modeled_seg,label)
    print(f'image: crossmoda_{ind}_hrT2.nii.gz, dice: {target_dices[i]*100:0.2f}')
    
print(f'target dice mean: {target_dices.mean()*100:0.2f}')


In [None]:
path = '/share/data_sam1/ckruse/image_data/CrossMoDa/target_validation/'
plot = False
save = True
backbone.cuda() 
aspp.cuda() 
head.cuda() 
for i in range(32):
    ind = i+211
    nii_img = nib.load(path + 'crossmoda_'+ str(ind) + '_hrT2.nii.gz')
    img_affine = nii_img.affine
    tmp = torch.from_numpy(nii_img.get_fdata())
    org_img = torch.from_numpy(nii_img.get_fdata())
    org_size = tmp.size()  
    size = (192,192,64)
    tmp = crop_center_plane(tmp,size)
    #print('red. shape',tmp.shape)
    #tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size)
    tmp = (tmp - tmp.mean()) / tmp.std()
        
    img= tmp.float().cuda()
    #print(img.shape)
    img = mindssc(img.unsqueeze(0).unsqueeze(0))
    with torch.no_grad():
        with amp.autocast(enabled=True):
            output_j = apply_model(backbone,aspp,head,img,checkpointing=False)#
    modeled_seg = F.interpolate(output_j,scale_factor=2).argmax(1)
    modeled_seg = pad_center_plane(modeled_seg.squeeze(),org_size)
    tmp = pad_center_plane(tmp.squeeze(),org_size)
    #print('org shape',tmp.shape)
    if plot:
        slc = 30
        i0 = overlaySegment(tmp[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        i1 = overlaySegment(org_img[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        fig,axs = plt.subplots(1, 2,figsize=(18, 9))
        axs[0].imshow((i0+i1).cpu().numpy())
        axs[1].imshow(i1.cpu().numpy())
        plt.show()
        fig.set_figwidth(40)
        fig.set_figheight(10)
    if save:
        label_nii = nib.Nifti1Image(modeled_seg.float().squeeze().cpu().numpy(), img_affine)
        nib.save(label_nii, 'Deeplab_validation/crossmoda_'+ str(ind) + '_Label.nii.gz')  
        

In [None]:
#!rm Deeplab_validation.zip
#!zip -r Deeplab_validation_half_res_adapted_model_target_train.zip Deeplab_validation

In [None]:
def save_model(backbone,aspp,head,name):
    torch.save(backbone.state_dict(), name + '_backbone.pth')
    torch.save(aspp.state_dict(), name + '_aspp.pth')
    torch.save(head.state_dict(), name + '_head.pth')
    return None

def load_model(name,output_classes,input_channels):
    backbone, aspp, head = create_model(output_classes=output_classes,input_channels=input_channels)
    backbone.load_state_dict(torch.load( name + '_backbone.pth'))
    aspp.load_state_dict(torch.load(name + '_aspp.pth'))
    head.load_state_dict(torch.load(name + '_head.pth'))
    return backbone,aspp,head

save_model(backbone,aspp,head,'Models/half_res_adapted_model_target_training')
#backbone,aspp,head=load_model('Models/half_res_adapted_model_target_training',3,12)