<a href="https://colab.research.google.com/github/mz0g/mri-translations/blob/main/UGP_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code adapted from 

### Data

In [5]:
import torch.utils.data as data
import os.path
import numpy as np
import random
import torch
random.seed(0)

class PairedImages_w_nameList(data.Dataset):
    '''
    can act as supervised or un-supervised based on flists
    '''
    def __init__(self, root1, root2, flist1, flist2, transform1=None, transform2=None, do_aug=False):
        self.root1 = root1
        self.root2 = root2
        self.flist1 = flist1
        self.flist2 = flist2
        self.transform1 = transform1
        self.transform2 = transform2
        self.do_aug = do_aug
    def __getitem__(self, index):
        impath1 = self.flist1[index]
        img1 = np.load(os.path.join(self.root1, impath1))
        impath2 = self.flist2[index]
        img2 = np.load(os.path.join(self.root2, impath2))
        if self.transform1 is not None:
            img1 = self.transform1(img1)
            img2 = self.transform2(img2)
        if self.do_aug:
            p1 = random.random()
            if p1<0.5:
                img1, img2 = torch.fliplr(img1), torch.fliplr(img2)
            p2 = random.random()
            if p2<0.5:
                img1, img2 = torch.flipud(img1), torch.flipud(img2)
        return img1, img2
    def __len__(self):
        return len(self.flist1)

### Loss Functions

In [6]:
# from skimage.measure import compare_ssim as ssim
# from skimage.measure import compare_psnr as psnr
import torch
import numpy as np
from matplotlib import pyplot as plt
import scipy.ndimage.filters as fi
import matplotlib.pyplot as plt

# def compare_ssim(imgRef, imgT, K1=0.01, K2=0.03):
#     r = ssim(imgRef, imgT, data_range=imgT.max() - imgT.min(), multichannel=True, K1=K1, K2=K2)
#     return r

# def compare_psnr(imgRef, imgT):
#     r = psnr(imgRef, imgT, data_range=imgT.max() - imgT.min())
#     return r

# def compare_rrmse(imgRef, imgT):
#     numerator = (imgRef-imgT)**2
#     numerator = np.mean(numerator.flatten())
    
#     denominator = (imgRef)**2
#     denominator = np.mean(denominator.flatten())
    
#     r = numerator/denominator
#     r = np.sqrt(r)
#     return r

# def compare_qilv(I, I2, Ws=0.0, K1=0.01, K2=0.03):
#     C1 = K1**2
#     C2 = K2**2

#     kernsize=11
#     kernstd = 1.5
#     if Ws==0:
#         window = np.zeros((kernsize, kernsize))
#         window[kernsize//2, kernsize//2]=1
#         window = fi.gaussian_filter(window, kernstd)
#     window = window/np.sum(window)
    
#     chs = I.shape[2]
#     idxs = []
#     for ch in range(chs):
#         M1 = fi.convolve(I[:,:,ch], window)
#         M2 = fi.convolve(I2[:,:,ch], window)
#         Isq = I**2
#         I2sq = I2**2
#         V1 = fi.convolve(Isq[:,:,ch], window) - M1**2
#         V2 = fi.convolve(I2sq[:,:,ch], window) - M2**2

#         m1 = np.mean(V1)
#         m2 = np.mean(V2)
#         s1 = np.std(V1)
#         s2 = np.std(V2)
#         s12 = np.mean((V1-m1)*(V2-m2))

#         ind1 = (2*m1*m2+C1)/(m1**2+m2**2+C1)
#         ind2 = (2*s1*s2+C2)/(s1**2+s2**2+C2)
#         ind3 = (s12+C2/2)/(s1*s2+C2/2)
        
#         idxs.append(ind1*ind2*ind3)

#     return np.mean(idxs)

def bayeLq_loss(out_mean, out_log_var, target, q=2, k1=1, k2=1):
    var_eps = 1e-5
    out_var = var_eps + torch.exp(out_log_var)
    # out_log_var = torch.clamp(out_log_var, min=-3, max=3)
    # factor = torch.exp(-1*out_log_var) #no dropout grad_clipping b4 optim.step 
    factor = 1/out_var
    diffq = factor*torch.pow(torch.abs(out_mean-target), q)
#     diffq = torch.clamp(diffq, min=1e-5, max=1e3)
    
    loss1 = k1*torch.mean(diffq)
    loss2 = k2*torch.mean(torch.log(out_var))
    
    loss = 0.5*(loss1 + loss2)
    return loss

def bayeGen_loss(out_mean, out_1alpha, out_beta, target):
    alpha_eps, beta_eps = 1e-5, 1e-1
    out_1alpha += alpha_eps
    out_beta += beta_eps 
    factor = out_1alpha
    resi = torch.abs(out_mean - target)
#     resi = (torch.log((resi*factor).clamp(min=1e-4, max=5))*out_beta).clamp(min=-1e-4, max=5)
    resi = (resi*factor*out_beta).clamp(min=1e-6, max=50)
    log_1alpha = torch.log(out_1alpha)
    log_beta = torch.log(out_beta)
    lgamma_beta = torch.lgamma(torch.pow(out_beta, -1))
    
    if torch.sum(log_1alpha != log_1alpha) > 0:
        print('log_1alpha has nan')
        print(lgamma_beta.min(), lgamma_beta.max(), log_beta.min(), log_beta.max())
    if torch.sum(lgamma_beta != lgamma_beta) > 0:
        print('lgamma_beta has nan')
    if torch.sum(log_beta != log_beta) > 0:
        print('log_beta has nan')
    
    l = resi - log_1alpha + lgamma_beta - log_beta
    l = torch.mean(l)
    return l
    

def bayeLq_loss1(out_mean, out_var, target, q=2, k1=1, k2=1):
    '''
    out_var has sigmoid applied to it and is between 0 and 1
    '''
    eps = 1e-7
    out_log_var = torch.log(out_var + eps)
    factor = 1/(out_var + eps)
#     print('im dbg2: ', factor.min(), factor.max())
    diffq = factor*torch.pow(out_mean-target, q)
    loss1 = k1*torch.mean(diffq)
    loss2 = k2*torch.mean(out_log_var)
#     print('im dbg: ', loss1.item(), loss2.item())
    loss = 0.5*(loss1 + loss2)
    return loss

def bayeLq_loss_n_ch(out_mean, out_log_var, target, q=2, k1=1, k2=1, n_ch=3):
    '''
    assumes uncertainty values are single channel
    '''
    out_log_var_nch = out_log_var.repeat(1,n_ch,1,1)

    factor = torch.exp(-out_log_var_nch)
    diffq = factor*torch.pow(out_mean-target, q)
    loss1 = k1*torch.mean(diffq)
    loss2 = k2*torch.mean(out_log_var) #does it have to be nch times?
    loss = 0.5*(loss1 + loss2)
    return loss

def Sinogram_loss(A, out_y, target, q=2):
    '''
    A = n_rows x (128x88)
    expected image: 128 x 88
    So load the variable, transpose it.
    incoming variable: out_y, target: n_batch x 1 x 88 x 128

    z = out_y.view(-1,n_batch) : (128x88) x n_batch

    Az = n_row x 1
    '''
    n_batch = out_y.shape[0]
    #sino = torch.mm(A, out_y.view(-1,n_batch))
    #na = 120, nb = 128;
    #sino = sino.view(na,nb)
    resi = torch.abs(torch.mm(A, out_y.view(-1,n_batch)) - torch.mm(A, target.view(-1,n_batch)))
#     print('sino dbg1: ', resi.min(), resi.max())
    resi = torch.pow(resi, q)
    return torch.mean(resi)

def bayeLq_Sino_loss(A, out_mean, out_log_var, target, q=2, k1=1, k2=1):
    n_batch = out_mean.shape[0]
    var_eps = 3e-3
    out_var = var_eps + torch.exp(out_log_var)
    
    resi = torch.abs(torch.mm(A, out_mean.view(-1,n_batch)) - torch.mm(A, target.view(-1,n_batch)))
#     x1 = torch.mm(A, out_mean.view(-1,n_batch)).view(-1).data.cpu().numpy()
#     x2 = torch.mm(A, target.view(-1,n_batch)).view(-1).data.cpu().numpy()
#     plt.subplot(1,2,1)
#     plt.hist(x1)
#     plt.subplot(1,2,2)
#     plt.hist(x2)
#     plt.show()
    sino_var_eps = 2e-2
    A_out_log_var = torch.log(torch.mm(A, out_var.view(-1,n_batch)) + sino_var_eps)
#     print(A_out_log_var)
    x1 = A_out_log_var.view(-1).data.cpu().numpy()
#      plt.subplot(1,2,1)
#     plt.hist(x1)
#     plt.show()
    factor = torch.exp(-1*A_out_log_var)
    
    diffq = factor*torch.pow(resi, q)
    loss1 = k1*torch.mean(diffq)
    loss2 = k2*torch.mean(A_out_log_var)
    
    loss = 0.5*(loss1 + loss2)
    return loss

def bayeLq_Sino_loss1(A, out_mean, out_var, target, q=2, k1=1, k2=1):
    eps = 1e-7
    n_batch = out_mean.shape[0]
    #print(A.shape, out_mean.shape, out_log_var.shape, target.shape)
    resi = torch.abs(torch.mm(A, out_mean.view(-1,n_batch)) - torch.mm(A, target.view(-1,n_batch)))
    resi = torch.clamp(resi, min=0, max=1e2)
    
    out_log_var = torch.log(out_var+eps)
    A_out_log_var = torch.log(torch.mm(A, out_var.view(-1,n_batch)))
    A_out_log_var = torch.clamp(A_out_log_var, min=-3, max=3)
    
    factor = torch.exp(-1*A_out_log_var)
    
    diffq = factor*torch.pow(resi, q)
    loss1 = k1*torch.mean(diffq)
    loss2 = k2*torch.mean(A_out_log_var)
    
    loss = 0.5*(loss1 + loss2)
    return loss
    

def save_model(M, M_ckpt):
    torch.save(M.state_dict(), M_ckpt)
    print('model saved @ {}'.format(M_ckpt))

def show_G(G, x_lr, x_hr):
    G.eval()
    with torch.no_grad():
        plt.figure(figsize=(15,10))
        plt.subplot(1,5,1)
        plt.imshow(x_lr[0,0,:,:].data.cpu().numpy(), cmap='gray')
        plt.title('lr')

        mean_sr, log_var_sr = G(x_lr)
        var_sr = torch.exp(log_var_sr)
        plt.subplot(1,5,2)
        plt.imshow(mean_sr[0,0,:,:].data.cpu().numpy(), cmap='gray')
        plt.title('sr')
        
        plt.subplot(1,5,3)
        plt.imshow(log_var_sr[0,0,:,:].data.cpu().numpy(), cmap='jet')
        plt.title('log_var sr')
        plt.subplot(1,5,4)
        plt.imshow(var_sr[0,0,:,:].data.cpu().numpy(), cmap='jet')
        plt.title('var sr')

        plt.subplot(1,5,5)
        plt.imshow(x_hr[0,0,:,:].data.cpu().numpy(), cmap='gray')
        plt.title('hr')
        plt.show()

def Gen_loss(D_for_pred, pred, target, k1=1e-3):
    adv_loss = torch.mean(1 - D_for_pred)
    fid_loss = torch.nn.functional.mse_loss(pred, target)
    total_loss = fid_loss + k1*adv_loss
    return total_loss

def Gen_genUncer_loss(D_for_pred, pred, pred_1alpha, pred_beta, target, k1=1e-4):
    adv_loss = torch.mean(1 - D_for_pred)
    fid_loss = bayeGen_loss(pred, pred_1alpha, pred_beta, target)
    total_loss = fid_loss + k1*adv_loss
    return total_loss

def Dis_loss(D, SR_pred, HR_target):
    n_batch = SR_pred.shape[0]
    dtype = SR_pred.type()
    target_real = torch.rand(n_batch,1)*0.2 + 0.8
    target_fake = torch.rand(n_batch,1)*0.2
    target_real = target_real.type(dtype)
    target_fake = target_fake.type(dtype)
    
    adv_loss = torch.nn.functional.binary_cross_entropy(D(HR_target), target_real)
    adv_loss += torch.nn.functional.binary_cross_entropy(D(SR_pred), target_fake)
    return adv_loss

### Networks

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools

### components
class ResConv(nn.Module):
    """
    Residual convolutional block, where
    convolutional block consists: (convolution => [BN] => ReLU) * 3
    residual connection adds the input to the output
    """
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.double_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x_in = self.double_conv1(x)
        x1 = self.double_conv(x)
        return self.double_conv(x) + x_in

class Down(nn.Module):
    """Downscaling with maxpool then Resconv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            ResConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
	"""Upscaling then double conv"""
	def __init__(self, in_channels, out_channels, bilinear=True):
		super().__init__()
		# if bilinear, use the normal convolutions to reduce the number of channels
		if bilinear:
			self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
			self.conv = ResConv(in_channels, out_channels, in_channels // 2)
		else:
			self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
			self.conv = ResConv(in_channels, out_channels)
	def forward(self, x1, x2):
		x1 = self.up(x1)
		# input is CHW
		diffY = x2.size()[2] - x1.size()[2]
		diffX = x2.size()[3] - x1.size()[3]
		x1 = F.pad(
			x1, 
			[
				diffX // 2, diffX - diffX // 2,
				diffY // 2, diffY - diffY // 2
			]
		)
		# if you have padding issues, see
		# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
		# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
		x = torch.cat([x2, x1], dim=1)
		return self.conv(x)

class OutConv(nn.Module):
	def __init__(self, in_channels, out_channels):
		super(OutConv, self).__init__()
		self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
	def forward(self, x):
		# return F.relu(self.conv(x))
		return self.conv(x)

##### The composite networks
class UNet(nn.Module):
	def __init__(self, n_channels, out_channels, bilinear=True):
		super(UNet, self).__init__()
		self.n_channels = n_channels
		self.out_channels = out_channels
		self.bilinear = bilinear
		####
		self.inc = ResConv(n_channels, 64)
		self.down1 = Down(64, 128)
		self.down2 = Down(128, 256)
		self.down3 = Down(256, 512)
		factor = 2 if bilinear else 1
		self.down4 = Down(512, 1024 // factor)
		self.up1 = Up(1024, 512 // factor, bilinear)
		self.up2 = Up(512, 256 // factor, bilinear)
		self.up3 = Up(256, 128 // factor, bilinear)
		self.up4 = Up(128, 64, bilinear)
		self.outc = OutConv(64, out_channels)
	def forward(self, x):
		x1 = self.inc(x)
		x2 = self.down1(x1)
		x3 = self.down2(x2)
		x4 = self.down3(x3)
		x5 = self.down4(x4)
		x = self.up1(x5, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		y = self.outc(x)
		return y

class CasUNet(nn.Module):
	def __init__(self, n_unet, io_channels, bilinear=True):
		super(CasUNet, self).__init__()
		self.n_unet = n_unet
		self.io_channels = io_channels
		self.bilinear = bilinear
		####
		self.unet_list = nn.ModuleList()
		for i in range(self.n_unet):
			self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
	def forward(self, x):
		y = x
		for i in range(self.n_unet):
			if i==0:
				y = self.unet_list[i](y)
			else:
				y = self.unet_list[i](y+x)
		return y

class CasUNet_2head(nn.Module):
	def __init__(self, n_unet, io_channels, bilinear=True):
		super(CasUNet_2head, self).__init__()
		self.n_unet = n_unet
		self.io_channels = io_channels
		self.bilinear = bilinear
		####
		self.unet_list = nn.ModuleList()
		for i in range(self.n_unet):
			if i != self.n_unet-1:
				self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
			else:
				self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
	def forward(self, x):
		y = x
		for i in range(self.n_unet):
			if i==0:
				y = self.unet_list[i](y)
			else:
				y = self.unet_list[i](y+x)
		y_mean, y_sigma = y[0], y[1]
		return y_mean, y_sigma

class CasUNet_3head(nn.Module):
	def __init__(self, n_unet, io_channels, bilinear=True):
		super(CasUNet_3head, self).__init__()
		self.n_unet = n_unet
		self.io_channels = io_channels
		self.bilinear = bilinear
		####
		self.unet_list = nn.ModuleList()
		for i in range(self.n_unet):
			if i != self.n_unet-1:
				self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
			else:
				self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
	def forward(self, x):
		y = x
		for i in range(self.n_unet):
			if i==0:
				y = self.unet_list[i](y)
			else:
				y = self.unet_list[i](y+x)
		y_mean, y_alpha, y_beta = y[0], y[1], y[2]
		return y_mean, y_alpha, y_beta

class UNet_2head(nn.Module):
	def __init__(self, n_channels, out_channels, bilinear=True):
		super(UNet_2head, self).__init__()
		self.n_channels = n_channels
		self.out_channels = out_channels
		self.bilinear = bilinear
		####
		self.inc = ResConv(n_channels, 64)
		self.down1 = Down(64, 128)
		self.down2 = Down(128, 256)
		self.down3 = Down(256, 512)
		factor = 2 if bilinear else 1
		self.down4 = Down(512, 1024 // factor)
		self.up1 = Up(1024, 512 // factor, bilinear)
		self.up2 = Up(512, 256 // factor, bilinear)
		self.up3 = Up(256, 128 // factor, bilinear)
		self.up4 = Up(128, 64, bilinear)
		#per pixel multiple channels may exist
		self.out_mean = OutConv(64, out_channels)
		#variance will always be a single number for a pixel
		self.out_var = nn.Sequential(
			OutConv(64, 128),
			OutConv(128, 1),
		)
	def forward(self, x):
		x1 = self.inc(x)
		x2 = self.down1(x1)
		x3 = self.down2(x2)
		x4 = self.down3(x3)
		x5 = self.down4(x4)
		x = self.up1(x5, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		y_mean, y_var = self.out_mean(x), self.out_var(x)
		return y_mean, y_var

class UNet_3head(nn.Module):
	def __init__(self, n_channels, out_channels, bilinear=True):
		super(UNet_3head, self).__init__()
		self.n_channels = n_channels
		self.out_channels = out_channels
		self.bilinear = bilinear
		####
		self.inc = ResConv(n_channels, 64)
		self.down1 = Down(64, 128)
		self.down2 = Down(128, 256)
		self.down3 = Down(256, 512)
		factor = 2 if bilinear else 1
		self.down4 = Down(512, 1024 // factor)
		self.up1 = Up(1024, 512 // factor, bilinear)
		self.up2 = Up(512, 256 // factor, bilinear)
		self.up3 = Up(256, 128 // factor, bilinear)
		self.up4 = Up(128, 64, bilinear)
		#per pixel multiple channels may exist
		self.out_mean = OutConv(64, out_channels)
		#variance will always be a single number for a pixel
		self.out_alpha = nn.Sequential(
			OutConv(64, 128),
			OutConv(128, 1),
			nn.ReLU()
		)
		self.out_beta = nn.Sequential(
			OutConv(64, 128),
			OutConv(128, 1),
			nn.ReLU()
		)
	def forward(self, x):
		x1 = self.inc(x)
		x2 = self.down1(x1)
		x3 = self.down2(x2)
		x4 = self.down3(x3)
		x5 = self.down4(x4)
		x = self.up1(x5, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		y_mean, y_alpha, y_beta = self.out_mean(x), \
		self.out_alpha(x), self.out_beta(x)
		return y_mean, y_alpha, y_beta

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        conv_block = [  
			nn.ReflectionPad2d(1),
			nn.Conv2d(in_features, in_features, 3),
			nn.InstanceNorm2d(in_features),
			nn.ReLU(inplace=True),
			nn.ReflectionPad2d(1),
			nn.Conv2d(in_features, in_features, 3),
			nn.InstanceNorm2d(in_features)
		]
        self.conv_block = nn.Sequential(*conv_block)
    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()
        # Initial convolution block       
        model = [
			nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
            nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
		]
        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  
				nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True) 
			]
            in_features = out_features
            out_features = in_features*2
        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]
        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  
				nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
				nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
			]
            in_features = out_features
            out_features = in_features//2
        # Output layer
        model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
        self.model = nn.Sequential(*model)
    def forward(self, x):
        return self.model(x)

### discriminator
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)
    def forward(self, input):
        """Standard forward."""
        return self.model(input)

### Training

In [9]:
import torch
import numpy as np
import scipy as sp
import skimage
from scipy.special import gamma, factorial
import matplotlib.gridspec as gridspec
from scipy.stats import gennorm
# import seaborn as sns
# sns.set_style('darkgrid')
import os, sys
import PIL
from PIL import Image
from matplotlib import pyplot as plt
# from losses import *
# from networks import *
# from ds import *
import random
random.seed(0)

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim
from torchvision import transforms, utils as tv_utils

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

In [10]:
def train_i2i_UNet3headGAN(
    netG_A,
    netD_A,
    train_loader, test_loader,
    dtype=torch.cuda.FloatTensor,
    device='cuda',
    num_epochs=50,
    init_lr=1e-4,
    ckpt_path='../ckpt/i2i_UNet3headGAN',
):
    netG_A.to(device)
    netG_A.type(dtype)
    ####
    netD_A.to(device)
    netD_A.type(dtype)
    
    ####
    optimizerG = torch.optim.Adam(list(netG_A.parameters()), lr=init_lr)
    optimizerD = torch.optim.Adam(list(netD_A.parameters()), lr=init_lr)
    optimG_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, num_epochs)
    optimD_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, num_epochs)
    ####
    list_epochs = [50, 50, 150]
    list_lambda1 = [1, 0.5, 0.1]
    list_lambda2 = [0.0001, 0.001, 0.01]
    for num_epochs, lam1, lam2 in zip(list_epochs, list_lambda1, list_lambda2):
        for eph in range(num_epochs):
            netG_A.train()
            netD_A.train()
            avg_rec_loss = 0
            avg_tot_loss = 0
            print(len(train_loader))
            for i, batch in enumerate(train_loader):
                if i>1000:
                    break
                xA, xB = batch[0].to(device).type(dtype), batch[1].to(device).type(dtype)
                #calc all the required outputs
                rec_B, rec_alpha_B, rec_beta_B = netG_A(xA)

                #first gen
                netD_A.eval()
                total_loss = lam1*F.l1_loss(rec_B, xB) + lam2*bayeGen_loss(rec_B, rec_alpha_B, rec_beta_B, xB)
                t0 = netD_A(rec_B)
                t1 = F.avg_pool2d(t0, t0.size()[2:]).view(t0.size()[0], -1)
                e5 = 0.001*F.mse_loss(t1, torch.ones(t1.size()).to(device).type(dtype))
                total_loss += e5
                optimizerG.zero_grad()
                total_loss.backward()
                optimizerG.step()

                #then discriminator
                netD_A.train()
                t0 = netD_A(xB)
                pred_real_A = F.avg_pool2d(t0, t0.size()[2:]).view(t0.size()[0], -1)
                loss_D_A_real = 1*F.mse_loss(
                    pred_real_A, torch.ones(pred_real_A.size()).to(device).type(dtype)
                )
                t0 = netD_A(rec_B.detach())
                pred_fake_A = F.avg_pool2d(t0, t0.size()[2:]).view(t0.size()[0], -1)
                loss_D_A_pred = 1*F.mse_loss(
                    pred_fake_A, torch.zeros(pred_fake_A.size()).to(device).type(dtype)
                )
                loss_D_A = (loss_D_A_real + loss_D_A_pred)*0.5

                loss_D = loss_D_A
                optimizerD.zero_grad()
                loss_D.backward()
                optimizerD.step()

                avg_tot_loss += total_loss.item()

            avg_tot_loss /= len(train_loader)
            print(
                'epoch: [{}/{}] | avg_tot_loss: {}'.format(
                    eph, num_epochs, avg_tot_loss
                )
            )
            torch.save(netG_A.state_dict(), ckpt_path+'_eph{}_G_A.pth'.format(eph))
            torch.save(netD_A.state_dict(), ckpt_path+'_eph{}_D_A.pth'.format(eph))
    return netG_A, netD_A


def train_i2i_Cas_UNet3headGAN(
    list_netG_A,
    list_netD_A,
    train_loader, test_loader,
    dtype=torch.cuda.FloatTensor,
    device='cuda',
    num_epochs=50,
    init_lr=1e-4,
    ckpt_path='../ckpt/i2i_UNet3headGAN',
):
    for nid, m1 in enumerate(list_netG_A):
        m1.to(device)
        m1.type(dtype)
        list_netG_A[nid] = m1
        
    for nid, m2 in enumerate(list_netD_A):
        m2.to(device)
        m2.type(dtype)
        list_netD_A[nid] = m2
    ####
    optimizerG = torch.optim.Adam(list(list_netG_A[-1].parameters()), lr=init_lr)
    optimizerD = torch.optim.Adam(list(list_netD_A[-1].parameters()), lr=init_lr)
    optimG_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, num_epochs)
    optimD_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, num_epochs)
    ####
    list_epochs = [50, 50, 150]
    list_lambda1 = [1, 0.5, 0.1]
    list_lambda2 = [0.0001, 0.001, 0.01]
    netG_A, netD_A = list_netG_A[-1], list_netD_A[-1]
    ####
    for num_epochs, lam1, lam2 in zip(list_epochs, list_lambda1, list_lambda2):
        for eph in range(num_epochs):
            netG_A.train()
            netD_A.train()
            avg_rec_loss = 0
            avg_tot_loss = 0
            print(len(train_loader))
            for i, batch in enumerate(train_loader):
                if i>1000:
                    break
                xA, xB = batch[0].to(device).type(dtype), batch[1].to(device).type(dtype)
                #calc all the required outputs
                
                for nid, netG in enumerate(list_netG_A):
                    if nid == 0:
                        rec_B, rec_alpha_B, rec_beta_B = netG(xA)
                    else:
                        xch = torch.cat([rec_B, rec_alpha_B, rec_beta_B, xA], dim=1)
                        rec_B, rec_alpha_B, rec_beta_B = netG(xch)

                #first gen
                netD_A.eval()
                total_loss = lam1*F.l1_loss(rec_B, xB) + lam2*bayeGen_loss(rec_B, rec_alpha_B, rec_beta_B, xB)
                t0 = netD_A(rec_B)
                t1 = F.avg_pool2d(t0, t0.size()[2:]).view(t0.size()[0], -1)
                e5 = 0.001*F.mse_loss(t1, torch.ones(t1.size()).to(device).type(dtype))
                total_loss += e5
                optimizerG.zero_grad()
                total_loss.backward()
                optimizerG.step()

                #then discriminator
                netD_A.train()
                t0 = netD_A(xB)
                pred_real_A = F.avg_pool2d(t0, t0.size()[2:]).view(t0.size()[0], -1)
                loss_D_A_real = 1*F.mse_loss(
                    pred_real_A, torch.ones(pred_real_A.size()).to(device).type(dtype)
                )
                t0 = netD_A(rec_B.detach())
                pred_fake_A = F.avg_pool2d(t0, t0.size()[2:]).view(t0.size()[0], -1)
                loss_D_A_pred = 1*F.mse_loss(
                    pred_fake_A, torch.zeros(pred_fake_A.size()).to(device).type(dtype)
                )
                loss_D_A = (loss_D_A_real + loss_D_A_pred)*0.5

                loss_D = loss_D_A
                optimizerD.zero_grad()
                loss_D.backward()
                optimizerD.step()

                avg_tot_loss += total_loss.item()

                if i%500 == 0:
                    print(eph, i)
                    test_uncorr2CT_Cas_UNet3headGAN_n_show(
                        list_netG_A,
                        test_loader,
                        device,
                        dtype,
                        nrow=1,
                        n_show = 1
                    )
            avg_tot_loss /= len(train_loader)
            print(
                'epoch: [{}/{}] | avg_tot_loss: {}'.format(
                    eph, num_epochs, avg_tot_loss
                )
            )
            torch.save(netG_A.state_dict(), ckpt_path+'_eph{}_G_A.pth'.format(eph))
            torch.save(netD_A.state_dict(), ckpt_path+'_eph{}_D_A.pth'.format(eph))
    return list_netG_A, list_netD_A

In [11]:
# init net and train
netG_A = CasUNet_3head(1,1)
netD_A = NLayerDiscriminator(1, n_layers=4)
netG_A, netD_A = train_i2i_UNet3headGAN(
    netG_A, netD_A,
    train_loader, test_loader,
    dtype=torch.cuda.FloatTensor,
    device='cuda',
    num_epochs=50,
    init_lr=1e-5,
    ckpt_path='../ckpt/i2i_UNet3headGAN',
)

# init net and train
netG_A1 = CasUNet_3head(1,1)
netG_A1.load_state_dict(torch.load('../ckpt/uncorr2CT_UNet3headGAN_v1_eph78_G_A.pth'))
netG_A2 = UNet_3head(4,1)
netG_A2.load_state_dict(torch.load('../ckpt/uncorr2CT_Cas_UNet3headGAN_v1_eph149_G_A.pth'))
netG_A3 = UNet_3head(4,1)

netD_A = NLayerDiscriminator(1, n_layers=4)
list_netG_A, list_netD_A = train_uncorr2CT_Cas_UNet3headGAN(
    [netG_A1, netG_A2, netG_A3], [netD_A],
    train_loader, test_loader,
    dtype=torch.cuda.FloatTensor,
    device='cuda',
    num_epochs=50,
    init_lr=1e-5,
    ckpt_path='../ckpt/uncorr2CT_Cas_UNet3headGAN_v1_block3',
)

NameError: ignored