# GAN_initial_attempt

Initial experiments of GAN training

* how training loops are organized
* how to "level-up" the descriminator (given that the generator has been fine-tuned)

In [1]:
# general tools
import sys
from glob import glob

# data tools
import numpy as np
from random import shuffle

# deep learning tools
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K

# graph tools
import matplotlib.pyplot as plt
%matplotlib inline

# custom tools
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/utils/')
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/DL_downscaling/utils/')
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/DL_downscaling/')
from namelist import *
import data_utils as du
import model_utils as mu
import train_utils as tu

In [2]:
import importlib
importlib.reload(mu)

<module 'model_utils' from '/glade/u/home/ksha/WORKSPACE/DL_downscaling/utils/model_utils.py'>

**Generator**, initialized with pre-trained UNET weights

In [3]:
N = [56, 112, 224, 448] # number of channels per downsampling level
l = [1e-4, 2e-4] # G lr; D lr

sea = 'jja' # testing in the JJA season
N_input = 3 # LR T2, HR elev, LR elev
VAR = 'TMAX'


model_name = 'UNET{}_{}_{}_tune'.format(N_input, VAR, sea)
model_path = temp_dir+model_name+'.hdf'
print('Import model: {}'.format(model_name))
backbone = keras.models.load_model(model_path)
W = backbone.get_weights()

# generator
G = mu.UNET(N, (None, None, N_input))
# optimizer
opt_G = keras.optimizers.Adam(lr=l[0])

print('Compiling G')
G.compile(loss=keras.losses.mean_absolute_error, optimizer=opt_G)
G.set_weights(W)

Import model: UNET3_TMAX_jja_tune
Compiling G


VGG-like **Descriminator**

{G_output, HR_elev, LR_elev} --> {conv56, conv56, maxpool} --> {conv112, conv112, maxpool} --> {conv224, conv224, maxpool} --> {global_pool, sigmoid}

* 1/4 size of the G

In [4]:
# !! test and move to utils
def vgg_descriminator(N, input_size):
    IN = keras.layers.Input(input_size)

    X = mu.CONV_stack(IN, N[0], kernel_size=3, stack_num=2)
    X = keras.layers.MaxPooling2D(pool_size=(2, 2))(X)

    X = mu.CONV_stack(X, N[1], kernel_size=3, stack_num=2)
    X = keras.layers.MaxPooling2D(pool_size=(2, 2))(X)

    X = mu.CONV_stack(X, N[2], kernel_size=3, stack_num=2)
    X = keras.layers.MaxPooling2D(pool_size=(2, 2))(X)

    X = keras.layers.GlobalMaxPool2D()(X)
    FLAG = keras.layers.Dense(2, activation=keras.activations.sigmoid)(X) #, 

    return keras.models.Model(inputs=[IN], outputs=[FLAG])

    
input_size = (None, None, N_input+1)
D = vgg_descriminator(N, input_size)

opt_D = keras.optimizers.Adam(lr=l[1])
print('Compiling D')
D.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt_D)

Compiling D


**GAN**

In [5]:
GAN_IN = keras.layers.Input((None, None, N_input))
G_OUT = G(GAN_IN)
D_IN = keras.layers.Concatenate()([G_OUT, GAN_IN])
D_OUT = D(D_IN)
GAN = keras.models.Model(GAN_IN, [G_OUT, D_OUT])

print('Compiling GAN')
# content_loss + 1e-3 * adversarial_loss
GAN.compile(loss=[keras.losses.mean_absolute_error, keras.losses.categorical_crossentropy], 
            loss_weights=[1.0, 1e-3],
            optimizer=opt_G)

Compiling GAN


data generators

In [6]:
file_path = BATCH_dir
trainfiles = glob(file_path+'{}_BATCH_*_TORI*_{}*.npy'.format(VAR, sea)) # e.g., TMAX_BATCH_128_VORIAUG_mam30.npy
validfiles = glob(file_path+'{}_BATCH_*_VORI*_{}*.npy'.format(VAR, sea))
# shuffle filenames
shuffle(trainfiles)
shuffle(validfiles)

L_train = len(trainfiles)

In [8]:
temp_batch = np.load(trainfiles[0], allow_pickle=True)[()]
X = temp_batch['batch']

In [18]:
np.max(X[0, ..., 0])

1.3607242122565246

In [13]:
input_flag = [False, True, False, False, True, True] # LR T2, HR elev, LR elev
output_flag = [True, False, False, False, False, False] # HR T2
inout_flag = [True, True, False, False, True, True]
labels = ['batch', 'batch'] # input and output labels

In [7]:
model_path = temp_dir+'GAN_D_{}_{}.hdf'.format(VAR, sea)

epochs = 5
batch_size = 200

y_bad = np.zeros(batch_size)
y_good = np.ones(batch_size)
dummy_good = keras.utils.to_categorical(y_good)
dummy_mix = keras.utils.to_categorical(np.concatenate((y_bad, y_good), axis=0))

input_flag = [False, True, False, False, True, True] # LR T2, HR elev, LR elev
output_flag = [True, False, False, False, False, False] # HR T2
inout_flag = [True, True, False, False, True, True]
labels = ['batch', 'batch'] # input and output labels
#@tf.function(...)

#def ...

for i in range(epochs):
    print('epoch = {}'.format(i))
    for j, name in enumerate(trainfiles):
        # import batch data
        temp_batch = np.load(name, allow_pickle=True)[()]
        X = temp_batch['batch']
        
        # D training
        D.trainable = True
        g_in = X[..., input_flag]
        g_out = G.predict([g_in]) # <-- np.array

        d_in_fake = np.concatenate((g_out, g_in), axis=-1) # channel last
        d_in_true = X[..., inout_flag]
        d_in = np.concatenate((d_in_fake, d_in_true), axis=0) # batch size doubled
        d_target = dummy_mix
        d_shuffle_ind = du.shuffle_ind(2*batch_size)
        
        d_loss = D.train_on_batch(d_in[d_shuffle_ind, ...], d_target[d_shuffle_ind, ...])
        loss_hist[epochs*i+]
#         # G training / transferring
#         D.trainable = False
#         gan_in = X[..., input_flag]
#         gan_target = [X[..., output_flag], dummy_good]
        
#         gan_loss = GAN.train_on_batch(gan_in, gan_target)
        if j%100 == 0:
            print('\t{} step loss = {}'.format(j, d_loss))
            
D.save(model_path)

epoch = 0
	0 step loss = 0.8007091283798218
	100 step loss = 0.6926652789115906
	200 step loss = 0.17872953414916992
	300 step loss = 0.06848559528589249
	400 step loss = 0.05440543219447136
	500 step loss = 0.0018735595513135195
	600 step loss = 0.015836108475923538
	700 step loss = 0.008387360721826553


KeyboardInterrupt: 

In [10]:
D.predict(d_in_true)

array([[9.99752402e-01, 5.96046448e-08],
       [9.98400152e-01, 3.57627869e-07],
       [9.99999940e-01, 0.00000000e+00],
       [9.99885857e-01, 8.94069672e-08],
       [9.99999642e-01, 0.00000000e+00],
       [9.99994278e-01, 0.00000000e+00],
       [9.99684751e-01, 1.78813934e-07],
       [9.99998450e-01, 0.00000000e+00],
       [9.99999881e-01, 0.00000000e+00],
       [9.99999285e-01, 0.00000000e+00],
       [9.99972045e-01, 0.00000000e+00],
       [9.98746514e-01, 3.57627869e-07],
       [9.99801993e-01, 2.05636024e-06],
       [9.99818087e-01, 1.19209290e-07],
       [9.99967217e-01, 0.00000000e+00],
       [9.99992788e-01, 2.98023224e-08],
       [9.99933004e-01, 0.00000000e+00],
       [9.99992847e-01, 0.00000000e+00],
       [9.99999404e-01, 0.00000000e+00],
       [9.99999702e-01, 0.00000000e+00],
       [9.99998093e-01, 1.78813934e-07],
       [9.99996305e-01, 8.94069672e-08],
       [9.99991298e-01, 0.00000000e+00],
       [9.99998212e-01, 0.00000000e+00],
       [9.999527

In [9]:
dummy_mix

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.