In [1]:
# general tools
import sys
from glob import glob
# data tools
import time
import numpy as np
from random import shuffle

# deep learning tools
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K

# custom tools
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/utils/')
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/DL_downscaling/utils/')
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/DL_downscaling/')
from namelist import *
import data_utils as du
import model_utils as mu
import train_utils as tu

In [2]:
from importlib import reload
reload(mu)

<module 'model_utils' from '/glade/u/home/ksha/WORKSPACE/DL_downscaling/utils/model_utils.py'>

In [3]:
class AdaIN(keras.layers.Layer):
    def __init__(self, 
             axis=-1,
             momentum=0.99,
             epsilon=1e-3,
             center=True,
             scale=True,
             **kwargs):
        super(AdaIN, self).__init__(**kwargs)
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
    
    
    def build(self, input_shape):
    
        dim = input_shape[0][self.axis]
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape[0]) + '.')
    
        super(AdaIN, self).build(input_shape) 
    
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs[0])
        reduction_axes = list(range(0, len(input_shape)))
        
        beta = inputs[1]
        gamma = inputs[2]

        if self.axis is not None:
            del reduction_axes[self.axis]

        del reduction_axes[0]
        mean = K.mean(inputs[0], reduction_axes, keepdims=True)
        stddev = K.std(inputs[0], reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs[0] - mean) / stddev

        return normed * gamma + beta
    
    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale
        }
        base_config = super(AdaIN, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    
    def compute_output_shape(self, input_shape):
    
        return input_shape[0]

def stride_conv(X, channel, pool_size, activation='relu'):
    X = keras.layers.Conv2D(channel, pool_size, strides=(pool_size, pool_size), padding='valid', 
                            use_bias=False, kernel_initializer='he_normal')(X)
    X = keras.layers.BatchNormalization(axis=3)(X)
    if activation == 'relu':
        X = keras.layers.ReLU()(X)
    elif activation == 'leaky':
        X = keras.layers.LeakyReLU(alpha=0.3)(X)
    return X

def DENSE_stack(X, units):
    L = len(units)
    for i in range(L):
        X = keras.layers.Dense(units[i], use_bias=False, kernel_initializer='he_normal')(X)
        X = keras.layers.BatchNormalization()(X)
        X = keras.layers.ReLU()(X)
    return X

def CONV_stack(X, channel, kernel_size, stack_num, activation='relu'):
    '''
    Stacked convolution-BN-ReLU blocks
    '''
    for i in range(stack_num):
        X = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(X)
        X = keras.layers.BatchNormalization(axis=3)(X)
        if activation == 'relu':
            X = keras.layers.ReLU()(X)
        elif activation == 'leaky':
            X = keras.layers.LeakyReLU(alpha=0.3)(X)
    return X

# UNet
def UNET_left(X, channel, kernel_size=3, pool_size=2, pool=True, activation='relu'):
    if pool:
        X = keras.layers.MaxPooling2D(pool_size=(pool_size, pool_size))(X)
    else:
        X = stride_conv(X, channel, pool_size, activation=activation)
    X = CONV_stack(X, channel, kernel_size, stack_num=1, activation=activation)
    return X

def UNET_right(X, X_left, channel, kernel_size=3, pool_size=2, activation='relu'):
    X = keras.layers.Conv2DTranspose(channel, kernel_size, strides=(pool_size, pool_size), padding='same')(X)
    X = CONV_stack(X, channel, kernel_size, stack_num=1, activation=activation) 
    H = keras.layers.concatenate([X_left, X], axis=3)
    H = CONV_stack(H, channel, kernel_size, stack_num=1, activation=activation)
    return H

def UNET_in_style(X, STY, channel, kernel_size=3, pool_size=2, pool=True, activation='relu'):
    # Conv layer
    X = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(X)
    # additive noise (not applied)
    # ----- AdaIN ----- #
    # affine transform
    b = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    b = keras.layers.Reshape([1, 1, channel])(b)
    g = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    g = keras.layers.Reshape([1, 1, channel])(g)
    # AdaIN
    X = AdaIN()([X, b, g])
    # ----------------- #
    X = keras.layers.ReLU()(X)
    
    # Conv layer x2
    X = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(X)
    # additive noise (not applied)
    # ----- AdaIN ----- #
    # affine transform
    b = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    b = keras.layers.Reshape([1, 1, channel])(b)
    g = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    g = keras.layers.Reshape([1, 1, channel])(g)
    # AdaIN
    X = AdaIN()([X, b, g])
    # ----------------- #
    X = keras.layers.ReLU()(X)
    return X

def UNET_out_style(X, STY, channel=2, kernel_size=3, pool_size=2, pool=True, activation='relu'):
    # Conv layer
    X = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(X)
    # additive noise (not applied)
    # ----- AdaIN ----- #
    # affine transform
    b = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    b = keras.layers.Reshape([1, 1, channel])(b)
    g = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    g = keras.layers.Reshape([1, 1, channel])(g)
    # AdaIN
    X = AdaIN()([X, b, g])
    # ----------------- #
    X = keras.layers.ReLU()(X)
    return X


def UNET_left_style(X, STY, channel, kernel_size=3, pool_size=2, pool=True, activation='relu'):
    # downsampling layer
    if pool:
        X = keras.layers.MaxPooling2D(pool_size=(pool_size, pool_size))(X)
    else:
        X = stride_conv(X, channel, pool_size, activation=activation)
    # Conv layer
    X = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(X)
    
    # ----- AdaIN ----- #
    # affine transform
    b = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    b = keras.layers.Reshape([1, 1, channel])(b)
    g = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    g = keras.layers.Reshape([1, 1, channel])(g)
    # AdaIN
    X = AdaIN()([X, b, g])
    # ----------------- #
    # additive noise (not applied)
    X = keras.layers.ReLU()(X)
    return X

def UNET_right_style(X, X_left, STY, channel, kernel_size=3, pool_size=2, activation='relu'):
    
    # up-sampling
    X = keras.layers.Conv2DTranspose(channel, kernel_size, strides=(pool_size, pool_size), padding='same')(X)
    
    # conv
    X = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(X)
    # ----- AdaIN ----- #
    # affine transform
    b = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    b = keras.layers.Reshape([1, 1, channel])(b)
    g = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    g = keras.layers.Reshape([1, 1, channel])(g)
    # AdaIN
    X = AdaIN()([X, b, g])
    # ----------------- #
    X = keras.layers.ReLU()(X)
    
    H = keras.layers.concatenate([X_left, X], axis=3)
    # conv
    H = keras.layers.Conv2D(channel, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')(H)
    # ----- AdaIN ----- #
    # affine transform
    b = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    b = keras.layers.Reshape([1, 1, channel])(b)
    g = keras.layers.Dense(channel, activation=keras.activations.linear)(STY)
    g = keras.layers.Reshape([1, 1, channel])(g)
    # AdaIN
    H = AdaIN()([H, b, g])
    # ----------------- #
    H = keras.layers.ReLU()(H)
    return H    

# Style-SRGAN

In [4]:
N_input = 3
input_size = (None, None, N_input)
input_stack_num = 2
pool = False
activation = 'relu'
N = [48, 96, 192, 384]
l = [5e-5, 5e-5] # G lr; D lr
lmd = 1e-3
epochs = 150
# early stopping settings
min_del = 0
max_tol = 10 # early stopping with patience

In [5]:
sea = 'jja' 

In [15]:
latent_size = N[-1]
mapping_size = N[-1]

In [16]:
key = 'TEST'

In [17]:
# IN1 = keras.layers.Input(shape=[latent_size])
# # layer 1
# STY = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(IN1)
# STY = keras.layers.ReLU()(STY)
# # layer 2
# STY = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(STY)
# STY = keras.layers.ReLU()(STY)
# # layer 3
# STY = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(STY)
# STY = keras.layers.ReLU()(STY)

# IN2 = keras.layers.Input(input_size)
# # left blocks
# X_en1 = UNET_in_style(IN2, STY, N[0], activation=activation)
# X_en2 = UNET_left_style(X_en1, STY, N[1], pool=pool, activation=activation)
# X_en3 = UNET_left_style(X_en2, STY, N[2], pool=pool, activation=activation)
# # bottom
# X4 = UNET_left_style(X_en3, STY, N[3], pool=pool, activation=activation)
# # right blocks
# X_de3 = UNET_right(X4, X_en3, N[2], activation=activation)
# X_de2 = UNET_right(X_de3, X_en2, N[1], activation=activation)
# X_de1 = UNET_right(X_de2, X_en1, N[0], activation=activation)
# # output
# OUT = CONV_stack(X_de1, 2, kernel_size=3, stack_num=1, activation=activation)
# OUT = keras.layers.Conv2D(1, 1, activation=keras.activations.linear, padding='same')(OUT)
# G_style = keras.models.Model(inputs=[IN1, IN2], outputs=[OUT])

In [18]:
# IN1 = keras.layers.Input(shape=[latent_size])
# # layer 1
# STY1 = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(IN1)
# STY1 = keras.layers.ReLU()(STY1)
# # layer 2
# STY1 = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(STY1)
# STY1 = keras.layers.ReLU()(STY1)
# # layer 3
# STY1 = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(STY1)
# STY1 = keras.layers.ReLU()(STY1)


# IN2 = keras.layers.Input(shape=[latent_size])
# # layer 1
# STY2 = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(IN2)
# STY2 = keras.layers.ReLU()(STY2)
# # layer 2
# STY2 = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(STY2)
# STY2 = keras.layers.ReLU()(STY2)
# # layer 3
# STY2 = keras.layers.Dense(mapping_size, kernel_initializer='he_normal')(STY2)
# STY2 = keras.layers.ReLU()(STY2)


# IN3 = keras.layers.Input(input_size)
# # left blocks
# X_en1 = UNET_in_style(IN3, STY1, N[0], activation=activation)
# X_en2 = UNET_left_style(X_en1, STY1, N[1], pool=pool, activation=activation)
# X_en3 = UNET_left_style(X_en2, STY1, N[2], pool=pool, activation=activation)
# # bottom
# X4 = UNET_left_style(X_en3, STY2, N[3], pool=pool, activation=activation)
# # right blocks
# X_de3 = UNET_right_style(X4, X_en3, STY2, N[2], activation=activation)
# X_de2 = UNET_right_style(X_de3, X_en2, STY2, N[1], activation=activation)
# X_de1 = UNET_right_style(X_de2, X_en1, STY2, N[0], activation=activation)
# # output
# OUT = UNET_out_style(X_de1, STY2, activation=activation)
# OUT = keras.layers.Conv2D(1, 1, activation=keras.activations.linear, padding='same')(OUT)
# G_style = keras.models.Model(inputs=[IN1, IN2, IN3], outputs=[OUT])

In [19]:
#keras.utils.plot_model(G_style)

In [20]:
G_style = mu.UNET_STYLE(N, input_size, 4, latent_size, mapping_size, pool=pool, activation=activation, noise=[0.2, 0.1])
opt_G = keras.optimizers.Adam(lr=0) # <--- compile G for validation only
print('Compiling G')
G_style.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)

Compiling G


In [21]:
G_style.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
unet_in (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
unet_left0_stack0_conv (Conv2D) (None, None, None, 4 1296        unet_in[0][0]                    
__________________________________________________________________________________________________
unet_left0_stack0_bn (BatchNorm (None, None, None, 4 192         unet_left0_stack0_conv[0][0]     
__________________________________________________________________________________________________
unet_left0_stack0_relu (ReLU)   (None, None, None, 4 0           unet_left0_stack0_bn[0][0]       
____________________________________________________________________________________________

In [22]:
model_import_dir = temp_dir
W = tu.dummy_loader(model_import_dir+'STRANS_G_TMEAN_jja.hdf')

Import model:
/glade/work/ksha/data/Keras/BACKUP/STRANS_G_TMEAN_jja.hdf


In [23]:
G_style.set_weights(W)

In [27]:
G_style.layers

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x2b9b2a3b4fd0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x2b9b2a3b4f10>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x2b9b2a363810>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x2b9b2a3dc850>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x2b9b2a3dcdd0>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x2b9b2a3f5110>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x2b9b2a46af50>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x2b9b2a46a510>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x2b9b2a47ce50>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x2b9b2a511b10>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x2b9b2a577d50>,
 <tensorflow.python.keras.engine.input_layer.InputLayer at 0x2b9b2a362c90>,
 <tensorflow.python.keras.layers.normalization_v

In [42]:
# load weights
# model_name = 'NEO_D_TMEAN_{}_pretrain'.format(sea) # GAN_D_{}_{}
# model_path = temp_dir+model_name+'.hdf'

# print('Import model: {}'.format(model_name))
# backbone = keras.models.load_model(model_path)
# W = backbone.get_weights()

input_size = (None, None, N_input+1)
D = mu.vgg_descriminator(N, input_size)

opt_D = keras.optimizers.Adam(lr=l[1])
print('Compiling D')
D.compile(loss=keras.losses.mean_squared_error, optimizer=opt_D)
#D.set_weights(W)

Compiling D


In [26]:
D.trainable = False
for layer in D.layers:
    layer.trainable = False
#
GAN_IN1 = keras.layers.Input(shape=[latent_size])
GAN_IN2 = keras.layers.Input(shape=[latent_size])
GAN_IN3 = keras.layers.Input((None, None, N_input))

G_OUT = G_style([GAN_IN1, GAN_IN2, GAN_IN3])
D_IN = keras.layers.Concatenate()([G_OUT, GAN_IN3])
D_OUT = D(D_IN)
GAN = keras.models.Model([GAN_IN1, GAN_IN2, GAN_IN3], [G_OUT, D_OUT])
# optimizer
opt_GAN = keras.optimizers.Adam(lr=l[0])
print('Compiling GAN')
# content_loss + 1e-3 * adversarial_loss
GAN.compile(loss=[keras.losses.mean_squared_error, keras.losses.binary_crossentropy], 
            loss_weights=[1.0, lmd],
            optimizer=opt_GAN)

Compiling GAN


In [27]:
# ---------- Training settings ---------- #
# Macros
input_flag = [False, False, False, False, False, True] # LR T2, HR elev, LR elev
output_flag = [False, False, False, False, True, False] # HR T2
inout_flag = [False, False, False, False, True, True]
labels = ['batch', 'batch'] # input and output labels

# Filepath
file_path = BATCH_dir
trainfiles = glob(file_path+'TMEAN_BATCH_*_TORI_*{}*.npy'.format(sea)) # e.g., TMAX_BATCH_128_VORIAUG_mam30.npy
validfiles = glob(file_path+'TMEAN_BATCH_*_VORI_*{}*.npy'.format(sea))
# shuffle filenames
shuffle(trainfiles)
shuffle(validfiles)
#
L_train = len(trainfiles)
gen_valid = tu.grid_grid_gen_noise(validfiles, labels, input_flag, output_flag, latent_size, sampling=2)

# model names
G_name = '{}_G_TMEAN_{}'.format(key, sea)
D_name = '{}_D_TMEAN_{}'.format(key, sea)
G_path = temp_dir+G_name+'.hdf'
D_path = temp_dir+D_name+'.hdf'
hist_path = temp_dir+'{}_LOSS_TMEAN_{}.npy'.format(key, sea)

# loss backup
GAN_LOSS = np.zeros([int(epochs*L_train), 3])*np.nan
D_LOSS = np.zeros([int(epochs*L_train)])*np.nan
V_LOSS = np.zeros([epochs])*np.nan           

In [28]:
tol = 0
batch_size = 200
train_size = 100
record = 999
for i in range(epochs):
    print('epoch = {}'.format(i))
    start_time = time.time()

    # shuffling at epoch begin
    shuffle(trainfiles)
    
    # loop over batches
    for j, name in enumerate(trainfiles):        
        
        # ----- import batch data subset ----- #
        inds = du.shuffle_ind(batch_size)[:train_size]
        temp_batch = np.load(name, allow_pickle=True)[()]
        X = temp_batch['batch'][inds, ...]
        # ------------------------------------ #
        
        # ----- D training ----- #
        # Latent space sampling
        Wf1 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        Wf2 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        # soft labels
        dummy_bad = np.ones(train_size)*0.1 + np.random.uniform(-0.02, 0.02, train_size)
        dummy_good = np.ones(train_size)*0.9 + np.random.uniform(-0.02, 0.02, train_size)
        # get G_output (channel last)
        g_in = [Wf1, Wf2, X[..., input_flag]]
        g_out = G_style.predict(g_in) # <-- np.array
        # train on batch
        d_in_fake = np.concatenate((g_out, X[..., input_flag]), axis=-1)
        d_in_true = X[..., inout_flag]
        d_loss1 = D.train_on_batch(d_in_true, dummy_good)
        d_loss2 = D.train_on_batch(d_in_fake, dummy_bad)
        d_loss = d_loss1 + d_loss2
        # ----------------------- #
        
        # ----- G training ----- #
        # Latent space sampling
        Wf1 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        Wf2 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        # soft labels
        dummy_good = np.ones(train_size)*0.9 + np.random.uniform(-0.02, 0.02, train_size)
        # train on batch
        gan_in = [Wf1, Wf2, X[..., input_flag]]
        gan_target = [X[..., output_flag], dummy_good]
        gan_loss = GAN.train_on_batch(gan_in, gan_target)
        # ---------------------- #
        
        # ----- Backup training loss ----- #
        D_LOSS[i*L_train+j] = d_loss
        GAN_LOSS[i*L_train+j, :] = gan_loss
        # -------------------------------- #
        if j%50 == 0:
            print('\t{} step loss = {}'.format(j, gan_loss))
    # on epoch-end
    record_temp = G_style.evaluate_generator(gen_valid, verbose=1)

    # Backup validation loss
    V_LOSS[i] = record_temp
    # Overwrite loss info
    LOSS = {'GAN_LOSS':GAN_LOSS, 'D_LOSS':D_LOSS, 'V_LOSS':V_LOSS}
    np.save(hist_path, LOSS)

    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        tol = 0
        print('tol: {}'.format(tol))
        # save
        print('save to: {}\n\t{}'.format(G_path, D_path))
        G_style.save(G_path)
        D.save(D_path)
    else:
        print('Validation loss {} NOT improved'.format(record_temp))
        tol += 1
        print('tol: {}'.format(tol))
        if tol >= max_tol:
            print('Early stopping')
            sys.exit();
        else:
            print('Pass to the next epoch')
            continue;

    print("--- %s seconds ---" % (time.time() - start_time))

epoch = 0
	0 step loss = [1.7764993, 1.7761368, 0.3625432]
	50 step loss = [1.0109146, 1.0097383, 1.1762408]
	100 step loss = [0.96827346, 0.96679986, 1.4736171]
	150 step loss = [0.8171611, 0.81634235, 0.8187239]
	200 step loss = [0.7714116, 0.7705031, 0.9084961]
Validation loss improved from 999 to 0.6774541662063127
tol: 0
save to: /glade/work/ksha/data/Keras/BACKUP/TEST_G_TMEAN_jja.hdf
	/glade/work/ksha/data/Keras/BACKUP/TEST_D_TMEAN_jja.hdf
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: /glade/work/ksha/data/Keras/BACKUP/TEST_G_TMEAN_jja.hdf/assets
INFO:tensorflow:Assets written to: /glade/work/ksha/data/Keras/BACKUP/TEST_D_TMEAN_jja.hdf/assets
--- 283.16722774505615 seconds ---
epoch = 1
	0 step loss = [0.6868811, 0.6861192, 0.76195407]
	50 step loss = [0.6269495, 0.62614983, 0.7996398]
	100 step loss = [0.66694015, 0.66605514, 0.88498425]
	150 step loss = [0.5704123, 0.5695648, 0.84744805]
	200 step loss = [0.5

KeyboardInterrupt: 