In [15]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as scio
import argparse
from prettytable import PrettyTable
from copy import deepcopy

from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from ssim_loss import SSIM
from Unet_block import *
from TIE_model import *
from split_dataset import *
from initialization import *
from torchvision.models import DenseNet
from torchvision.models.densenet import _Transition, _load_state_dict
from collections import OrderedDict

os.environ['CUDA_VISIBLE_DEVICES'] = "0" #select num.0 GPU

## Build Dataset

In [2]:
def NORMalize(image, MIN_B, MAX_B):
    image = (image - MIN_B) / (MAX_B - MIN_B)
   #image[image>1] = 1.
   #image[image<0] = 0.
    return image

def DENORMalize(image, MIN_B, MAX_B):
    image = image * (MAX_B - MIN_B) + MIN_B
    return image

    
def build_phasedata(dataset, norm_range = (0.0, 1.0)):
    
    nt_red = 6
    data = scio.loadmat('')['im']
    nx, ny, nt = data.shape
    s1 = nt//nt_red
    train_s = s1*4
    small_sample = s1*4+2
    reshape_order = 'C'
    
    data_path_train_in = '' #file direction
    data_path_train_out = ''
    data_path_intensity = ''
    
    mat_im = scio.loadmat(data_path_train_in + '\\im')['im']
    mat_gt = scio.loadmat(data_path_train_out + '\\label')['gt']
    mat_id = scio.loadmat(data_path_intensity + '\\in')['in']
    mat_i0 = scio.loadmat(data_path_intensity + '\\i0')['i0']

    mat_im = np.reshape(np.transpose(mat_im,(2,0,1)), (nt,1,nx,ny), order=reshape_order)
    mat_gt = np.reshape(np.transpose(mat_gt,(2,0,1)), (nt,1,nx,ny), order=reshape_order)
    mat_id = np.reshape(np.transpose(mat_id,(2,0,1)), (nt,1,nx,ny), order=reshape_order)
    mat_i0 = np.reshape(np.transpose(mat_i0,(2,0,1)), (nt,1,nx,ny), order=reshape_order)
    
    if dataset == 'train':
        
        im = mat_im[:small_sample,...]
        gt = mat_gt[:small_sample,...]
        I_d = mat_id[:small_sample,...]
        I_0= mat_i0[:small_sample,...]
        dataset = (im, gt, I_d, I_0)
            
    if dataset == 'valid':
        
        im = mat_im[small_sample:small_sample+s1,...]
        gt = mat_gt[small_sample:small_sample+s1,...]
        I_d = mat_id[small_sample:small_sample+s1,...]
        I_0 = mat_i0[small_sample:small_sample+s1,...]
        dataset = (im, gt, I_d, I_0)
        
    if dataset == 'test':
        
        im = mat_im[s1*5+2:,...]
        gt = mat_gt[s1*5+2:,...]
        I_d = mat_id[s1*5+2:,...]
        I_0 = mat_i0[s1*5+2:,...]
        dataset = (im, gt, I_d, I_0)
            
    return dataset


class train_phase_data_loader(Dataset):
    def __init__(self, dataset, crop_size=None, crop_n=None):
        self.dataset = dataset
        self.crop_size = crop_size
        self.crop_n = crop_n

    def __getitem__(self):
        (im, gt, I_d, I_0) = self.dataset
        input_img = im
        target_img = gt
        inten_id = I_d
        inten_i0 = I_0

        if self.crop_n:
            assert input_img.shape == target_img.shape
            crop_input = []
            crop_target = []
            n_c, h, w = input_img.shape
            new_h, new_w = self.crop_size, self.crop_size
            for _ in range(self.crop_n):
                top = np.random.randint(0, h - new_h)
                left = np.random.randint(0, w - new_w)
                input_img_ = input_img[top:top + new_h, left:left + new_w]
                target_img_ = target_img[top:top + new_h, left:left + new_w]
                crop_input.append(input_img_)
                crop_target.append(target_img_)
            crop_input = np.array(crop_input)
            crop_target = np.array(crop_target)

            sample = (crop_input, crop_target)
            return sample
        else:
            sample = (input_img, target_img, inten_id, inten_i0)
            return sample
        

class LambdaLR_():  #learning_rate schedulers
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

def prep_input(im, mode):
    """Undersample the batch, then reformat them into what the network accepts.
    Parameters: patch the data
    """
    #input variable
    (phase_var, target_var, id_var, i0_var) = im
    if mode == 'train':
        phase_var = torch.tensor(phase_var, requires_grad=True).to(device)
    else:
        phase_var = torch.tensor(phase_var).to(device)
    target_var = torch.tensor(target_var).to(device)
    id_var = torch.tensor(id_var).to(device)
    i0_var = torch.tensor(i0_var).to(device)

    if torch.cuda.is_available():
        phase_var = phase_var.type(torch.cuda.FloatTensor)
        target_var = target_var.type(torch.cuda.FloatTensor)
        id_var = id_var.type(torch.cuda.FloatTensor)
        i0_var = i0_var.type(torch.cuda.FloatTensor)
    
    data = DeepChainMap({'input': phase_var}, {'targets': target_var}, {'fres_input': id_var}, {'inverse_input': i0_var})

    return data
    

def calc_gradeint_penalty(discriminator, real_data, fake_data):
    #alpha = torch.rand(real_data.size()[0], 1)
    #alpha = alpha.expand(real_data.size())
    alpha = torch.Tensor(np.random.random((real_data.size(0),1,1,1)))
    alpha = alpha.cuda() if torch.cuda.is_available() else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if torch.cuda.is_available():
        interpolates = interpolates.cuda()
    interpolates = torch.autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = discriminator(interpolates)

    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda() if torch.cuda.is_available() else torch.ones(disc_interpolates.size()), create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

class no_op(object):
    def __enter__(self):
        pass

    def __exit__(self, *args):
        pass

def maybe_to_torch(d):
    if isinstance(d, list):
        d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
    elif not isinstance(d, torch.Tensor):
        d = torch.from_numpy(d).float()
    return d


def to_cuda(data, non_blocking=True, gpu_id=0):
    if isinstance(data, list):
        data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
    else:
        data = data.cuda(gpu_id, non_blocking=True)
    return data

softmax_helper = lambda x: F.softmax(x, 1)

def count_parameters(model):
    # table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

## Dense Unet

In [3]:
class _DenseUNetEncoder(DenseNet):
    def __init__(self, skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, downsample):
        super(_DenseUNetEncoder, self).__init__(growth_rate, block_config, num_init_features, bn_size, drop_rate)
        
        self.skip_connections = skip_connections

        # remove last norm, classifier
        features = OrderedDict(list(self.features.named_children())[:-1])
        delattr(self, 'classifier')
        if not downsample:
            features['conv0'].stride = 1
            del features['pool0']
        self.features = nn.Sequential(features)
        
        for module in self.features.modules():
            if isinstance(module, nn.AvgPool2d):
                module.register_forward_hook(lambda _, input, output : self.skip_connections.append(input[0]))

    def forward(self, x):
        return self.features(x)
        
class _DenseUNetDecoder(DenseNet):
    def __init__(self, skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, upsample):
        super(_DenseUNetDecoder, self).__init__(growth_rate, block_config, num_init_features, bn_size, drop_rate)
        
        self.skip_connections = skip_connections
        self.upsample = upsample
        
        # remove conv0, norm0, relu0, pool0, last denseblock, last norm, classifier
        features = list(self.features.named_children())[4:-2]
        delattr(self, 'classifier')

        num_features = num_init_features
        num_features_list = []
        for i, num_layers in enumerate(block_config):
            num_input_features = num_features + num_layers * growth_rate
            num_output_features = num_features // 2
            num_features_list.append((num_input_features, num_output_features))
            num_features = num_input_features // 2
        
        for i in range(len(features)):
            name, module = features[i]
            if isinstance(module, _Transition):
                num_input_features, num_output_features = num_features_list.pop(1)
                features[i] = (name, _TransitionUp(num_input_features, num_output_features, skip_connections))

        features.reverse()
        
        self.features = nn.Sequential(OrderedDict(features))
        
        num_input_features, _ = num_features_list.pop(0)
        
        if upsample:
            self.features.add_module('upsample0', nn.Upsample(scale_factor=4, mode='bilinear'))
        self.features.add_module('norm0', nn.BatchNorm2d(num_input_features))
        self.features.add_module('relu0', nn.ReLU(inplace=True))
        self.features.add_module('conv0', nn.Conv2d(num_input_features, num_init_features, kernel_size=1, stride=1, bias=False))
        self.features.add_module('norm1', nn.BatchNorm2d(num_init_features))

    def forward(self, x):
        return self.features(x)
          
        
class _Concatenate(nn.Module):
    def __init__(self, skip_connections):
        super(_Concatenate, self).__init__()
        self.skip_connections = skip_connections
        
    def forward(self, x):
        return torch.cat([x, self.skip_connections.pop()], 1)

          
class _TransitionUp(nn.Sequential):
    def __init__(self, num_input_features, num_output_features, skip_connections):
        super(_TransitionUp, self).__init__()
        
        self.add_module('norm1', nn.BatchNorm2d(num_input_features))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module('conv1', nn.Conv2d(num_input_features, num_output_features * 2,
                                              kernel_size=1, stride=1, bias=False))
        
        self.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
        self.add_module('cat', _Concatenate(skip_connections))
        self.add_module('norm2', nn.BatchNorm2d(num_output_features * 4))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module('conv2', nn.Conv2d(num_output_features * 4, num_output_features,
                                          kernel_size=1, stride=1, bias=False))

class DenseUNet(nn.Module):
    def __init__(self, n_classes, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, downsample=False, pretrained_encoder_uri=None, progress=None):
        super(DenseUNet, self).__init__()
        self.skip_connections = []
        self.encoder = _DenseUNetEncoder(self.skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, downsample)
        self.decoder = _DenseUNetDecoder(self.skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, downsample)
        self.classifier = nn.Conv2d(num_init_features, n_classes, kernel_size=1, stride=1, bias=True)
        
        self.encoder._load_state_dict = self.encoder.load_state_dict
        self.encoder.load_state_dict = lambda state_dict : self.encoder._load_state_dict(state_dict, strict=False)
        if pretrained_encoder_uri:
            _load_state_dict(self.encoder, str(pretrained_encoder_uri), progress)
        self.encoder.load_state_dict = lambda state_dict : self.encoder._load_state_dict(state_dict, strict=True)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        y = self.classifier(x)
        return y


## PINN model

In [4]:
class Unet(nn.Module):
    def __init__(self, n_channels, n_classes, nf, ks, drop_rate=0.5):
        super(Unet, self).__init__()
        self.inc = inconv(n_channels, nf, ks)
        self.down1 = down(nf, nf*2, ks)
        self.down2 = down(nf*2, nf*4, ks)
        self.down3 = down(nf*4, nf*8, ks)
        self.down4 = down(nf*8, nf*16, ks)
        self.up1 = up(nf*24, nf*8, ks)
        self.up2 = up(nf*12, nf*4, ks)
        self.up3 = up(nf*6, nf*2, ks, drop_rate)
        self.up4 = up(nf*3, nf, ks, drop_rate)
        self.outc = outconv(nf, n_classes)
        

    def forward(self, im):
        raw = im['input']

        x1 = self.inc(raw)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.outc(x)

        x = raw - x

        return x

class Seven_Unetplusplus(nn.Module):
    def __init__(self, n_channels, n_classes, nf, ks, drop_rate=0.5):
        super(Seven_Unetplusplus, self).__init__()
        self.inc = inconv(n_channels, nf, ks)
        self.down1 = down(nf, nf*2, ks)
        self.down2 = down(nf*2, nf*4, ks)
        self.down3 = down(nf*4, nf*8, ks)
        self.up1 = up(nf*12, nf*4, ks)
        self.up2 = up(nf*6, nf*2, ks, drop_rate)
        self.up3 = up(nf*3, nf, ks, drop_rate)
        
        self.dense1 = Dense_Block(nf,3)
        self.dense2 = Dense_Block(nf,2)

        self.outc = outconv(nf, n_classes)
        

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x21,x22 = self.dense2(x2,x3)
        x = self.up2(x, x22)
        x13 = self.dense1(x1,x2,x21)
        x = self.up3(x, x13)

        x = self.outc(x)

        return x


class PINN(nn.Module):
    def __init__(self, n_channels, n_classes, nf, ks,trade_off, dpix, z, Hsize, _lambda, k, method, nc=3, nd=5, **kwargs):
        super(PINN, self).__init__()
        self.nc = nc
        self.nd = nd
        self.method = method
        self.trade_off = trade_off
        conv_dim = 2
        dilation = 1
        n_ch = n_channels
        n_out = n_classes
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print('Creating D{}C{}'.format(nd, nc))
        conv_blocks = []
        dcs = []
        gamma = []
        
        pad_conv = 1
        if dilation > 1:
        # in = floor(in + 2*pad - dilation * (ks-1) - 1)/stride + 1)
        # pad = dilation
            pad_dilconv = dilation
        else:
            pad_dilconv = pad_conv
        
        for i in range(nc):
            conv_blocks.append(Seven_Unetplusplus(n_ch, n_out, nf, ks, **kwargs))
            gamma.append(torch.tensor(self.trade_off,device = self.device,requires_grad=True))
            with torch.no_grad():
                dcs.append(Data_consistency(dpix, z, Hsize, _lambda, k, **kwargs))

        self.conv_blocks = nn.ModuleList(conv_blocks)
        self.dcs = dcs
        self.gamma = gamma

    def forward(self, im):
        
        raw = im['input']
        x = im['input']
        x_i0 = im['inverse_input']
        x_id = im['fres_input']

        for i in range(self.nc):
            x_phase = self.dcs[0].perform(x,x_i0,x_id)

            if i == 0:
                x = raw - self.conv_blocks[i](raw)
            else:
                x = x - self.conv_blocks[i](x)

            if self.method == 'constant':
                x = x - x_phase*self.trade_off
            elif self.method == 'learn':
                x = x - x_phase*self.gamma[i]
                
        return x 

    
class PINNShared(nn.Module):
    def __init__(self, n_channels, n_classes, nf, ks,trade_off, dpix, z, Hsize, _lambda, k, nc=3, nd=5, **kwargs):
        super(PINNShared, self).__init__()
        self.nc = nc
        self.nd = nd
        self.trade_off = trade_off
        print('Creating D{}C{}-S (2D)'.format(nd, nc))
        conv_blocks = []
        dc = []
        gamma = []

        conv_blocks.append(Seven_Unetplusplus(n_channels, n_classes, nf, ks, **kwargs))
        gamma.append(torch.tensor(self.trade_off,device = self.device,requires_grad=True))
        self.conv_blocks = nn.ModuleList(conv_blocks)
        with torch.no_grad():
            self.dcs = Data_consistency(dpix, z, Hsize, _lambda, k, **kwargs)
        self.gamma = gamma

    def forward(self, im):

        raw = im['input']
        x = im['input']
        x_i0 = im['inverse_input']
        x_id = im['fres_input']

        for i in range(self.nc):
            x_phase = self.dcs[0].perform(x,x_i0,x_id)
            x = x - self.conv_blocks[0](x)
            if self.method == 'constant':
                x = x - x_phase*self.trade_off
            elif self.method == 'learn':
                x = x - x_phase*self.gamma[0]

        return x


## Main function

In [5]:

if __name__=='__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    parser = argparse.ArgumentParser(description='parameters of PINN')
    parser.add_argument('--NUM_EPOCH', default = 200, type=int,nargs='*', help='number of epochs')
    parser.add_argument('--START_EPOCH', default = 0, type=int, nargs='*', help='start epoch')
    parser.add_argument('--DECAY_EPOCH', default = 5, type=int, nargs='*', help='decay epoch')
    parser.add_argument('--CRITIC_ITER', default = 1, type=int, nargs='*', help='critic iteration')
    parser.add_argument('--SAVE_EPOCH', default = 10, type=int, nargs='*', help='save iteration')
    parser.add_argument('--BATCH_SIZE', default = 1, type=int, nargs=1, help='batch size')
    parser.add_argument('--lr', default = 5e-4, type=float, nargs=1, help='initial learning rate')
    parser.add_argument('--beta1', default = 0.9, type=float, nargs=1, help='first-moment exponential decay rate')
    parser.add_argument('--beta2', default = 0.999, type=float, nargs=1, help='second-moment exponential decay rate')
    parser.add_argument('--input_nc', default = 1, type=int, nargs=1, help='input channel')
    parser.add_argument('--output_nc', default = 1, type=int, nargs=1, help='output channel')
    parser.add_argument('--nf', default = 32, type=int, nargs=1, help='number of filters')
    parser.add_argument('--ks', default = 3, type=int, nargs=1, help='kernel size')
    parser.add_argument('--tradeoff', default = 1e-2, nargs=1, type=float, help='trade-off value')
    parser.add_argument('--itr_method', default = 'learn', nargs=1, type=str, help='constant or learning trade-off value')
    parser.add_argument('--dpix', default = 0.22e-6, type=float, nargs=1, help='pixel size')
    parser.add_argument('--d', default = 2e-6, type=float, nargs=1, help='propagation distance')
    parser.add_argument('--Hsize', default = 512, type=int, nargs=1, help='image size')
    parser.add_argument('--wl', default = 660e-9, type=float, nargs=1, help='wave length')
    parser.add_argument('--k', default = 9.51e6, nargs=1, type=int, help='k value')
    parser.add_argument('--loss_function', default = 'mix', type=str, help='training loss function')
    parser.add_argument('--alpha', default = 5, type=float, nargs=1, help='trade-off value of mix loss function')
    parser.add_argument('--model_name', default = 'pinn', type=str, help='project model')
    parser.add_argument('--pretrain', default = False, type=bool, help='model pretrain')
    parser.add_argument('--debug', action='store_true', help='debug mode')
    parser.add_argument('--savefig', action='store_true',
                        help='Save output images and masks')

    args = parser.parse_args([])
    cuda = True if torch.cuda.is_available() else False

    save_frq = 10
    
    #specify work
    if args.model_name == 'pinn':
        pinn = PINN(n_channels = args.input_nc, n_classes = args.output_nc, nf = args.nf, ks = args.ks, trade_off = args.tradeoff,
                        dpix = args.dpix, z = args.d, Hsize = args.Hsize, _lambda = args.wl, k = args.k, method = args.itr_method) 
    elif args.model_name == 'unet':
        pinn = Unet(n_channels = args.input_nc, n_classes = args.output_nc, nf = args.nf, ks = args.ks)
    else:
        pinn= DenseUNet(n_classes = args.output_nc)
   
    if torch.cuda.device_count() > 1:
        print("Use {} GPUs".format(torch.cuda.device_count()), "=" * 9)
        pinn = nn.DataParallel(pinn)
        if args.pretrain is True:
            init_weights(pinn, 'normal', init_gain=0.02)
    
    pinn.to(device)
    # loss
    criterion_GAN = nn.MSELoss() #L2 loss
    criterion_Sensitive = nn.L1Loss() #L1 loss
    criterion_ssim = SSIM(window_size = 11) #SSIM loss
    
    # mix loss function recommended
    if cuda:
        pinn = pinn.cuda()
        if args.loss_function == 'l2':
            criterion_loss = criterion_GAN.cuda()
        elif args.loss_function == 'l1':
            criterion_loss = criterion_Sensitive.cuda()
        elif args.loss_function == 'ssim':
            criterion_loss = criterion_ssim.cuda()
        elif args.loss_function == 'mix':
            criterion_ssim = criterion_ssim.cuda() 
            criterion_mse = criterion_GAN.cuda()   


    # optimizer
    optimizer = torch.optim.Adam(pinn.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    # learning rate schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR_(args.NUM_EPOCH, args.START_EPOCH, args.DECAY_EPOCH).step)

   # input & target data & intensity images
    train_dataset = build_phasedata(dataset = 'train', norm_range=(0, 1))
    train_loader = train_phase_data_loader(dataset = train_dataset, crop_size=None, crop_n=None)
    train_data = train_loader.__getitem__()
    (train_im, train_gt, train_id, train_i0) = train_data
    (m, m_chan, m_r, m_c) = train_im.shape

    valid_dataset = build_phasedata(dataset = 'valid', norm_range=(0, 1))
    valid_loader = train_phase_data_loader(dataset = valid_dataset, crop_size=None, crop_n=None)
    valid_data = valid_loader.__getitem__()
    (valid_im, valid_gt, valid_id, valid_i0) = valid_data
    (m_valid, n_chan, n_r, n_c) = valid_im.shape

    test_dataset = build_phasedata(dataset = 'test', norm_range=(0, 1))
    test_loader = train_phase_data_loader(dataset = test_dataset, crop_size=None, crop_n=None)                                    
    test_data = test_loader.__getitem__()
    
#     Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
    Tensor = torch.cuda.FloatTensor #GPU only
    
    Loss_list = [] #visualize result
    valid_list = []
    #Store models and the results in test-set
    project_root = '' 
    save_dir = os.path.join(project_root, '\\%s' % args.model_name)
    os.chdir(save_dir)
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    for epoch in range(args.START_EPOCH, args.NUM_EPOCH):
        start_time = time.time()
        count = 0
        total_loss = 0.
        valid_loss = 0.
        torch.cuda.empty_cache() #crear the GPU storage
        train_mini_batches, valid_mini_batches, test_mini_batches = random_mini_batches(train_data, valid_data, test_data, mini_batch_size=args.BATCH_SIZE, shuffle = True)

        #Training process
        for minibatch in train_mini_batches:
            count += 1
        
            im = prep_input(minibatch, mode='train')
            target_img = im['targets']

            ######## Train PINN ########
            optimizer.zero_grad()
        
            if count % args.CRITIC_ITER == 0:
                
                fake_img = pinn(im)
                if args.loss_function == 'ssim':
                    loss_G = -criterion_loss(fake_img,target_img)
                elif args.loss_function == 'mix':
                    loss_G = (1-criterion_ssim(fake_img,target_img)) * args.alpha + criterion_mse(fake_img,target_img)
                else:
                    loss_G = criterion_loss(fake_img,target_img)
                total_loss += float(loss_G)
                loss_G.backward()
                optimizer.step()

            if count % 10 == 0:
                print("EPOCH [{}/{}], STEP [{}/{}]".format(epoch+1, args.NUM_EPOCH, count+1, m))
                print("Total Loss G: {}".format(loss_G))
        end_time = time.time()
        print("Training time is ",(end_time-start_time)/60,'mins')
        Loss_list.append(total_loss / m)
        
    #Validation process
        count = 0
        for minibatch in valid_mini_batches:
            count += 1
            im = prep_input(minibatch, mode='valid')
            target_img = im['targets']
        
            if count % args.CRITIC_ITER == 0:
                
                fake_img = pinn(im)
                if args.loss_function == 'mix':
                    loss_V = (1-criterion_ssim(fake_img,target_img)) * args.alpha + criterion_mse(fake_img,target_img)
                else:
                    loss_V = criterion_loss(fake_img,target_img)
                valid_loss += float(loss_V)
                
            if count % 10 == 0:
                print("Valid: 2-shot EPOCH [{}/{}], STEP [{}/{}]".format(epoch+1, args.NUM_EPOCH, count, m_valid))
                print("Total Loss V: {}".format(loss_V))
        valid_list.append(valid_loss / m_valid)
    

        #Test process save test figure in every 10 epochs
        total_loss_t = 0.
        test_count = 0
        if (epoch+1) > (args.NUM_EPOCH/2) and (epoch+1) % args.SAVE_EPOCH == 0:
            for mini_batch in test_mini_batches:
                test_count += 1
                im = prep_input(mini_batch, mode='test')
                input_img = im['input']
                target_img = im['targets']

                fake_img = pinn(im)

                if test_count % 1 == 0:
                    
                    with torch.no_grad():

                        #pytorch tensor
                        input_pic = input_img.cpu().numpy()
                        fake_pic = fake_img.cpu().numpy()
                        target_pic = target_img.cpu().numpy()

                        #numpy array
                        scio.savemat("InputIn"+str(test_count)+".mat", {'array':input_pic})
                        scio.savemat(str(epoch+1)+"_"+str(test_count)+".mat", {'array':fake_pic})
                        scio.savemat("GT"+str(test_count)+".mat", {'array':target_pic})
    
    #save models
    torch.save(pinn.state_dict(),os.path.join(save_dir,'net.pth'))
            


Creating D5C3


  pi,lamda,z,dpix,Hsize = [torch.tensor(i,device=self.device) for i in [np.pi,self._lambda,z,self.dpix,self.Hsize]]
  self.z = torch.tensor(self.z,device = self.device)


EPOCH [1/200], STEP [11/134]
Total Loss G: 1.791059136390686
EPOCH [1/200], STEP [21/134]
Total Loss G: 3.2435414791107178
EPOCH [1/200], STEP [31/134]
Total Loss G: 0.6815721988677979
EPOCH [1/200], STEP [41/134]
Total Loss G: 2.3443777561187744
EPOCH [1/200], STEP [51/134]
Total Loss G: 1.8194869756698608
EPOCH [1/200], STEP [61/134]
Total Loss G: 1.6262032985687256
EPOCH [1/200], STEP [71/134]
Total Loss G: 1.1527962684631348
EPOCH [1/200], STEP [81/134]
Total Loss G: 2.3574440479278564
EPOCH [1/200], STEP [91/134]
Total Loss G: 1.4220061302185059
EPOCH [1/200], STEP [101/134]
Total Loss G: 1.6221528053283691
EPOCH [1/200], STEP [111/134]
Total Loss G: 1.4837125539779663
EPOCH [1/200], STEP [121/134]
Total Loss G: 1.5281918048858643
EPOCH [1/200], STEP [131/134]
Total Loss G: 2.2293293476104736
EPOCH [1/200], STEP [141/134]
Total Loss G: 2.4842910766601562
Training time is  1.1524493098258972 mins
Valid: 2-shot EPOCH [1/200], STEP [10/33]
Total Loss V: 1.6021842956542969
Valid: 2-sh