In [None]:
gpu_id = 3

In [None]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu_id], 'GPU')
tf.config.experimental.set_memory_growth(gpus[gpu_id], True)

In [None]:
iter_display = 2000
iter_save = 10000
iter_max = 500000

name_input = 'center'
name_output = 'stacks'

root_data = '/home/park_e/datasets'
root_save = '/userhome/park_e/denoising_tf2'

ch_input = 1
ch_output = 1

lim_hmi = 100.

isize=256
bsize = 1
layer_max_d = 3

In [None]:
mode = 'unet_cgan%d'%(layer_max_d)

root_model = '%s/%s/model' % (root_save, mode)
root_validation = '%s/%s/validation' % (root_save, mode)
root_test = '%s/%s/test' % (root_save, mode)

import os
os.makedirs(root_model, exist_ok=True)
os.makedirs(root_validation, exist_ok=True)
os.makedirs(root_test, exist_ok=True)

In [None]:
from glob import glob

list_train_input = sorted(glob('%s/train/hmi/%s/*.npy'%(root_data, name_input)))
list_train_output = sorted(glob('%s/train/hmi/%s/*.npy'%(root_data, name_output)))
assert len(list_train_input) == len(list_train_output)
list_train = list(zip(list_train_input, list_train_output))
nb_train = len(list_train)

list_validation_input = sorted(glob('%s/validation/hmi/%s/*.npy'%(root_data, name_input)))
list_validation_output = sorted(glob('%s/validation/hmi/%s/*.npy'%(root_data, name_output)))
assert len(list_validation_input) == len(list_validation_output)
list_validation = list(zip(list_validation_input, list_validation_output))
nb_validation = len(list_validation)


list_test_input = sorted(glob('%s/test/hmi/%s/*.npy'%(root_data, name_input)))
list_test_output = sorted(glob('%s/test/hmi/%s/*.npy'%(root_data, name_output)))
assert len(list_test_input) == len(list_test_output)
list_test = list(zip(list_test_input, list_test_output))
nb_test = len(list_test)

print(nb_train, nb_validation, nb_test)

In [None]:
import numpy as np
from utils import bytescale
from imageio import imsave

def make_tensor(file_):
    x = np.load(file_)[None, :, :, None].clip(-lim_hmi,lim_hmi)/lim_hmi
    return x

def make_output(gen_, tar_):
    x = gen_.reshape(isize, isize)*lim_hmi
    x_png = np.hstack((bytescale(x, imin=-30, imax=30),
                       bytescale(x, imin=-100, imax=100)))
    y = np.load(tar_)
    y_png = np.hstack((bytescale(y, imin=-30, imax=30),
                       bytescale(y, imin=-100, imax=100)))
    
    df_png = np.hstack((bytescale(x.clip(-30, 30)-y.clip(-30, 30), imin=-30, imax=30),
                       bytescale(x.clip(-100, 100)-y.clip(-100, 100), imin=-100, imax=100)))

    xy_png = np.vstack((x_png, y_png, df_png))
    
    return x, xy_png   

In [None]:
from random import shuffle

def train_batch_generator():
    epoch = i = 0
    tmpsize = None
    while True:
        size = tmpsize if tmpsize else bsize
        if i + size > nb_train :
            shuffle(list_train)
            i = 0
            epoch += 1
        batch_A = np.concatenate([make_tensor(list_train[j][0]) for j in range(i, i+size)], 0)
        batch_B = np.concatenate([make_tensor(list_train[j][1]) for j in range(i, i+size)], 0)
        i += size
        tmpsize = yield epoch, batch_A, batch_B

train_batch = train_batch_generator()

#for n in range(10):
#    A, B, C = next(train_batch)
#    print(B.shape, C.shape)

In [None]:
from networks_tf2 import unet_generator, patch_discriminator
network_G = unet_generator(isize, ch_input, ch_output)
network_D = patch_discriminator(isize, ch_input, ch_output, layer_max_d)

In [None]:
def loss_object(target, output):
    return -tf.math.reduce_mean(tf.math.log(output+1e-12)*target+tf.math.log(1-output+1e-12)*(1-target))
#loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def multi_loss_L(fake_B, real_B):

    loss_L0 = tf.reduce_mean(tf.abs(real_B - fake_B))    
    
    fake_B = fake_B*lim_hmi
    real_B = real_B*lim_hmi
    
    loss_L1 = tf.reduce_mean(tf.abs(tf.clip_by_value(real_B, -100, 100)/100 - tf.clip_by_value(fake_B, -100, 100)/100))
    loss_L2 = tf.reduce_mean(tf.abs(tf.clip_by_value(real_B, -30, 30)/30 - tf.clip_by_value(fake_B, -30, 30)/30))

    return loss_L0, loss_L1, loss_L2

def loss_Dis(output_D_real, output_D_fake):
    loss_D_real = loss_object(target=tf.ones_like(output_D_real), output=output_D_real)
    loss_D_fake = loss_object(target=tf.zeros_like(output_D_fake), output=output_D_fake)
    return (loss_D_real+loss_D_fake)/2.

def loss_Gen(output_D_fake, fake_B, real_B):
    loss_G_fake = loss_object(target=tf.ones_like(output_D_fake), output=output_D_fake)
    loss_L = tf.reduce_mean(tf.abs(real_B - fake_B))
#    loss_L0, loss_L1, loss_L2 = multi_loss_L(real_B, fake_B)
#    loss_L = loss_L0 + loss_L1 + loss_L2
    
    return loss_G_fake, loss_L

optimizer_D = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
optimizer_G = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
@tf.function
def train_step(real_A, real_B):
    with tf.GradientTape() as tape_G, tf.GradientTape() as tape_D:
        fake_B = network_G(real_A, training=True)
        output_D_real = network_D([real_A, real_B], training=True)
        output_D_fake = network_D([real_A, fake_B], training=True)

        loss_G_fake, loss_L = loss_Gen(output_D_fake, fake_B, real_B)
        loss_D = loss_Dis(output_D_real, output_D_fake)
        loss_G = loss_G_fake + loss_L*100.

    gradient_G = tape_G.gradient(loss_G, network_G.trainable_variables)
    gradient_D = tape_D.gradient(loss_D, network_D.trainable_variables)

    optimizer_G.apply_gradients(zip(gradient_G, network_G.trainable_variables))
    optimizer_D.apply_gradients(zip(gradient_D, network_D.trainable_variables))

    return loss_D, loss_G_fake, loss_L

In [None]:
import time

TM = time.localtime(time.time())
t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)

print('\n--------------------------------\n')

print('\n%s : Now start below session!\n'%(t_now))
print('Model mode : %s'%mode)
print('Data mode : %s to %s'%(name_input, name_output))
print('Model save path: %s'%(root_model))
print('Validation snap save path: %s'%(root_validation))
print('Test result save path: %s'%(root_test))
print('# of train, validation, and test datasets : %d, %d, %d'%(nb_train, nb_validation, nb_test))

print('\n--------------------------------\n')

t0 = time.time()
epoch = iter_gen = 0
err_D = err_D_sum = err_D_mean = 0
err_G = err_G_sum = err_G_mean = 0
err_L = err_L_sum = err_L_mean = 0

while iter_gen <= iter_max :

    epoch, train_A, train_B = next(train_batch)
    train_A = tf.cast(train_A, tf.float32)
    train_B = tf.cast(train_B, tf.float32)
    
    err_D, err_G, err_L = train_step(train_A, train_B)
    
    err_D_sum += err_D
    err_G_sum += err_G
    err_L_sum += err_L
    
    iter_gen += bsize
    
    if iter_gen % iter_display == 0:

        err_D_mean = err_D_sum/iter_display
        err_G_mean = err_G_sum/iter_display
        err_L_mean = err_L_sum/iter_display
        message1 = '[%d][%d/%d]' % (epoch, iter_gen, iter_max)
        message2 = 'LOSS_D: %5.3f LOSS_G: %5.3f LOSS_L: %5.3f T: %dsec/%dits' % (err_D_mean, err_G_mean, err_L_mean, time.time()-t0, iter_display)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s %s'%(t_now, message1, message2))

        err_L_sum, err_G_sum, err_D_sum = 0, 0, 0
        t0 = time.time()

    if iter_gen % iter_save == 0:
        
        dst_model = '%s/%s.%07d'%(root_model, mode, iter_gen)
        network_G.save('%s.G.h5'%(dst_model))
        network_D.save('%s.D.h5'%(dst_model))
        message3 = 'network_G and network_D are saved under %s'%(root_model)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s'%(t_now, message3))

        path_validation = '%s/iter_%07d'%(root_validation, iter_gen)
        os.makedirs(path_validation, exist_ok=True)
        for pair_ in list_validation :
            file_A, file_B = pair_
            real_A = make_tensor(file_A)
            real_A = tf.cast(real_A, tf.float32)
            fake_B = network_G.predict(real_A)
            fake_B, fake_B_png = make_output(fake_B, file_B)
            name_save = '%s.%07d.%s'%(mode, iter_gen, file_A.split('/')[-1][-23:-4])
            np.save('%s/%s.npy'%(path_validation, name_save), fake_B)
            imsave('%s/%s.png'%(path_validation, name_save), fake_B_png)
        message4 = 'Validation snaps (%d images) are saved in %s'%(nb_validation, path_validation)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s'%(t_now, message4))
        
        path_test = '%s/iter_%07d'%(root_test, iter_gen)
        os.makedirs(path_test, exist_ok=True)
        for pair_ in list_test :
            file_A, file_B = pair_
            real_A = make_tensor(file_A)
            real_A = tf.cast(real_A, tf.float32)
            fake_B = network_G.predict(real_A)
            fake_B, fake_B_png = make_output(fake_B, file_B)
            name_save = '%s.%07d.%s'%(mode, iter_gen, file_A.split('/')[-1][-23:-4])
            np.save('%s/%s.npy'%(path_test, name_save), fake_B)
            imsave('%s/%s.png'%(path_test, name_save), fake_B_png)
        message5 = 'Test snaps (%d images) are saved in %s'%(nb_test, path_test)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s'%(t_now, message5))        
        
        t0 = time.time()    