In [None]:
gpu_id       = 1
iter_display = 20000
iter_save    = 100000
iter_max     = 1000000
ch_input     = 1       # Number of input channels
ch_output    = 1       # Number of output channels
lim_hmi      = 3000.   # HMI data range
isize        = 256     # Image size
bsize        = 16      # Batch size
use_fm_loss  = False   # Feature mapping loss, default=False
use_l1_loss  = True    # L1 loss, default = True

root_data    = 'path_to_data'
root_save    = 'path_to_save'

In [None]:
root_model      = '%s/model' % (root_save)
root_validation = '%s/validation' % (root_save)
root_test       = '%s/test' % (root_save)

import os
os.makedirs(root_model, exist_ok=True)
os.makedirs(root_validation, exist_ok=True)
os.makedirs(root_test, exist_ok=True)

In [None]:
import tensorflow as tf
import tensorflow.keras as keras

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu_id], 'GPU')
tf.config.experimental.set_memory_growth(gpus[gpu_id], True)

In [None]:
""" data tree

root_data - train - center (input)
                  - stacks (target)

          - validation - center
                       - stacks

          - test - center
                 - stacks
"""


from glob import glob

list_train_input  = sorted(glob('%s/train/center/*.npy'%(root_data)))
list_train_output = sorted(glob('%s/train/stacks/*.npy'%(root_data)))
list_train        = list(zip(list_train_input, list_train_output))
nb_train          = len(list_train)

list_validation   = sorted(glob('%s/validation/center/*.npy'%(root_data)))
nb_validation     = len(list_validation)

list_test         = sorted(glob('%s/test/center/*.npy'%(root_data)))
nb_test           = len(list_test)

print(nb_train, nb_validation, nb_test)

In [None]:
import numpy as np
from utils_data import bytescale
from random import shuffle

def make_tensor(file_):
    x = np.load(file_)[None, :, :, None]/lim_hmi
    return x

def make_output(gen_):
    x = gen_.numpy().reshape(isize, isize)*lim_hmi
    x_png = bytescale(x, imin=-30, imax=30)
    return x, x_png

def train_batch_generator():
    epoch = i = 0
    size = bsize
    while True:
        if i + size > nb_train :
            shuffle(list_train)
            i = 0
            epoch += 1
        batch_A = np.concatenate([make_tensor(list_train[j][0]) for j in range(i, i+size)], 0)
        batch_B = np.concatenate([make_tensor(list_train[j][1]) for j in range(i, i+size)], 0)
        i += size
        yield epoch, tf.cast(batch_A, tf.float32), tf.cast(batch_B, tf.float32)

def check_train_batch_generator():
    train_batch = train_batch_generator()
    for n in range(5):
        epoch, batch_A, batch_B = next(train_batch)
        print(epoch, batch_A.shape, batch_B.shape)
    
#check_train_batch_generator() to check train_batch_generator()

In [None]:
from networks import generator, discriminator
network_G = generator(isize, ch_input, ch_output)
network_D = discriminator(isize, ch_input, ch_output)

network_G.summary()
network_D.summary()

In [None]:
def loss_function(target, output):
    return -tf.math.reduce_mean(tf.math.log(output+1e-12)*target+tf.math.log(1-output+1e-12)*(1-target))

#lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(1e-4,decay_steps=1024,decay_rate=0.96)
optimizer_D = tf.keras.optimizers.Adam(2.0e-4, beta_1=0.5)
optimizer_G = tf.keras.optimizers.Adam(2.0e-4, beta_1=0.5)

@tf.function
def train_step(real_A, real_B, fm_weight=10, l1_weight=100):
    with tf.GradientTape() as tape_G, tf.GradientTape() as tape_D:

        fake_B = network_G(real_A, training=True)
        
        output_D_real = network_D([real_A, real_B], training=True)
        output_D_fake = network_D([real_A, fake_B], training=True)
        
        loss_D_real = loss_function(target=tf.ones_like(output_D_real[0]), output=output_D_real[0])
        loss_D_fake = loss_function(target=tf.zeros_like(output_D_fake[0]), output=output_D_fake[0])
        loss_D = (loss_D_real + loss_D_fake)/2.

        loss_G_fake = loss_function(target=tf.ones_like(output_D_fake[0]), output=output_D_fake[0])
        
        feature_real = output_D_real[1:]
        feature_fake = output_D_fake[1:]
        
        loss_F = 0
        for i in range(len(feature_fake)):
            loss_F += tf.math.reduce_mean(tf.abs(feature_fake[i]-feature_real[i]))
        loss_F *= fm_weight
        
        loss_L = tf.reduce_mean(tf.abs(real_B - fake_B)) * l1_weight
        
        loss_G = loss_G_fake
        if use_fm_loss :
            loss_G += loss_G_feature
        if use_l1_loss :
            loss_G += loss_L
        
    gradient_G = tape_G.gradient(loss_G, network_G.trainable_variables)
    gradient_D = tape_D.gradient(loss_D, network_D.trainable_variables)

    optimizer_G.apply_gradients(zip(gradient_G, network_G.trainable_variables))
    optimizer_D.apply_gradients(zip(gradient_D, network_D.trainable_variables))

    return loss_D, loss_G_fake, loss_F, loss_L

@tf.function
def generation_step(real_A):
    return network_G(tf.cast(real_A, tf.float32), training=False)

In [None]:
import time
def t_now():
    TM = time.localtime(time.time())
    return '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)

print('\n------------------------------------ Summary ------------------------------------\n')

print('\n%s: Now start below session!\n'%(t_now()))
print('Model save path: %s'%(root_model))
print('Validation snap save path: %s'%(root_validation))
print('Test result save path: %s'%(root_test))
print('# of train, validation, and test datasets : %d, %d, %d'%(nb_train, nb_validation, nb_test))

print('\n---------------------------------------------------------------------------------\n')

In [None]:
from imageio import imsave

train_batch = train_batch_generator()

t0 = time.time()
epoch = iter_gen = 0
err_D = err_D_sum = err_D_mean = 0
err_G = err_G_sum = err_G_mean = 0
err_F = err_F_sum = err_F_mean = 0
err_L = err_L_sum = err_L_mean = 0

while iter_gen <= iter_max :

    epoch, train_A, train_B = next(train_batch)
    err_D, err_G, err_F, err_L = train_step(train_A, train_B)
    
    err_D_sum += err_D*bsize
    err_G_sum += err_G*bsize
    err_F_sum += err_F*bsize
    err_L_sum += err_L*bsize
    
    iter_gen += bsize
    
    if iter_gen % iter_display == 0:

        err_D_mean = err_D_sum/iter_display
        err_G_mean = err_G_sum/iter_display
        err_F_mean = err_F_sum/iter_display
        err_L_mean = err_L_sum/iter_display
        
        message1 = '[%d][%d/%d]' % (epoch, iter_gen, iter_max)
        message2 = 'Loss_D: %5.3f Loss_G: %5.3f Loss_F: %5.3f Loss_L: %5.3f T: %dsec/%dits'%(err_D_mean, err_G_mean, err_F_mean, err_L_mean, time.time()-t0, iter_display)
        print('%s: %s %s'%(t_now(), message1, message2))

        err_G_sum, err_D_sum, err_F_sum, err_L_sum = 0, 0, 0, 0
        t0 = time.time()

    if iter_gen % iter_save == 0:
        
        network_G.save('%s/denoising.gan.%07d.G.h5'%(root_model, iter_gen))
        network_D.save('%s/denoising.gan.%07d.D.h5'%(root_model, iter_gen))
        message3 = 'network_G and network_D are saved under %s'%(root_model)
        print('%s: %s'%(t_now(), message3))
        
        path_validation = '%s/iter_%07d'%(root_validation, iter_gen)
        path_test = '%s/iter_%07d'%(root_test, iter_gen)
        os.makedirs(path_validation, exist_ok=True)
        os.makedirs(path_test, exist_ok=True)
        
        for file_A in list_validation :
            date = file_A.split('.')[-2]
            fake_B = generation_step(make_tensor(file_A))
            fake_B, fake_B_png = make_output(fake_B)
            name = 'denoising.gan.%07d.%s'%(iter_gen, date)
            np.save('%s/%s.npy'%(path_validation, name), fake_B)
            imsave('%s/%s.png'%(path_validation, name), fake_B_png)

        message4 = 'Validation snaps (%d images) are saved in %s'%(nb_validation, path_validation)
        print('%s: %s'%(t_now(), message4))
        
        for file_A in list_test :
            date = file_A.split('.')[-2]
            fake_B = generation_step(make_tensor(file_A))
            fake_B, fake_B_png = make_output(fake_B)
            name = 'denoising.gan.%07d.%s'%(iter_gen, date)
            np.save('%s/%s.npy'%(path_test, name), fake_B)
            imsave('%s/%s.png'%(path_test, name), fake_B_png)

        message5 = 'Test results (%d images) are saved in %s'%(nb_test, path_test)
        print('%s: %s'%(t_now(), message5))
        
        t0 = time.time()    