In [1]:
import sys
# general
import os
import time
from meidic_vtach_utils.run_on_recommended_cuda import get_cuda_environ_vars as get_vars
os.update(get_vars(select="* -4"))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import nibabel as nib
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import pandas as pd  
import numpy as np
import scipy.ndimage
from sklearn.model_selection import KFold
import cc3d
from mdl_seg_class

# custom functions
from mindssc import mindssc

print(torch.__version__)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

1.8.1
7605
Tesla T4


In [2]:
def overlaySegment(gray1,seg1,colors,flag=False):
    H, W = seg1.squeeze().size()
    #colors=torch.FloatTensor([0,0,0,199,67,66,225,140,154,78,129,170,45,170,170,240,110,38,111,163,91,235,175,86,202,255,52,162,0,183]).view(-1,3)/255.0
    segs1 = F.one_hot(seg1.long(),29).float().permute(2,0,1)

    seg_color = torch.mm(segs1.view(29,-1).t(),colors).view(H,W,3)
    alpha = torch.clamp(1.0 - 0.5*(seg1>0).float(),0,1.0)

    overlay = (gray1*alpha).unsqueeze(2) + seg_color*(1.0-alpha).unsqueeze(2)
    if(flag):
        plt.imshow((overlay).numpy()); 
        plt.axis('off');
        plt.show()
    return overlay# Atrous Spatial Pyramid Pooling (Segmentation Network)

class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv3d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            # nn.AdaptiveAvgPool2d(1),
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-3:]
        # x = F.adaptive_avg_pool3d(x, (1))
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='nearest')  # , align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = [nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())]

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)
 
        self.project = nn.Sequential(
            nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)


# Mobile-Net with depth-separable convolutions and residual connections
class ResBlock(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs


def create_model(output_classes: int = 14,input_channels: int = 1):
    # in_channels = torch.Tensor([1,16,24,24,32,32,32,64]).long()
    in_channels = torch.Tensor([input_channels, 24, 24, 32, 48, 48, 48, 64]).long()
    mid_channels = torch.Tensor([64, 128, 192, 192, 256, 256, 256, 384]).long()
    out_channels = torch.Tensor([24, 24, 32, 48, 48, 48, 64, 64]).long()
    mid_stride = torch.Tensor([1, 1, 1, 2, 1, 1, 1, 1])
    net = [nn.Identity()]
    for i in range(8):
        inc = int(in_channels[i])
        midc = int(mid_channels[i])
        outc = int(out_channels[i])
        strd = int(mid_stride[i])
        layer = nn.Sequential(nn.Conv3d(inc, midc, 1, bias=False), nn.BatchNorm3d(midc), nn.ReLU6(True),
                              nn.Conv3d(midc, midc, 3, stride=strd, padding=1, bias=False, groups=midc),
                              nn.BatchNorm3d(midc), nn.ReLU6(True),
                              nn.Conv3d(midc, outc, 1, bias=False), nn.BatchNorm3d(outc))
        if i == 0:
            layer[0] = nn.Conv3d(inc, midc, 3, padding=1, stride=2, bias=False)
        if (inc == outc) & (strd == 1):
            net.append(ResBlock(layer))
        else:
            net.append(layer)

    backbone = nn.Sequential(*net)

    count = 0
    # weight initialization
    for m in backbone.modules():
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            count += 1
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.zeros_(m.bias)

    print('#CNN layer', count)
    # complete model: MobileNet + ASPP + head (with a single skip connection)
    # newer model (one more stride, no groups in head)
    aspp = ASPP(64, (2, 4, 8, 16, 32), 128)
    head = nn.Sequential(nn.Conv3d(128 + 24, 64, 1, padding=0, groups=1, bias=False), nn.BatchNorm3d(64), nn.ReLU(), \
                         nn.Conv3d(64, 64, 3, groups=1, padding=1, bias=False), nn.BatchNorm3d(64), nn.ReLU(), \
                         nn.Conv3d(64, output_classes, 1))
    return backbone, aspp, head


def apply_model(backbone, aspp, head, img, checkpointing=True, return_intermediate=False):
    if checkpointing:
        x1 = checkpoint(backbone[:3], img)
        x2 = checkpoint(backbone[3:], x1)
        y = checkpoint(aspp, x2)
        y1 = torch.cat((x1[0:1], F.interpolate(y[0:1], scale_factor=2)), 1)
        output_j = checkpoint(head, y1)
    else:
        x1 = backbone[:3](img)
        x2 = backbone[3:](x1)
        y = aspp(x2)
        y1 = torch.cat((x1[0:1], F.interpolate(y[0:1], scale_factor=2)), 1)
        output_j = head(y1)
    if return_intermediate:
        return y1,output_j
    else:
        return output_j


In [3]:
def crop_center_plane(img,size):
    s0,s1,s2 = img.size()
    i0 = (s0-size[0])//2
    i1 = (s1-size[1])//2
    i2 = 0
    img = img[i0:i0+size[0],i1:i1+size[1],i2:i2+size[2]]
    return img

    

def load_source_data(normalize:bool = True,size:tuple = (192,192,120)):
    t0 = time.time()

    source_train_num = 105
    img_train = torch.zeros(source_train_num,1,size[0],size[1],size[2])
    label_train = torch.zeros(source_train_num,size[0],size[1],size[2])
    path = '/share/data_sam1/ckruse/image_data/CrossMoDa/source_training_resampled_05mm/'
    for i in range(source_train_num):
        ind = i+1
        tmp = torch.from_numpy(
                    nib.load(path + 'crossmoda_'+ str(ind) + '_ceT1.nii.gz').get_fdata())
        #print(tmp.shape) 421 421 120/107
        #tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size)
#        if tmp.size(-1)<120:
#            tmp = F.pad(tmp,(0,120-tmp.size(-1)))
        tmp = crop_center_plane(tmp,size)
        if normalize:
            tmp = (tmp - tmp.mean()) / tmp.std()
        img_train[i, 0, :, :, :] = tmp.squeeze()
        tmp = torch.from_numpy(
            nib.load(path + 'crossmoda_'+ str(ind) + '_Label.nii.gz').get_fdata())
        tmp = crop_center_plane(tmp,size)
#        if tmp.size(-1)<120:
#            tmp = F.pad(tmp,(0,120-tmp.size(-1)))
        #tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size)
        label_train[i] = tmp.squeeze()

    print('Time for data import:', round(time.time() - t0, 2), 's')
    return img_train, label_train



In [4]:
def load_staples_training_data(normalize:bool = True,size:tuple = (192,192,120)):
    t0 = time.time()

    target_train_num = 105
    img_train = torch.zeros(target_train_num,1,size[0],size[1],size[2])
    label_train = torch.zeros(target_train_num,size[0],size[1],size[2])
    path = '/share/data_supergrover1/hansen/temp/crossmoda/data/target_training/'
    for i in range(target_train_num):
        ind = i+106
        tmp = torch.from_numpy(
                    nib.load(path + 'crossmoda_'+ str(ind) + '_hrT2.nii.gz').get_fdata())
        if tmp.size(-1)<size[-1]:
            tmp = F.pad(tmp,(0,size[-1]-tmp.size(-1)))
        tmp = crop_center_plane(tmp,size)
        if normalize:
            tmp = (tmp - tmp.mean()) / tmp.std()
        img_train[i, 0, :, :, :] = tmp.squeeze()
        tmp = torch.from_numpy(
                    nib.load(path + 'crossmoda_'+ str(ind) + '_Label.nii.gz').get_fdata())
        if tmp.size(-1)<size[-1]:
            tmp = F.pad(tmp,(0,size[-1]-tmp.size(-1)))
        tmp = crop_center_plane(tmp,size)

        label_train[i] = tmp.squeeze()

    print('Time for data import:', round(time.time() - t0, 2), 's')
    return img_train, label_train


In [5]:
def plot_slice_overlay(img,seg,slc= [100,60,100]):
    print('image size:',img.size())
    
    i1 = overlaySegment(img[0,:,slc[1],:], seg[:,slc[1],:], flag=False)
    i2 = overlaySegment(img[0,:,:,slc[2]], seg[:,:,slc[2]], flag=False)
    i3 = overlaySegment(img[0,slc[0],:,:], seg[slc[0],:,:], flag=False)
    fig,axs = plt.subplots(1, 3)
    axs[0].imshow(i1.cpu().numpy())
    axs[1].imshow(i2.cpu().numpy())
    axs[2].imshow(i3.cpu().numpy())
    plt.show()
    return None

In [6]:
# load data
imgs_train_target, labels_train_target = load_staples_training_data(normalize = True,size = (192,192,64))
imgs_train_source, labels_train_source = load_source_data(normalize = True,size = (192,192,64))

Time for data import: 46.79 s
Time for data import: 129.2 s


In [7]:
def augmentAffine(img_in, seg_in, strength=0.05):
    """
    3D affine augmentation on image and segmentation mini-batch on GPU.
    (affine transf. is centered: trilinear interpolation and zero-padding used for sampling)
    :input: img_in batch (torch.cuda.FloatTensor), seg_in batch (torch.cuda.LongTensor)
    :return: augmented BxCxTxHxW image batch (torch.cuda.FloatTensor), augmented BxTxHxW seg batch (torch.cuda.LongTensor)
    """
    B,C,D,H,W = img_in.size()
    affine_matrix = (torch.eye(3,4).unsqueeze(0) + torch.randn(B, 3, 4) * strength).to(img_in.device)

    meshgrid = F.affine_grid(affine_matrix,torch.Size((B,1,D,H,W)))

    img_out = F.grid_sample(img_in, meshgrid,padding_mode='border')
    seg_out = F.grid_sample(seg_in.float().unsqueeze(1), meshgrid, mode='nearest').long().squeeze(1)

    return img_out, seg_out

def augmentNoise(img_in,strength=0.05):
    return img_in + strength*torch.randn_like(img_in)

In [8]:
#training routine
def train_DL(imgs,segs,epochs=500,update_epx = 50,use_mind = True):
    img_num,C,_,_,_ = imgs.size()
    if use_mind:
        C =12
    else:
        C = 1
    ds_segs = F.interpolate(segs.unsqueeze(0),scale_factor=0.5,mode='nearest').squeeze(0)
    num_class = int(torch.max(segs).item()+1)
    backbone, aspp, head = create_model(output_classes=num_class,input_channels=C)
    optimizer = torch.optim.Adam(list(backbone.parameters())+list(aspp.parameters())+list(head.parameters()),lr=0.001)
    class_weight = torch.sqrt(1.0/(torch.bincount(segs.long().view(-1)).float()))
    class_weight = class_weight/class_weight.mean()
    class_weight[0] = 0.15
    class_weight = class_weight.cuda()
    print('inv sqrt class_weight',class_weight)
    criterion = nn.CrossEntropyLoss(class_weight)
    scaler = amp.GradScaler()
    backbone.cuda() 
    backbone.train()
    aspp.cuda() 
    aspp.train()
    head.cuda() 
    head.train()
    t0 = time.time()
    for epx in range(epochs):
        optimizer.zero_grad()
        ind = torch.randint(0,img_num,(1,))
        img = imgs[ind:ind+1].cuda()
        seg = segs[ind:ind+1].long().cuda()
        if use_mind:
            img = mindssc(img)
        img, seg = augmentAffine(img, seg, strength=0.1)
        img = augmentNoise(img,strength=0.02)
        seg = F.interpolate(seg.float().unsqueeze(0),scale_factor=0.5,mode='nearest').squeeze(0).long()
        img.requires_grad = True
        #img_mr.requires_grad = True
        with amp.autocast(enabled=True):
            output_j = apply_model(backbone,aspp,head,img,checkpointing=True)
            loss = criterion(output_j,seg)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if epx%update_epx==update_epx-1 or epx == 0:
            dice = dice_coeff(output_j.argmax(1),seg)
            print('epx',epx,round(time.time()-t0,2),'s','loss',round(loss.item(),6),'dice mean', round(dice.mean().item(),4),'dice',dice)
#    stat_cuda('Visceral training')
    backbone.cpu()
    aspp.cpu() 
    head.cpu() 
    return backbone,aspp,head

In [9]:
epochs = 2000
updates = 50
imgs = torch.cat((imgs_train_source,imgs_train_target),dim=0)
label = torch.cat((labels_train_source,labels_train_target),dim=0)
backbone,aspp,head = train_DL(imgs,label,epochs=epochs,update_epx =updates,use_mind = True)




#CNN layer 24
inv sqrt class_weight tensor([0.1500, 0.5172, 2.4564], device='cuda:0')




epx 0 5.48 s loss 1.173589 dice mean 0.006 dice tensor([0.0060])
epx 49 261.58 s loss 0.14849 dice mean 0.0 dice tensor([0.])
epx 99 527.87 s loss 0.072097 dice mean 0.0 dice tensor([0., 0.])
epx 149 795.07 s loss 0.058573 dice mean 0.1088 dice tensor([0.2177, 0.0000])
epx 199 1062.63 s loss 0.03156 dice mean 0.0 dice tensor([0., 0.])
epx 249 1330.62 s loss 0.026504 dice mean 0.2031 dice tensor([0.4062, 0.0000])
epx 299 1598.62 s loss 0.015467 dice mean 0.2812 dice tensor([0.5625, 0.0000])
epx 349 1866.45 s loss 0.035605 dice mean 0.0 dice tensor([0., 0.])
epx 399 2134.39 s loss 0.013029 dice mean 0.2774 dice tensor([0.2930, 0.2618])
epx 449 2402.47 s loss 0.009394 dice mean 0.2839 dice tensor([0.2267, 0.3410])
epx 499 2670.51 s loss 0.010311 dice mean 0.4167 dice tensor([0.6934, 0.1401])
epx 549 2938.68 s loss 0.032879 dice mean 0.5712 dice tensor([0.8256, 0.3167])
epx 599 3207.06 s loss 0.012772 dice mean 0.1363 dice tensor([0.0519, 0.2206])
epx 649 3475.59 s loss 0.018782 dice mean 

In [10]:
    
def pad_center_plane(img,size):
    s0,s1,s2 = img.size()
    i0 = (size[0]-s0)//2
    i1 = (size[1]-s1)//2
    i2 = 0
    pd = (i2,size[2]-s2-i2,i1,size[1]-s1-i1,i0,size[0]-s0-i0)
    #print('pad',pd)
    img = F.pad(img, pd, "constant", 0)
    return img

In [17]:
path = '/share/data_sam1/ckruse/image_data/CrossMoDa/target_training/'
label_path = '/share/data_supergrover1/weihsbach/tmp/crossmoda_full_set/'
plot = False
backbone.cuda() 
aspp.cuda() 
head.cuda() 
target_dices = torch.zeros(32)
source_dices = torch.zeros(32)
for i in range(32):
    ind = i+150
    nii_img = nib.load(path + 'crossmoda_'+ str(ind) + '_hrT2.nii.gz')
    nii_label = nib.load(label_path + 'crossmoda_'+ str(ind) + '_hrT2_Label.nii.gz')
    tmp = torch.from_numpy(nii_img.get_fdata())
    label = torch.from_numpy(nii_label.get_fdata()).cuda()
    org_size = tmp.size()  
    size = (192,192,64)
    tmp = crop_center_plane(tmp,size)
    tmp = (tmp - tmp.mean()) / tmp.std()
    org_img = torch.from_numpy(nii_img.get_fdata())
    
    img= tmp.float().cuda()
    img = mindssc(img.unsqueeze(0).unsqueeze(0))
    with torch.no_grad():
        with amp.autocast(enabled=True):
            output_j = apply_model(backbone,aspp,head,img,checkpointing=False)#
    modeled_seg = F.interpolate(output_j,scale_factor=2).argmax(1)
    modeled_seg = pad_center_plane(modeled_seg.squeeze(),org_size)
    #print(modeled_seg.shape,torch.max(modeled_seg))
    connectivity = 18 # only 4,8 (2D) and 26, 18, and 6 (3D) are allowed
    np_label = modeled_seg.long().cpu().numpy().astype('int32')
    tmp = pad_center_plane(tmp.squeeze(),org_size)
    #plot = True
    if plot:
        slc = 30
        i0 = overlaySegment(tmp[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        i1 = overlaySegment(org_img[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        fig,axs = plt.subplots(1, 2,figsize=(18, 9))
        axs[0].imshow((i0+i1).cpu().numpy())
        axs[1].imshow(i1.cpu().numpy())
        plt.show()
        fig.set_figwidth(40)
        fig.set_figheight(10)
    #print(modeled_seg.shape,torch.max(modeled_seg))
    target_dices[i] = dice_coeff(modeled_seg,label)
    print(f'image: crossmoda_{ind}_hrT2.nii.gz, dice: {target_dices[i]*100:0.2f}')
    
print(f'target dice mean: {target_dices.mean()*100:0.2f}')


image: crossmoda_150_hrT2.nii.gz, dice: 77.95
image: crossmoda_151_hrT2.nii.gz, dice: 75.39
image: crossmoda_152_hrT2.nii.gz, dice: 60.75
image: crossmoda_153_hrT2.nii.gz, dice: 76.53
image: crossmoda_154_hrT2.nii.gz, dice: 58.11
image: crossmoda_155_hrT2.nii.gz, dice: 51.08
image: crossmoda_156_hrT2.nii.gz, dice: 69.33
image: crossmoda_157_hrT2.nii.gz, dice: 19.09
image: crossmoda_158_hrT2.nii.gz, dice: 68.11
image: crossmoda_159_hrT2.nii.gz, dice: 79.19
image: crossmoda_160_hrT2.nii.gz, dice: 74.65
image: crossmoda_161_hrT2.nii.gz, dice: 51.15
image: crossmoda_162_hrT2.nii.gz, dice: 78.01
image: crossmoda_163_hrT2.nii.gz, dice: 58.91
image: crossmoda_164_hrT2.nii.gz, dice: 63.81
image: crossmoda_165_hrT2.nii.gz, dice: 44.12
image: crossmoda_166_hrT2.nii.gz, dice: 63.75
image: crossmoda_167_hrT2.nii.gz, dice: 80.05
image: crossmoda_168_hrT2.nii.gz, dice: 69.89
image: crossmoda_169_hrT2.nii.gz, dice: 31.71
image: crossmoda_170_hrT2.nii.gz, dice: 31.81
image: crossmoda_171_hrT2.nii.gz, 

In [18]:
path = '/share/data_sam1/ckruse/image_data/CrossMoDa/target_validation/'
plot = False
save = True
backbone.cuda() 
aspp.cuda() 
head.cuda() 
for i in range(32):
    ind = i+211
    nii_img = nib.load(path + 'crossmoda_'+ str(ind) + '_hrT2.nii.gz')
    img_affine = nii_img.affine
    tmp = torch.from_numpy(nii_img.get_fdata())
    org_img = torch.from_numpy(nii_img.get_fdata())
    org_size = tmp.size()  
    size = (192,192,64)
    tmp = crop_center_plane(tmp,size)
    #print('red. shape',tmp.shape)
    #tmp = F.interpolate(tmp.unsqueeze(0).unsqueeze(0), size=size)
    tmp = (tmp - tmp.mean()) / tmp.std()
        
    img= tmp.float().cuda()
    #print(img.shape)
    img = mindssc(img.unsqueeze(0).unsqueeze(0))
    with torch.no_grad():
        with amp.autocast(enabled=True):
            output_j = apply_model(backbone,aspp,head,img,checkpointing=False)#
    modeled_seg = F.interpolate(output_j,scale_factor=2).argmax(1)
    modeled_seg = pad_center_plane(modeled_seg.squeeze(),org_size)
    tmp = pad_center_plane(tmp.squeeze(),org_size)
    #print('org shape',tmp.shape)
    if plot:
        slc = 30
        i0 = overlaySegment(tmp[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        i1 = overlaySegment(org_img[:,:,slc].cpu(), modeled_seg[:,:,slc].cpu(), flag=False)
        fig,axs = plt.subplots(1, 2,figsize=(18, 9))
        axs[0].imshow((i0+i1).cpu().numpy())
        axs[1].imshow(i1.cpu().numpy())
        plt.show()
        fig.set_figwidth(40)
        fig.set_figheight(10)
    if save:
        label_nii = nib.Nifti1Image(modeled_seg.float().squeeze().cpu().numpy(), img_affine)
        nib.save(label_nii, 'Deeplab_validation/crossmoda_'+ str(ind) + '_Label.nii.gz')  
        

In [19]:
#!rm Deeplab_validation.zip
#!zip -r Deeplab_validation_half_res_adapted_model_target_train.zip Deeplab_validation

rm: cannot remove 'Deeplab_validation.zip': No such file or directory
  adding: Deeplab_validation/ (stored 0%)
  adding: Deeplab_validation/crossmoda_240_Label.nii.gz (deflated 95%)
  adding: Deeplab_validation/crossmoda_242_Label.nii.gz (deflated 99%)
  adding: Deeplab_validation/.ipynb_checkpoints/ (stored 0%)
  adding: Deeplab_validation/crossmoda_211_Label.nii.gz (deflated 97%)
  adding: Deeplab_validation/crossmoda_212_Label.nii.gz (deflated 98%)
  adding: Deeplab_validation/crossmoda_213_Label.nii.gz (deflated 93%)
  adding: Deeplab_validation/crossmoda_214_Label.nii.gz (deflated 99%)
  adding: Deeplab_validation/crossmoda_215_Label.nii.gz (deflated 98%)
  adding: Deeplab_validation/crossmoda_216_Label.nii.gz (deflated 97%)
  adding: Deeplab_validation/crossmoda_217_Label.nii.gz (deflated 99%)
  adding: Deeplab_validation/crossmoda_218_Label.nii.gz (deflated 96%)
  adding: Deeplab_validation/crossmoda_219_Label.nii.gz (deflated 98%)
  adding: Deeplab_validation/crossmoda_220_Lab

In [14]:
def save_model(backbone,aspp,head,name):
    torch.save(backbone.state_dict(), name + '_backbone.pth')
    torch.save(aspp.state_dict(), name + '_aspp.pth')
    torch.save(head.state_dict(), name + '_head.pth')
    return None

def load_model(name,output_classes,input_channels):
    backbone, aspp, head = create_model(output_classes=output_classes,input_channels=input_channels)
    backbone.load_state_dict(torch.load( name + '_backbone.pth'))
    aspp.load_state_dict(torch.load(name + '_aspp.pth'))
    head.load_state_dict(torch.load(name + '_head.pth'))
    return backbone,aspp,head

save_model(backbone,aspp,head,'Models/half_res_adapted_model_target_training')
#backbone,aspp,head=load_model('Models/half_res_adapted_model_target_training',3,12)