In [None]:
gpu_id = 2

In [None]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu_id], 'GPU')
tf.config.experimental.set_memory_growth(gpus[gpu_id], True)

In [None]:
iter_display = 2000
iter_save = 10000
iter_max = 500000


name_input = 'hmi_M_720s'
is_aia_input = False

name_output = 'aia_304'
is_aia_output = True

csv_train = '/home/park_e/datasets/solar_generation_train.csv'
csv_validation = '/home/park_e/datasets/solar_generation_validation.csv'
csv_test = '/home/park_e/datasets/solar_generation_test.csv'

root_save = '/userhome/park_e/solar_generation_tf2'

isize = 1024
rsun=392
ch_input = 1
ch_output = 1

bsize = 1
do_shake = True

layer_max_d = 3

lim_hmi = 1000.
lim_aia = 12.

In [None]:
mode1 = 'unet_cgan%d'%(layer_max_d)
mode2 = '%s.%s'%(name_input, name_output)

root_model = '%s/%s/%s/model' % (root_save, mode1, mode2)
root_validation = '%s/%s/%s/validation' % (root_save, mode1, mode2)
root_test = '%s/%s/%s/test' % (root_save, mode1, mode2)

import os
os.makedirs(root_model, exist_ok=True)
os.makedirs(root_validation, exist_ok=True)
os.makedirs(root_test, exist_ok=True)

In [None]:
from pandas import read_csv

total_train = read_csv(csv_train)
list_train_input = total_train[name_input]
list_train_output = total_train[name_output]
assert len(list_train_input) == len(list_train_output)
list_train = list(zip(list_train_input, list_train_output))
nb_train = len(list_train)

total_validation = read_csv(csv_validation)
list_validation = total_validation[name_input]
nb_validation = len(list_validation)

total_test = read_csv(csv_test)
list_test = total_test[name_input]
nb_test = len(list_test)

print(nb_train, nb_validation, nb_test)

In [None]:
import numpy as np
from utils import rescale, bytescale
from imageio import imsave
from random import shuffle

X = np.arange(isize)[:, None]
Y = np.arange(isize)[None, :]
XY = np.sqrt((X-isize/2.)**2. + (Y-isize/2.)**2.)
cfilter = np.where(XY>rsun*0.99)

class make_tensor():
    def __init__(self, is_aia, ch):
        self.is_aia = is_aia
        self.ch = ch
    def __call__(self, x):
        x = np.load(x)
        if self.is_aia :
            x = np.log2((x+1.).clip(1., 2.**lim_aia))
            x = rescale(x, imin=0., imax=14., omin=-1, omax=1)
            x[cfilter]=-1
        else :
            x = x.clip(-lim_hmi, lim_hmi) / lim_hmi
            x[cfilter]=-1
        x.shape = (1, isize, isize, self.ch)
        return x

class make_result():
    def __init__(self, is_aia, ch):
        self.is_aia = is_aia
        self.ch = ch
    def __call__(self, x):
        x.shape = (isize, isize, self.ch) if self.ch != 1 else (isize, isize)
        if self.is_aia :
            result_npy = rescale(x, imin=-1., imax=1., omin=0., omax=lim_aia)
            result_npy = 2**result_npy - 1
            result_png = bytescale(x, imin=-1., imax=1.)
        else :
            result_npy = x*lim_hmi
            result_png = result_npy.clip(-100, 100)
            result_png = bytescale(result_png, imin=-100, imax=100)
        return result_npy, result_png

def shake_tensor(batch_A, batch_B):
    pad = int(isize/64) - 1
    x, y = np.random.randint(2*pad+1), np.random.randint(2*pad+1)
    batch_A = np.pad(batch_A, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=-1)
    batch_B = np.pad(batch_B, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=-1)
    batch_A = batch_A[:, x:x+isize, y:y+isize, :]
    batch_B = batch_B[:, x:x+isize, y:y+isize, :]
    return batch_A, batch_B

make_tensor_input = make_tensor(is_aia_input, ch_input)
make_tensor_output = make_tensor(is_aia_output, ch_output)
make_result_output = make_result(is_aia_output, ch_output)    

In [None]:
def train_batch_generator():
    epoch = i = 0
    tmpsize = None
    while True:
        size = tmpsize if tmpsize else bsize
        if i + size > nb_train :
            shuffle(list_train)
            i = 0
            epoch += 1
        batch_A = np.concatenate([make_tensor_input(list_train[j][0]) for j in range(i, i+size)], 0)
        batch_B = np.concatenate([make_tensor_output(list_train[j][1]) for j in range(i, i+size)], 0)
        if do_shake :
            batch_A, batch_B = shake_tensor(batch_A, batch_B)
        i += size
        tmpsize = yield epoch, batch_A, batch_B

train_batch = train_batch_generator()

In [None]:
from networks_tf2 import unet_generator, patch_discriminator
network_G = unet_generator(isize, ch_input, ch_output)
network_D = patch_discriminator(isize, ch_input, ch_output, layer_max_d)

In [None]:
def loss_object(target, output):
    return -tf.math.reduce_mean(tf.math.log(output+1e-12)*target+tf.math.log(1-output+1e-12)*(1-target))
#loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def loss_Dis(output_D_real, output_D_fake):
    loss_D_real = loss_object(target=tf.ones_like(output_D_real), output=output_D_real)
    loss_D_fake = loss_object(target=tf.zeros_like(output_D_fake), output=output_D_fake)
    return (loss_D_real+loss_D_fake)/2.

def loss_Gen(output_D_fake, fake_B, real_B):
    loss_G_fake = loss_object(target=tf.ones_like(output_D_fake), output=output_D_fake)
    loss_L = tf.reduce_mean(tf.abs(real_B - fake_B))
    return loss_G_fake, loss_L

optimizer_D = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
optimizer_G = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
@tf.function
def train_step(real_A, real_B):
    with tf.GradientTape() as tape_G, tf.GradientTape() as tape_D:
        fake_B = network_G(real_A, training=True)
        output_D_real = network_D([real_A, real_B], training=True)
        output_D_fake = network_D([real_A, fake_B], training=True)

        loss_G_fake, loss_L = loss_Gen(output_D_fake, fake_B, real_B)
        loss_D = loss_Dis(output_D_real, output_D_fake)
        loss_G = loss_G_fake + loss_L*100.

    gradient_G = tape_G.gradient(loss_G, network_G.trainable_variables)
    gradient_D = tape_D.gradient(loss_D, network_D.trainable_variables)

    optimizer_G.apply_gradients(zip(gradient_G, network_G.trainable_variables))
    optimizer_D.apply_gradients(zip(gradient_D, network_D.trainable_variables))

    return loss_D, loss_G_fake, loss_L

In [None]:
import time

TM = time.localtime(time.time())
t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)

print('\n--------------------------------\n')

print('\n%s : Now start below session!\n'%(t_now))
print('Model mode : %s'%mode1)
print('Data mode : %s'%mode2)
print('Model save path: %s'%(root_model))
print('Validation snap save path: %s'%(root_validation))
print('Test result save path: %s'%(root_test))
print('# of train, validation, and test datasets : %d, %d, %d'%(nb_train, nb_validation, nb_test))

print('\n--------------------------------\n')

In [None]:
t0 = time.time()
epoch = iter_gen = 0
err_D = err_D_sum = err_D_mean = 0
err_G = err_G_sum = err_G_mean = 0
err_L = err_L_sum = err_L_mean = 0

while iter_gen <= iter_max :

    epoch, train_A, train_B = next(train_batch)
    train_A = tf.cast(train_A, tf.float32)
    train_B = tf.cast(train_B, tf.float32)
    
    err_D, err_G, err_L = train_step(train_A, train_B)
    
    err_D_sum += err_D
    err_G_sum += err_G
    err_L_sum += err_L
    
    iter_gen += bsize
    
    if iter_gen % iter_display == 0:

        err_D_mean = err_D_sum/iter_display
        err_G_mean = err_G_sum/iter_display
        err_L_mean = err_L_sum/iter_display
        message1 = '[%d][%d/%d]' % (epoch, iter_gen, iter_max)
        message2 = 'LOSS_D: %5.3f LOSS_G: %5.3f LOSS_L: %5.3f T: %dsec/%dits' % (err_D_mean, err_G_mean, err_L_mean, time.time()-t0, iter_gen)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s %s'%(t_now, message1, message2))

        err_L_sum, err_G_sum, err_D_sum = 0, 0, 0
        t0 = time.time()

    if iter_gen % iter_save == 0:
        
        dst_model = '%s/%s.%s.%07d'%(root_model, mode1, mode2, iter_gen)
        network_G.save('%s.G.h5'%(dst_model))
        network_D.save('%s.D.h5'%(dst_model))
        message3 = 'network_G and network_D are saved under %s'%(root_model)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s'%(t_now, message3))

        path_validation = '%s/iter_%07d'%(root_validation, iter_gen)
        os.makedirs(path_validation, exist_ok=True)
        for file_A in list_validation :
            real_A = make_tensor_input(file_A)
            real_A = tf.cast(real_A, tf.float32)
            fake_B = network_G.predict(real_A)
            fake_B, fake_B_png = make_result_output(fake_B)
            name_save = '%s.%s.%07d.%s'%(mode1, mode2, iter_gen, file_A.split('/')[-1][-23:-4])
            np.save('%s/%s.npy'%(path_validation, name_save), fake_B)
            imsave('%s/%s.png'%(path_validation, name_save), fake_B_png)
        message4 = 'Validation snaps (%d images) are saved in %s'%(nb_validation, path_validation)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s'%(t_now, message4))
        
        path_test = '%s/iter_%07d'%(root_test, iter_gen)
        os.makedirs(path_test, exist_ok=True)
        for file_A in list_test :
            real_A = make_tensor_input(file_A)
            real_A = tf.cast(real_A, tf.float32)
            fake_B = network_G.predict(real_A)
            fake_B, fake_B_png = make_result_output(fake_B)
            name_save = '%s.%s.%07d.%s'%(mode1, mode2, iter_gen, file_A.split('/')[-1][-23:-4])
            np.save('%s/%s.npy'%(path_test, name_save), fake_B)
            imsave('%s/%s.png'%(path_test, name_save), fake_B_png)
        message5 = 'Test snaps (%d images) are saved in %s'%(nb_test, path_test)
        TM = time.localtime(time.time())
        t_now = '%04d-%02d-%02d %02d:%02d:%02d'%(TM.tm_year, TM.tm_mon, TM.tm_mday, TM.tm_hour, TM.tm_min, TM.tm_sec)
        print('%s: %s'%(t_now, message5))        
        
        t0 = time.time()    