In [1]:
import os
import glob

# h5py can read hdf5 dataset
import h5py

# delete bad data files
from send2trash import send2trash

# fastmri has some k-space undersampling functions we can use
# git clone https://github.com/facebookresearch/fastMRI.git
# go to the fastmri directory
# pip install -e.
import fastmri

# We will use this functions to generate masks
from fastmri.data.subsample import RandomMaskFunc, EquispacedMaskFunc

# sigpy is apparently a good MRI viewing tool
# pip install sigpy
import sigpy as sp
import sigpy.plot as pl

import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib

%matplotlib notebook

In [2]:
# define constants
DATASET = 'singlecoil_train'
AXES = {
        'singlecoil_train' : (1, 2),
        'multicoil_train' : (2, 3),
       }

log = '1'
MODEL_NAME = 'model_1'

data_path = os.path.join(os.getcwd(), DATASET)
mri_paths = glob.glob(os.path.join(data_path, '*1.h5'))
log_paths = os.path.join(os.getcwd(), f'logs/{log}.txt')

In [3]:
# this block gets Dataset object with imaginary and real separated
def _get_kspace_and_reconstruction_rss(filename):
    """
    @params filename: full path to .h5 mri file
    @return kspace data of that particular file
    """
    try:
        with h5py.File(filename, 'r') as hr:
            return hr['kspace'][:], hr['reconstruction_rss'][:]
    except:
        print(f'Error could not open {filename}')

def _get_kspace_undersampled(kspace, center_fractions = [0.04], accelerations = [4]):
    """
    @params kspace: from _get_kspace_and_reconstruction_rss(filename)
    @params center_fractions: for undersampling, 
        N*center_fraction columns in center corresponding to low-frequencies
    @params accelerations: how much mri acquisition is sped up
    @return undersampled k-space
    """
    mask_func = RandomMaskFunc(
        center_fractions = center_fractions, 
        accelerations = accelerations
    )
    mask = np.array(mask_func(kspace.shape))
    return kspace * mask



def _get_mri_im_separated(
#     kspace, 
    reconstruction_rss,
    kspace_undersampled, 
    DATASET
):
    """
    separates imaginary from real values
    # @params kspace: from _get_kspace_and_reconstruction_rss(filename)
    @params reconstruction_rss: reconstructed MR image of fully sampled kspace, provided
    @params kspace_undersampled: mask-undersampled k-space from _get_kspace_undersampled
    @params DATASET: i.e. 'singlecoil_challenge' or 'multicoil_challenge'
    @return (undersampled mri image, fully sampled mri image (i.e. label for GAN))
    """
    undersampled_im = sp.ifft(kspace_undersampled, axes=AXES[DATASET])
#     fullysampled_im = sp.ifft(kspace, axes=AXES[DATASET])
    
    #crop to make sure images are all the size
    undersampled_crop = sp.resize(
        undersampled_im,
        [1, 32, 256, 256]
#         [1, 30, 320, 320] # [batch size, height, length, width]
    )
    
    undersampled_crop_real = tf.math.real(undersampled_crop)
    undersampled_crop_imag = tf.math.imag(undersampled_crop)
    
    undersampled_crop = np.stack(
        (undersampled_crop_real, undersampled_crop_imag),
        axis = 4,
    )
    
    
    fullysampled_crop = sp.resize(
        reconstruction_rss,
        [1, 32, 256, 256]
#         [1, 30, 320, 320]
    )
    
    
    return (
        undersampled_crop,
        fullysampled_crop,
    )
    



def get_datum_from_single_file_separated(filename, DATASET):
    """
    user-facing function for tf Dataset object
    @params filename: full path to .h5 mri file
    @params DATASET: i.e. 'singlecoil_challenge' or 'multicoil_challenge'
    @return (undersampled mri image, fully sampled mri image (i.e. label for GAN))
    """
    kspace, reconstruction_rss = _get_kspace_and_reconstruction_rss(filename)
    kspace_undersampled = _get_kspace_undersampled(kspace)
    return _get_mri_im_separated(
        reconstruction_rss,
        kspace_undersampled,
        DATASET,
    )




def get_data_from_files_separated(filenames, DATASET):  
    """
    user-facing function for tf Dataset object
    @params filenames: list of full paths to .h5 mri files
    @params DATASET: i.e. 'singlecoil_train' or 'multicoil_train'
    @return ndarray of 
        (undersampled mri image, fully sampled mri image (i.e. label for GAN))
    """
    undersampled_images = np.ones((1, 32, 256, 256, 2)) #[bn, h, l, w, c]
    fullysampled_images = np.ones((1, 32, 256, 256))
#     undersampled_images = np.ones((1, 30, 320, 320, 2)) #[bn, h, l, w, c]
#     fullysampled_images = np.ones((1, 30, 320, 320))
    for mri_path in filenames:
        try:
            # undersampled_crop has real and imag components
            undersampled_crop, fullysampled_crop = get_datum_from_single_file_separated(
                mri_path, DATASET
            )
               

            undersampled_images = np.vstack(
                (undersampled_images, undersampled_crop)
            )
            
            
            fullysampled_images = np.vstack(
                (fullysampled_images, fullysampled_crop)
            )
            
#             print (f'undersampled image shape: {undersampled_crop.shape}')           
#             print (f'undersampled images running total shape: {undersampled_images.shape}')
#             print (f'fully sampled images running total shape {fullysampled_images.shape} \n\n')

            
            
        except:
            print(f'could not open file {mri_path}')
#             send2trash(mri_path)
            print(f'sent file {mri_path} to trash')
    
    # reshape with extra one at the end for channel
    fullysampled_images = fullysampled_images.reshape(
        (-1, 32, 256, 256, 1)
    )


    return undersampled_images[1:], fullysampled_images[1:]

#

In [4]:
under_sampled_separated, fully_sampled_separated = get_data_from_files_separated(mri_paths, DATASET)
ds_separated = tf.data.Dataset.from_tensor_slices((under_sampled_separated, fully_sampled_separated))
ds_separated = ds_separated.shuffle(150, seed = 123, reshuffle_each_iteration = True)

In [5]:
for undersampled_im, fullysampled_im in ds_separated.take(20):
    undersampled_im = tf.reshape(undersampled_im, (-1, 32, 256, 256, 2))
    fullysampled_im = tf.reshape(fullysampled_im, (-1, 32, 256, 256, 1))
    print(f'undersampled size {undersampled_im.shape} fullysampled size {fullysampled_im.shape}')

undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)
undersampled size (1, 32, 256, 256, 2) fullysampled size (1, 32, 256, 256, 1)


In [6]:
import os

import numpy as np

import tensorflow as tf

from keras.models import Model, Input
from keras.optimizers import Adam

from keras.layers import Dense
from keras.layers import Conv3D, Conv3DTranspose
from keras.layers import Add, Concatenate
from keras.layers import Activation, LeakyReLU
from keras.layers import BatchNormalization, Lambda

from keras.utils import multi_gpu_model 
from keras import backend as K

from tensorflow.python.client import device_lib
devices = device_lib.list_local_devices()
gpus = [d for d in devices if d.name.lower().startswith('/device:gpu')]
print (f'using {len(gpus)} GPUs')

using 1 GPUs


In [7]:
def accm(y_true, y_pred):
    '''
    accuracy metric
    '''
    y_pred = K.clip(y_pred, -1, 1)
    return K.mean(K.equal(y_true, K.round(y_pred)))

def mssim(y_true, y_pred):
    '''
    mean structural similarity index
    '''
    costs = 1.0 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 2.0))
    return costs

def wloss(y_true, y_predict):
    '''
    Wasserstein loss
    '''
    return -K.mean(y_true * y_predict)

In [8]:
def discriminator(
    kernel_initz,
    inp_shape, # 3d, with channels last
    trainable = True,
):
    
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    inp = Input(shape = inp_shape) # 3d
    
    l0 = Conv3D(
        filters = 16, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(inp)
    l0 = LeakyReLU(alpha = 0.2)(l0)
    
    
    l1 = Conv3D(
        filters = 16 * 2, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l0)
    l1 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l1)
    l1 = LeakyReLU(alpha = 0.2)(l1)
    
    
    l2 = Conv3D(
        filters = 16 * 4, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l1)
    l2 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l2)
    l2 = LeakyReLU(alpha = 0.2)(l2)
    
    
    l3 = Conv3D(
        filters = 16 * 8, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz, 
        )(l2)
    l3 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l3)
    l3 = LeakyReLU(alpha = 0.2)(l3)
    
    
    l4 = Conv3D(
        filters = 16 * 16, kernel_size = 4, strides = (2, 2, 2),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l3)
    l4 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l4)
    l4 = LeakyReLU(alpha = 0.2)(l4)
    
    
    l7 = Conv3D(
        filters = 16 * 8, kernel_size = 1, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l4)
    l7 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l7)
    l7 = LeakyReLU(alpha = 0.2)(l7)
    
    
    l8 = Conv3D(
        filters = 16 * 4, kernel_size = 1, strides = (1, 1, 1), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l7)
    l8 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l8)
    l8 = LeakyReLU(alpha = 0.2)(l8)
    
    
    l9 = Conv3D(
        filters = 16 * 2, kernel_size = 3, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l8)
    l9 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l9)
    l9 = LeakyReLU(alpha = 0.2)(l9)
    
    
    l10 = Conv3D(
        filters = 16 * 8, kernel_size = 3, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l9)
    l10 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l10)
    l10 = LeakyReLU(alpha = 0.2)(l10)
    
    
    l11 = Add()([l7,l10])
    l11 = LeakyReLU(alpha = 0.2)(l11)
    
    
    out = Conv3D(
        filters = 1, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = 'he_normal',
        )(l11)
    
    
    model = Model(inputs = inp, outputs = out)
    return model

In [9]:
def resden(
    x, fil_lay, fil_end, beta, 
    gamma_init, kernel_initz, trainable,
):   
    
    x1 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x)
    x1 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x1)
    x1 = LeakyReLU(alpha = 0.2)(x1)
    x1=Concatenate(axis=-1)([x, x1])
    
    
    x2 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x1)
    x2 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x2)
    x2 = LeakyReLU(alpha = 0.2)(x2)
    x2 = Concatenate(axis = -1)([x1, x2])
     
        
    x3 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(x2)
    x3 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x3)
    x3 = LeakyReLU(alpha = 0.2)(x3)
    x3 = Concatenate(axis = -1)([x2, x3])
    
    
#     x4 = Conv3D(
#         filters = fil_lay, kernel_size = 3, strides = 1,
#         padding = 'same', kernel_initializer = kernel_initz, 
#         )(x3)
#     x4 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x4)
#     x4 = LeakyReLU(alpha = 0.2)(x4)
#     x4 = Concatenate(axis = -1)([x3, x4])
    
    
    x5 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x3)
    
    x5 = Lambda(lambda x: x * beta)(x5)
    
    xout = Add()([x5,x])
    return xout

def resresden(x, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable):
    
    x1 = resden(x,  fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
#     x2 = resden(x1, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    x3 = resden(x, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    
    x3 = Lambda(lambda x : x * beta)(x3)
    
    xout = Add()([x3,x])
    return xout

In [10]:
def generator(inp_shape, kernel_initz, trainable = True,):
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    fil_lay = 32
    fil_end = 512
    rrd_count = 12
    beta = 0.2

    
    inp_usamp_imag = Input(inp_shape) # (-1, 32, 256, 256, 2)
    
    
    lay_1dn = Conv3D(
        filters = 32, kernel_size = 4, strides = (1, 2, 2), # 32, 128, 128
        padding = 'same', kernel_initializer = kernel_initz,
        )(inp_usamp_imag)
    lay_1dn = LeakyReLU(alpha = 0.2)(lay_1dn)

    
    lay_2dn = Conv3D(
        filters = 64, kernel_size = 4, strides = (1, 2, 2), # 32, 64, 64
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_1dn)
    lay_2dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2dn)
    lay_2dn = LeakyReLU(alpha = 0.2)(lay_2dn)

    
    lay_3dn = Conv3D(
        filters = 128, kernel_size = 4, strides = (1, 2, 2), # 32, 32, 32
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_2dn)
    lay_3dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3dn)
    lay_3dn = LeakyReLU(alpha = 0.2)(lay_3dn)

    
    lay_4dn = Conv3D(
        filters = 256, kernel_size = 4, strides = (2, 2, 2), # 16, 16, 16
        padding = 'same', kernel_initializer = kernel_initz, 
        )(lay_3dn)
    lay_4dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4dn)
    lay_4dn = LeakyReLU(alpha = 0.2)(lay_4dn)  

    
    lay_5dn = Conv3D(
        filters = 256, kernel_size = 4, strides = (2, 2, 2), # 8, 8, 8
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_4dn)
    lay_5dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_5dn)
    lay_5dn = LeakyReLU(alpha = 0.2)(lay_5dn)


    c1 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(lay_5dn)
    
    xrrd = c1
    for _ in range(rrd_count):
        xrrd = resresden(xrrd, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable)

    c2 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(xrrd)
    
    
    lay_5upc = Add()([c1, c2]) # 8, 8, 8

    
    lay_4up = Conv3DTranspose(
        filters = 256, kernel_size = 4, strides = (2, 2, 2), # 16, 16, 16
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_5upc)
    lay_4up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4up)
    lay_4up = Activation('relu')(lay_4up) 

    lay_4upc = Concatenate(axis = -1)([lay_4up,lay_4dn]) 

    
    
    lay_3up = Conv3DTranspose(
        filters = 128, kernel_size = 4, strides = (2, 2, 2), # 32, 32, 32
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_4upc) 
    lay_3up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3up)
    lay_3up = Activation('relu')(lay_3up)

    lay_3upc = Concatenate(axis = -1)([lay_3up,lay_3dn])

    
    lay_2up = Conv3DTranspose(
        filters = 64, kernel_size = 4, strides = (1, 2, 2), #32, 64, 64
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_3upc)
    lay_2up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2up)
    lay_2up = Activation('relu')(lay_2up)

    lay_2upc = Concatenate(axis = -1)([lay_2up, lay_2dn])

    
    lay_1up = Conv3DTranspose(
        filters = 32, kernel_size = 4, strides = (1, 2, 2), #32, 128, 128
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_2upc)
    lay_1up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_1up)
    lay_1up = Activation('relu')(lay_1up) 

    lay_1upc = Concatenate(axis = -1)([lay_1up,lay_1dn])

    lay_256up = Conv3DTranspose(
        filters = 32, kernel_size = 4, strides = (1, 2, 2), #32, 256, 256
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_1upc)
    lay_256up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_256up)
    lay_256up = Activation('relu')(lay_256up)

    out = Conv3D(
        filters = 1, kernel_size = 1, strides = (1, 1, 1), activation = 'tanh', 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_256up)

    model = Model(inputs = inp_usamp_imag, outputs = out)

    return model

In [11]:
def define_gan_model(gen_model, dis_model, inp_shape):
        
    dis_model.trainable = False
    inp = Input(shape = inp_shape)
    out_g = gen_model(inp)
    out_dis = dis_model(out_g)
    model = Model(inputs = inp, outputs = [out_dis, out_g])
    model.summary()
    return model

In [41]:
def train(
    g_model, d_model, gan_model, 
    dataset, 
    n_epochs, n_batch, n_critic, 
    clip_val, n_patch, 
    f,
):
    
    bat_per_epo = int(np.ceil(tf.data.experimental.cardinality(dataset).numpy() / n_batch))
    half_batch = int(np.ceil(n_batch / 2))
    
    for i in range(n_epochs):
        for j in range(bat_per_epo):
            
            # training the discriminator
            for k in range(n_critic):
                X = np.ones((1, 32, 256, 256, 1)) #[h, l, w, c]
                y = np.ones((1, n_patch, n_patch, n_patch, 1))

                for usamp_data, fsamp_data in dataset.take(half_batch):
                    usamp_data = tf.reshape(usamp_data, (-1, 32, 256, 256, 2))
                    fsamp_data = tf.reshape(fsamp_data, (-1, 32, 256, 256, 1))
                    
#                     print(f'usamp shape: {usamp_data.shape}, fsamp shape: {fsamp_data.shape}')
                    
                    X_real = fsamp_data
                    X_fake = g_model.predict(usamp_data)
                    
                    y_real = np.ones((1, n_patch, n_patch, n_patch, 1))
                    y_fake = -np.ones((1, n_patch, n_patch, n_patch, 1))

                    X, y = np.vstack((X, X_real)), np.vstack((y, y_real))
                    X, y = np.vstack((X, X_fake)), np.vstack((y, y_fake))
                    
                    

                X, y = X[1:], y[1:] # take out first np.ones
                
#                 print(f'X shape: {X.shape}, y shape: {y.shape}')
                d_loss, accuracy = d_model.train_on_batch(X, y)

                for l in d_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -clip_val, clip_val) for w in weights]
                    l.set_weights(weights)

                    
            # training the generator
            X_usamps = np.ones((1, 32, 256, 256, 2)) #[h, l, w, c]
            X_fsamps = np.ones((1, 32, 256, 256, 1))

            for X_usamp, X_fsamp in dataset.take(n_batch):
                
                X_usamp = tf.reshape(X_usamp, (-1, 32, 256, 256, 2))
                X_fsamp = tf.reshape(X_fsamp, (-1, 32, 256, 256, 1))

                X_usamps = np.vstack((X_usamps, X_usamp))
                X_fsamps = np.vstack((X_fsamps, X_fsamp))
                
            X_usamps, X_fsamps = X_usamps[1:], X_fsamps[1:] # take out first np.ones
            y_gan = np.ones((n_batch, n_patch, n_patch, n_patch, 1))
            g_loss = gan_model.train_on_batch ([X_usamps], [y_gan, X_fsamps])
            
            
#             f.write(f'>epoch: {i+1}, batch: {(j+1)/bat_per_epo}, discriminator loss: {d_loss}, acc: {accuracy},  wasserstein: {g_loss[1]},  mae: {g_loss[2]},  mssim: {g_loss[3]}, generator_loss: {g_loss[0]}')
            
            f.write(f'''>epoch: {i+1}, batch: {np.round((j+1)/bat_per_epo, 6)} \n''')
            f.write(f'''    discriminator loss: {np.round(d_loss, 6)}, generator loss: {np.round(g_loss[0], 6)}, acc: {np.round(accuracy, 6)} \n\n\n''')        
    

            print (f'''>epoch: {i+1}, batch: {(j+1)/bat_per_epo}''')
            print(f'''    discriminator loss: {d_loss}, generator loss: {g_loss[0]}, acc: {accuracy} \n\n''')    
        
        filename = f'{MODEL_NAME}_epoch_{i+1}.h5'
        g_model.save(filename)
        
    f.close() 


In [42]:
K.clear_session()

# hyperparameters       
n_epochs = 1
n_batch = 1
n_critic = 1
clip_val = 0.05
in_shape_gen = (32, 256, 256, 2)
in_shape_dis = (32, 256, 256, 1)


# multiple GPUs (doesn't work, kernel crashes)
# strategy = tf.distribute.MirroredStrategy()
# print("Number of devices: {}".format(strategy.num_replicas_in_sync))

# with strategy.scope():
#     d_model = discriminator(inp_shape = in_shape_dis, kernel_initz = 'he_normal', trainable = True)
#     d_model = multi_gpu_model(d_model, gpus = 2, cpu_relocation = True)
#     opt = Adam(lr = 0.0002, beta_1 = 0.5)
#     d_model.compile(loss = wloss, optimizer = opt, metrics = [accm])
#     # d_model.summary()


#     g_model = generator(inp_shape = in_shape_gen, kernel_initz = 'he_normal', trainable = True)
#     g_model = multi_gpu_model(g_model, gpus = 2, cpu_relocation = True)


#     gan_model = define_gan_model(g_model, d_model, in_shape_gen)
#     opt1 = Adam(lr = 0.0001, beta_1 = 0.5)
#     gan_model.compile(loss = [wloss, 'mae', mssim], optimizer = opt1, loss_weights = [0.01, 20.0, 1.0])
#     # g_model.summary()



d_model = discriminator(inp_shape = in_shape_dis, kernel_initz = 'he_normal', trainable = True)
opt = Adam(lr = 0.0002, beta_1 = 0.5)
d_model.compile(loss = wloss, optimizer = opt, metrics = [accm])
# d_model.summary()


g_model = generator(inp_shape = in_shape_gen, kernel_initz = 'he_normal', trainable = True)

gan_model = define_gan_model(g_model, d_model, in_shape_gen)
opt1 = Adam(lr = 0.0001, beta_1 = 0.5)
gan_model.compile(loss = [wloss, 'mae', mssim], optimizer = opt1, loss_weights = [0.01, 20.0, 1.0])
# g_model.summary()


# other paramters
n_patch = d_model.output_shape[1]
f = open(log_paths, 'w')

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 32, 256, 256, 2)] 0         
_________________________________________________________________
model_1 (Model)              (None, 32, 256, 256, 1)   149392001 
_________________________________________________________________
model (Model)                (None, 8, 8, 8, 1)        3000785   
Total params: 152,392,786
Trainable params: 149,387,265
Non-trainable params: 3,005,521
_________________________________________________________________


In [43]:
train(
    g_model, d_model, gan_model, 
    ds_separated, 
    n_epochs, n_batch, n_critic, 
    clip_val, n_patch, 
    f
)

>epoch: 1, batch: 0.1111111111111111
    discriminator loss: 0.03282199427485466, generator loss: 15.623870849609375, acc: 0.29296875 


>epoch: 1, batch: 0.2222222222222222
    discriminator loss: -0.008487838320434093, generator loss: 14.36258602142334, acc: 0.0 


>epoch: 1, batch: 0.3333333333333333
    discriminator loss: -0.020223159343004227, generator loss: 10.282111167907715, acc: 0.0 


>epoch: 1, batch: 0.4444444444444444
    discriminator loss: -0.02248506061732769, generator loss: 6.822440147399902, acc: 0.0 


>epoch: 1, batch: 0.5555555555555556
    discriminator loss: -0.03181546926498413, generator loss: 5.372978687286377, acc: 0.0 


>epoch: 1, batch: 0.6666666666666666
    discriminator loss: -0.04207652062177658, generator loss: 4.410793304443359, acc: 0.0 


>epoch: 1, batch: 0.7777777777777778
    discriminator loss: -0.05107714980840683, generator loss: 3.803701877593994, acc: 0.0 


>epoch: 1, batch: 0.8888888888888888
    discriminator loss: -0.0605213157832622