npz file for each MRI

In [1]:
import os
import glob

import numpy as np

import tensorflow as tf

from keras.models import Model, Input
from keras.optimizers import Adam

from keras.layers import Dense
from keras.layers import Conv3D, Conv3DTranspose
from keras.layers import Add, Concatenate
from keras.layers import Activation, LeakyReLU
from keras.layers import BatchNormalization, Lambda

from keras import backend as K

from tensorflow.python.client import device_lib
devices = device_lib.list_local_devices()
gpus = [d for d in devices if d.name.lower().startswith('/device:gpu')]
print (f'using {len(gpus)} GPUs')
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

using 1 GPUs
Num GPUs Available:  1


In [2]:
MODEL_NAME = '1'
DATASET = 'singlecoil_train' # CHANGE THIS TO TRAIN
log_path = os.path.join('/central/groups/BEBi_205_Spring_2021/vliu/logs/train', f'model_{MODEL_NAME}.txt')
models_path = os.path.join('/central/groups/BEBi_205_Spring_2021/vliu', 'models')
npz_data_path = glob.glob(
    os.path.join(
        f'/central/groups/BEBi_205_Spring_2021/vliu/dataset_objects/{DATASET}', 
        '*.npz')
)

In [3]:
def acc(y_true, y_pred):
    y_pred=K.clip(y_pred, -1, 1)
    return K.mean(K.equal(y_true, K.round(y_pred)))

# def mssim(y_true, y_pred):
#     costs = 1.0 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 2.0))
#     return costs

def wloss(y_true,y_predict):
    return -K.mean(y_true*y_predict)

def mssim(y_true, y_pred):
    '''
    mean structural similarity index
    '''
    im_slices = y_true.shape[0]
    mssim_running = 0
    
    for i in range(im_slices):
        costs = 1.0 - tf.reduce_mean(tf.image.ssim(y_true[i], y_pred[i], 2.0))
        mssim_running += costs
    return mssim_running

def wloss(y_true, y_predict):
    '''
    Wasserstein loss
    '''
    
    im_slices = y_true.shape[0]
    wass_running = 0
    
    for i in range(im_slices):
        wass_running += -K.mean(y_true[i] * y_predict[i])
    return wass_running

In [4]:
def discriminator(
    kernel_initz,
    inp_shape, # 3d, with channels last
    trainable = True,
):
    
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    inp = Input(shape = inp_shape) # 3d
    
    l0 = Conv3D(
        filters = 8, kernel_size = 4, strides = (1, 2, 2), # 32, 128, 128
        padding = 'same', kernel_initializer = kernel_initz,
        )(inp)
    l0 = LeakyReLU(alpha = 0.2)(l0)
    
    
    l1 = Conv3D(
        filters = 8 * 2, kernel_size = 4, strides = (1, 2, 2), #32, 64, 64
        padding = 'same', kernel_initializer = kernel_initz,
        )(l0)
    l1 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l1)
    l1 = LeakyReLU(alpha = 0.2)(l1)
    
    
    l2 = Conv3D(
        filters = 8 * 4, kernel_size = 4, strides = (1, 2, 2), #32, 32, 32
        padding = 'same', kernel_initializer = kernel_initz,
        )(l1)
    l2 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l2)
    l2 = LeakyReLU(alpha = 0.2)(l2)
    
    
    l3 = Conv3D(
        filters = 8 * 8, kernel_size = 4, strides = (2, 2, 2), #16, 16, 16
        padding = 'same', kernel_initializer = kernel_initz, 
        )(l2)
    l3 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l3)
    l3 = LeakyReLU(alpha = 0.2)(l3)
    
    
    l4 = Conv3D(
        filters = 8 * 16, kernel_size = 4, strides = (2, 2, 2), #8, 8, 8
        padding = 'same', kernel_initializer = kernel_initz,
        )(l3)
    l4 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l4)
    l4 = LeakyReLU(alpha = 0.2)(l4)
    
    
    l7 = Conv3D(
        filters = 8 * 8, kernel_size = 1, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l4)
    l7 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l7)
    l7 = LeakyReLU(alpha = 0.2)(l7)
    
    
    l8 = Conv3D(
        filters = 8 * 4, kernel_size = 1, strides = (1, 1, 1), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l7)
    l8 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l8)
    l8 = LeakyReLU(alpha = 0.2)(l8)
    
    
    l9 = Conv3D(
        filters = 8 * 2, kernel_size = 3, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l8)
    l9 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l9)
    l9 = LeakyReLU(alpha = 0.2)(l9)
    
    
    l10 = Conv3D(
        filters = 8 * 8, kernel_size = 3, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l9)
    l10 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l10)
    l10 = LeakyReLU(alpha = 0.2)(l10)
    
    
    l11 = Add()([l7,l10])
    l11 = LeakyReLU(alpha = 0.2)(l11)
    
    
    out = Conv3D(
        filters = 1, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(l11)
    
    
    model = Model(inputs = inp, outputs = out)
    return model

In [5]:
def resden(
    x, fil_lay, fil_end, beta, 
    gamma_init, kernel_initz, trainable,
):   
    
    x1 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x)
    x1 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x1)
    x1 = LeakyReLU(alpha = 0.2)(x1)
    x1=Concatenate(axis=-1)([x, x1])
    
    
    x2 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x1)
    x2 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x2)
    x2 = LeakyReLU(alpha = 0.2)(x2)
    x2 = Concatenate(axis = -1)([x1, x2])
     
        
    x3 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(x2)
    x3 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x3)
    x3 = LeakyReLU(alpha = 0.2)(x3)
    x3 = Concatenate(axis = -1)([x2, x3])
    
    
    x4 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x3)
    
    x4 = Lambda(lambda x: x * beta)(x4)
    
    xout = Add()([x4,x])
    return xout

def resresden(x, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable):
    
    x1 = resden(x,  fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
#     x2 = resden(x1, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    x3 = resden(x1, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    
    x3 = Lambda(lambda x : x * beta)(x3)
    
    xout = Add()([x3,x])
    return xout

In [6]:
def generator(inp_shape, kernel_initz, trainable = True,):
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    fil_lay = 64
    fil_end = 512 
    rrd_count = 6
    beta = 0.2

    
    inp_usamp_imag = Input(inp_shape) # (-1, 32, 256, 256, 2)
    
    
    lay_1dn = Conv3D(
        filters = 32 * 2, kernel_size = 4, strides = (1, 2, 2), # 32, 128, 128
        padding = 'same', kernel_initializer = kernel_initz,
        )(inp_usamp_imag)
    lay_1dn = LeakyReLU(alpha = 0.2)(lay_1dn)

    
    lay_2dn = Conv3D(
        filters = 64 * 2, kernel_size = 4, strides = (1, 2, 2), # 32, 64, 64
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_1dn)
    lay_2dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2dn)
    lay_2dn = LeakyReLU(alpha = 0.2)(lay_2dn)

    
    lay_3dn = Conv3D(
        filters = 128 * 2, kernel_size = 4, strides = (1, 2, 2), # 32, 32, 32
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_2dn)
    lay_3dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3dn)
    lay_3dn = LeakyReLU(alpha = 0.2)(lay_3dn)

    
    lay_4dn = Conv3D(
        filters = 256 * 2, kernel_size = 4, strides = (2, 2, 2), # 16, 16, 16
        padding = 'same', kernel_initializer = kernel_initz, 
        )(lay_3dn)
    lay_4dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4dn)
    lay_4dn = LeakyReLU(alpha = 0.2)(lay_4dn)  

    
    lay_5dn = Conv3D(
        filters = 256 * 2, kernel_size = 4, strides = (2, 2, 2), # 8, 8, 8
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_4dn)
    lay_5dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_5dn)
    lay_5dn = LeakyReLU(alpha = 0.2)(lay_5dn)


    c1 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(lay_5dn)
    
    xrrd = c1
    for _ in range(rrd_count):
        xrrd = resresden(xrrd, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable)

    c2 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(xrrd)
    
    
    lay_5upc = Add()([c1, c2]) # 8, 8, 8

    
    lay_4up = Conv3DTranspose(
        filters = 256 * 2, kernel_size = 4, strides = (2, 2, 2), # 16, 16, 16
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_5upc)
    lay_4up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4up)
    lay_4up = Activation('relu')(lay_4up) 

    lay_4upc = Concatenate(axis = -1)([lay_4up,lay_4dn]) 

    
    
    lay_3up = Conv3DTranspose(
        filters = 128 * 2, kernel_size = 4, strides = (2, 2, 2), # 32, 32, 32
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_4upc) 
    lay_3up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3up)
    lay_3up = Activation('relu')(lay_3up)

    lay_3upc = Concatenate(axis = -1)([lay_3up,lay_3dn])

    
    lay_2up = Conv3DTranspose(
        filters = 64 * 2, kernel_size = 4, strides = (1, 2, 2), #32, 64, 64
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_3upc)
    lay_2up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2up)
    lay_2up = Activation('relu')(lay_2up)

    lay_2upc = Concatenate(axis = -1)([lay_2up, lay_2dn])

    
    lay_1up = Conv3DTranspose(
        filters = 32 * 2, kernel_size = 4, strides = (1, 2, 2), #32, 128, 128
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_2upc)
    lay_1up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_1up)
    lay_1up = Activation('relu')(lay_1up) 

    lay_1upc = Concatenate(axis = -1)([lay_1up,lay_1dn])

    lay_256up = Conv3DTranspose(
        filters = 32, kernel_size = 4, strides = (1, 2, 2), #32, 256, 256
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_1upc)
    lay_256up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_256up)
    lay_256up = Activation('relu')(lay_256up)

    out = Conv3D(
        filters = 1, kernel_size = 1, strides = (1, 1, 1), activation = 'tanh', 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_256up)

    model = Model(inputs = inp_usamp_imag, outputs = out)

    return model



# # contracted
# def generator(inp_shape, kernel_initz, trainable = True):
    
#     gamma_init = tf.random_normal_initializer(1., 0.02)
#     fil_lay = 32
#     fil_end = 64
#     rrd_count = 6
#     beta = 0.2

    
#     inp_usamp_imag = Input(inp_shape) # (-1, 32, 256, 256, 2)
    
    
#     lay_1dn = Conv3D(
#         filters = 8, kernel_size = 4, strides = (2, 4, 4), # 16, 64, 64
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(inp_usamp_imag)
#     lay_1dn = LeakyReLU(alpha = 0.2)(lay_1dn)

    
#     lay_2dn = Conv3D(
#         filters = 16, kernel_size = 4, strides = (2, 2, 2), # 8, 32, 32
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_1dn)
#     lay_2dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2dn)
#     lay_2dn = LeakyReLU(alpha = 0.2)(lay_2dn)

    
#     lay_3dn = Conv3D(
#         filters = 32, kernel_size = 4, strides = (1, 2, 2), # 8, 16, 16
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_2dn)
#     lay_3dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3dn)
#     lay_3dn = LeakyReLU(alpha = 0.2)(lay_3dn)

    
#     lay_4dn = Conv3D(
#         filters = 64, kernel_size = 4, strides = (1, 2, 2), # 8, 8, 8
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_3dn)
#     lay_4dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4dn)
#     lay_4dn = LeakyReLU(alpha = 0.2)(lay_4dn)


#     c1 = Conv3D(
#         filters = fil_end, kernel_size = 3, strides = 1,
#         padding = 'same', kernel_initializer = kernel_initz, 
#         )(lay_4dn)
    
#     xrrd = c1
#     for _ in range(rrd_count):
#         xrrd = resresden(xrrd, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable)

#     c2 = Conv3D(
#         filters = fil_end, kernel_size = 3, strides = 1,
#         padding = 'same', kernel_initializer = kernel_initz, 
#         )(xrrd)
    
    
#     lay_4upc = Add()([c1, c2]) # 8, 8, 8

    
#     lay_3up = Conv3DTranspose(
#         filters = 32, kernel_size = 4, strides = (1, 2, 2), # 8, 16, 16
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_4upc)
#     lay_3up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3up)
#     lay_3up = Activation('relu')(lay_3up) 

#     lay_3upc = Concatenate(axis = -1)([lay_3up,lay_3dn]) 

    
    
#     lay_2up = Conv3DTranspose(
#         filters = 16, kernel_size = 4, strides = (1, 2, 2), # 8, 32, 32
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_3upc) 
#     lay_2up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2up)
#     lay_2up = Activation('relu')(lay_2up)

#     lay_2upc = Concatenate(axis = -1)([lay_2up,lay_2dn])

    
#     lay_1up = Conv3DTranspose(
#         filters = 8, kernel_size = 4, strides = (2, 2, 2), # 16, 64, 64
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_2upc)
#     lay_1up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_1up)
#     lay_1up = Activation('relu')(lay_1up)

#     lay_1upc = Concatenate(axis = -1)([lay_1up, lay_1dn])


#     lay_256up = Conv3DTranspose(
#         filters = 8, kernel_size = 4, strides = (2, 4, 4), #32, 256, 256
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_1upc)
#     lay_256up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_256up)
#     lay_256up = Activation('relu')(lay_256up)

#     out = Conv3D(
#         filters = 1, kernel_size = 1, strides = (1, 1, 1), activation = 'tanh', 
#         padding = 'same', kernel_initializer = kernel_initz,
#         )(lay_256up)

#     model = Model(inputs = inp_usamp_imag, outputs = out)

#     return model

In [7]:
def define_gan_model(gen_model, dis_model, inp_shape):
        
    dis_model.trainable = False
    inp = Input(shape = inp_shape)
    out_g = gen_model(inp)
    out_dis = dis_model(out_g)
    model = Model(inputs = inp, outputs = [out_dis, out_g])
    model.summary()
    return model

In [16]:
def train(
    g_model, d_model, gan_model, 
    data_filenames_all, 
    n_epochs, n_batch, n_unrolled, 
    clip_val, n_patch, 
    f,
):
    '''
    @params data_filenames_all: glob
    '''
    # we already know batch size 
    data_sample_count = len(data_filenames_all)
    bat_per_epo = int(np.ceil(data_sample_count / n_batch))
#     bat_per_epo = 50
    half_batch = int(np.ceil(n_batch / 2))
    
    for i in range(n_epochs):
        print (f'EPOCH {i+1} \n\n')
        for j in range(bat_per_epo):
            
            # training the discriminator
            for k in range(n_unrolled):
                
                # get indices for X_fake
                disc_idx1 = np.random.choice(
                    data_sample_count, 
                    size = half_batch, 
                    replace = False
                )
                
                # get indices for X_real
                disc_idx2 = np.random.choice(
                    data_sample_count, 
                    size = half_batch, 
                    replace = False
                )
                
                X = np.ones((2 * half_batch, 32, 256, 256, 1))
                
                # set data
                for v in range(half_batch):
                    usamp_data_filename = data_filenames_all[disc_idx1[v]]
                    fsamp_data_filename = data_filenames_all[disc_idx2[v]]
                
                    npzfile1 = np.load(f'{usamp_data_filename}')
                    usamp = npzfile1['arr_0']
                    X_fake = g_model.predict(usamp)
                    
                    npzfile2 = np.load(f'{fsamp_data_filename}')
                    X_real = npzfile2['arr_1']  
                    
                    X[v], X[v + half_batch] = X_real, X_fake
                    
                y_real = np.ones((half_batch, n_patch, n_patch, n_patch, 1))
                y_fake = -np.ones((half_batch, n_patch, n_patch, n_patch, 1))
                y = np.vstack((y_real, y_fake))
                
                d_loss, accuracy = d_model.train_on_batch(X, y)

                for l in d_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -clip_val, clip_val) for w in weights]
                    l.set_weights(weights)

                    
            # training the generator

            # get rando indices
            gen_idx = np.random.choice(
                data_sample_count, 
                size = n_batch, 
                replace = False
            )
            
            X_usamps = np.ones((n_batch, 32, 256, 256, 2))
            X_fsamps = np.ones((n_batch, 32, 256, 256, 1))

            # set data
            for v in range(n_batch):
                data_filename = data_filenames_all[gen_idx[v]]

                npzfile = np.load(f'{data_filename}')
                X_usamp = npzfile['arr_0']
                X_fsamp = npzfile['arr_1']  

                X_usamps[v], X_fsamps[v] = X_usamp, X_fsamp
            
            y_gan = np.ones((n_batch, n_patch, n_patch, n_patch, 1))
            g_loss = gan_model.train_on_batch ([X_usamps], [y_gan, X_fsamps]) #y_gan
            
            
            f.write(f'''>epoch: {i+1}, batch: {np.round((j+1)/bat_per_epo, 6)} \n''')
            f.write(f'''    discriminator loss: {np.round(d_loss, 6)}, generator loss: {np.round(g_loss[0], 6)}, MAE: {np.round(accuracy, 6)} \n\n''')        
            
            
            if (i) % 10 == 0:
                if (j+1) % 100 == 0:
                    print (f'''>epoch: {i+1}, batch: {(j+1)/bat_per_epo}''')
                    print(f'''    discriminator loss: {d_loss}, generator loss: {g_loss[0]}, MAE: {accuracy}\n''')    
        
        filename = os.path.join(models_path, f'{MODEL_NAME}_epoch_{i+1}.h5')
        if os.path.isfile(filename):
            os.remove(filename)
            
        g_model.save_weights(filename)
        
    f.close() 
    print(f'check out saved models at {models_path}')


In [17]:
K.clear_session()

# hyperparameters       
n_epochs = 50
n_batch = 1
n_unrolled = 3
clip_val = 0.05
in_shape_gen = (32, 256, 256, 2)
in_shape_dis = (32, 256, 256, 1)


#models
d_model = discriminator(inp_shape = in_shape_dis, kernel_initz = 'glorot_uniform', trainable = True)
opt = Adam(lr = 0.0001)
d_model.compile(loss = wloss, optimizer = opt, metrics = 'mae')
# d_model.summary()


g_model = generator(inp_shape = in_shape_gen, kernel_initz = 'glorot_uniform', trainable = True)

gan_model = define_gan_model(g_model, d_model, in_shape_gen)
opt1 = Adam(lr = 0.0002)
gan_model.compile(loss = [wloss, 'mae', mssim], optimizer = opt1, loss_weights = [0.01, 20.0, 1.0])
# g_model.summary()


# other paramters
n_patch = d_model.output_shape[1]
f = open(log_path, 'w')

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 32, 256, 256, 2)] 0         
_________________________________________________________________
model_1 (Model)              (None, 32, 256, 256, 1)   233658433 
_________________________________________________________________
model (Model)                (None, 8, 8, 8, 1)        752361    
Total params: 234,410,794
Trainable params: 233,649,025
Non-trainable params: 761,769
_________________________________________________________________


In [None]:
train(g_model, d_model, gan_model, 
    npz_data_path, 
    n_epochs, n_batch, n_unrolled, 
    clip_val, n_patch, 
    f,)

EPOCH 1 


>epoch: 1, batch: 0.1111111111111111
    discriminator loss: -6.014625072479248, generator loss: 0.18149298429489136, MAE: 2.0136260986328125

>epoch: 1, batch: 0.2222222222222222
    discriminator loss: -5.581221580505371, generator loss: 0.1530897319316864, MAE: 1.881453514099121

>epoch: 1, batch: 0.3333333333333333
    discriminator loss: -9.854053497314453, generator loss: 0.09769800305366516, MAE: 3.92920184135437

>epoch: 1, batch: 0.4444444444444444
    discriminator loss: 6.740375518798828, generator loss: 0.30550748109817505, MAE: 4.370187282562256

>epoch: 1, batch: 0.5555555555555556
    discriminator loss: -10.918370246887207, generator loss: -0.07457803189754486, MAE: 4.4591851234436035

>epoch: 1, batch: 0.6666666666666666
    discriminator loss: -5.862706661224365, generator loss: 0.3957073390483856, MAE: 2.467292070388794

>epoch: 1, batch: 0.7777777777777778
    discriminator loss: -12.732583999633789, generator loss: -0.10126792639493942, MAE: 5.3662924766

In [133]:
def train_gen(
    g_model,
    data_filenames_all, 
    n_epochs, n_batch, 
    f,
):
    '''
    @params data_filenames_all: glob
    '''
    # we already know batch size 
    data_sample_count = len(data_filenames_all)
    bat_per_epo = int(np.ceil(data_sample_count / n_batch))
#     bat_per_epo = 50
    half_batch = int(np.ceil(n_batch / 2))
    
    for i in range(n_epochs):
        print (f'EPOCH {i+1} \n\n')
        for j in range(bat_per_epo):
                    
            # training the generator

            # get rando indices
            gen_idx = np.random.choice(
                data_sample_count, 
                size = n_batch, 
                replace = False
            )
            
            X_usamps = np.ones((n_batch, 32, 256, 256, 2))
            X_fsamps = np.ones((n_batch, 32, 256, 256, 1))

            # set data
            for v in range(n_batch):
                data_filename = data_filenames_all[gen_idx[v]]

                npzfile = np.load(f'{data_filename}')
                X_usamp = npzfile['arr_0']
                X_fsamp = npzfile['arr_1']  

                X_usamps[v], X_fsamps[v] = X_usamp, X_fsamp
            
            g_loss = g_model.train_on_batch(X_usamps, X_fsamps) #y_gan
            
            
            f.write(f'''>epoch: {i+1}, batch: {np.round((j+1)/bat_per_epo, 6)} \n''')
            f.write(f'''    loss: {np.round(g_loss, 6)}''')        
            
            
            if (i+1) % 100 == 0:
                if (j+1) % 100 == 0:
                    print (f'''>epoch: {i+1}, batch: {(j+1)/bat_per_epo}''')
                    print(f'''    loss: {g_loss}\n''')    
        
        filename = os.path.join(models_path, f'{MODEL_NAME}_epoch_{i+1}.h5')
        if os.path.isfile(filename):
            os.remove(filename)
            
        g_model.save_weights(filename)
        
    f.close() 
    print(f'check out saved models at {models_path}')



In [134]:
# g_model = generator(inp_shape = in_shape_gen, kernel_initz = 'he_normal', trainable = True)

# opt1 = Adam(lr = 0.0005)
# g_model.compile(loss = [wloss, 'mae', mssim], optimizer = opt1, loss_weights = [0.01, 20.0, 1.0])

# n_epochs = 200
# n_batch = 1
# f = open(log_path, 'w')
# train_gen(
#     g_model, 
#     npz_data_path,
#     n_epochs, n_batch, 
#     f
# )