In [1]:
# general tools
import sys
import time
from glob import glob

# data tools
import h5py
import numpy as np
from random import shuffle
from datetime import datetime, timedelta

# deep learning tools
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K

# graph tools
import matplotlib.pyplot as plt
%matplotlib inline

# custom tools
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/utils/')
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/DL_downscaling/utils/')
sys.path.insert(0, '/glade/u/home/ksha/WORKSPACE/DL_downscaling/')
from namelist import *
import data_utils as du
import graph_utils as gu
import verif_utils as vu
import model_utils as mu
import train_utils as tu

In [31]:
from importlib import reload
reload(mu)

<module 'model_utils' from '/glade/u/home/ksha/WORKSPACE/DL_downscaling/utils/model_utils.py'>

In [3]:
sea = 'jja'
file_path = BATCH_dir
model_import_dir = temp_dir

In [4]:
input_flag = [False, True, False, False, True, True] # LR T2, HR elev, LR elev
output_flag = [True, False, False, False, False, False] # HR T2
inout_flag = [True, True, False, False, True, True]
labels = ['batch', 'batch'] # input and output labels

In [5]:
file_path = BATCH_dir
validfiles1 = glob(file_path+'TMEAN_BATCH_64_TSUB_*{}*.npy'.format(sea)) + glob(file_path+'TMEAN_BATCH_96_TSUB_*{}*.npy'.format(sea))
validfiles2 = glob(file_path+'TMEAN_BATCH_64_VORI_*{}*.npy'.format(sea)) + glob(file_path+'TMEAN_BATCH_96_VORI_*{}*.npy'.format(sea)) 
#
gen_valid_base1 = tu.grid_grid_gen(validfiles1, labels, input_flag, output_flag)
gen_valid_base2 = tu.grid_grid_gen(validfiles2, labels, input_flag, output_flag)
#
gen_valid_style1 = tu.grid_grid_gen_noise(validfiles1, labels, input_flag, output_flag, 384)
gen_valid_style2 = tu.grid_grid_gen_noise(validfiles2, labels, input_flag, output_flag, 384)

In [6]:
N_input = 3
N_output = 1
N = [48, 96, 192, 384]
input_size = (None, None, N_input)
mapping_size = N[-1]
latent_size = N[-1]
latent_lev = 4

In [7]:
# macros
key = 'MIX'
lmd = 1e-3
pool = False

In [8]:
l = [5e-5, 5e-5] # G lr; D lr
epochs = 150

## UNet baseline

In [10]:
W = tu.dummy_loader(model_import_dir+'UNET-G3_TMEAN_jja_tune.hdf')
model = mu.UNET(N, input_size, input_stack_num=2, pool=pool, activation=activation)
opt_G = keras.optimizers.Adam(lr=0)
model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
model.set_weights(W)
model.evaluate_generator(gen_valid_base1, verbose=1);
model.evaluate_generator(gen_valid_base2, verbose=1);

Import model:
/glade/work/ksha/data/Keras/BACKUP/UNET-G3_TMEAN_jja_tune.hdf


## SRGAN baseline (before tuning)

In [11]:
W = tu.dummy_loader(model_import_dir+'NEO_G_TMEAN_jja.hdf')
model = mu.UNET(N, input_size, input_stack_num=2, pool=pool, activation=activation)
opt_G = keras.optimizers.Adam(lr=0)
model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
model.set_weights(W)
model.evaluate_generator(gen_valid_base1, verbose=1);
model.evaluate_generator(gen_valid_base2, verbose=1);

Import model:
/glade/work/ksha/data/Keras/BACKUP/NEO_G_TMEAN_jja.hdf


## Style-GAN (before tuning)

In [19]:
W = tu.dummy_loader(model_import_dir+'SGAN_G_TMEAN_jja.hdf')
model = mu.UNET_STYLE(N, input_size, latent_lev, latent_size, mapping_size, 
                      pool=pool, activation='leaky', noise=[0.2, 0.1])
opt_G = keras.optimizers.Adam(lr=0)
model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
model.set_weights(W)
model.evaluate_generator(gen_valid_style1, verbose=1);
model.evaluate_generator(gen_valid_style2, verbose=1);

Import model:
/glade/work/ksha/data/Keras/BACKUP/SGAN_G_TMEAN_jja.hdf


# Latent code creation

In [9]:
def mapping_network_fixer(style_model, style_gan_name, trainable=False):
    '''
    Assuming layers names (and number of nodes) match
    '''
    W = tu.dummy_loader(style_gan_name)
    model = mu.UNET_STYLE(N, input_size, latent_lev, latent_size, mapping_size, 
                          pool=pool, activation='leaky', noise=[0.2, 0.1])
    model.set_weights(W)
    W_mapping_network = []
    for layer in model.layers:
        if layer.name[:3] == 'map':
            W_mapping_network += layer.get_weights()
    
    if ~trainable:
        style_model.trainable = False
        for layer in style_model.layers:
            layer.trainable = False
    opt_G = keras.optimizers.Adam(lr=0)
    style_model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
    style_model.set_weights(W_mapping_network)
    return style_model

In [39]:
style_model_base = mu.STYLE_MAP(latent_lev, latent_size, mapping_size, activation='leaky')
style_model_base = mapping_network_fixer(style_model_base, 
                                         model_import_dir+'SGAN_G_TMEAN_jja.hdf', 
                                         trainable=False)

style_model_elev = mu.STYLE_MAP(latent_lev, latent_size, mapping_size, activation='leaky')
style_model_elev = mapping_network_fixer(style_model_elev, 
                                         model_import_dir+'SGAN_G_ELEV_jja.hdf', 
                                         trainable=False)

Import model:
/glade/work/ksha/data/Keras/BACKUP/SGAN_G_TMEAN_jja.hdf
Import model:
/glade/work/ksha/data/Keras/BACKUP/SGAN_G_ELEV_jja.hdf


## Testing original W

In [10]:
class gan_predict(keras.utils.Sequence):
    def __init__(self, filename, labels, input_flag, output_flag, sampling):

        self.filename = filename
        self.labels = labels
        self.input_flag = input_flag
        self.output_flag = output_flag
        self.sampling = sampling
        self.filenum = len(self.filename)
    def __len__(self):
        return self.filenum
    
    def __getitem__(self, index):
        temp_name = self.filename[index]
        return self.__readfile__(temp_name, self.labels, self.input_flag, self.output_flag, self.sampling)
    
    def __readfile__(self, temp_name, labels, input_flag, output_flag, sampling):
        
        data_temp = np.load(temp_name, allow_pickle=True)
        X = data_temp[()][labels[0]][..., input_flag] # channel last
        Y = data_temp[()][labels[1]][..., output_flag]
        return sampling+[X], [Y]

In [11]:
def dscale_network_fixer(dscale_model, dscale_gan_name, trainable=False):
    W = tu.dummy_loader(dscale_gan_name)
    model = mu.UNET_STYLE(N, input_size, latent_lev, latent_size, mapping_size, 
                          pool=pool, activation='leaky', noise=[False, False])
    model.set_weights(W)
    W_unet = []
    for layer in model.layers:
        if layer.name[:4] == 'unet':
            W_unet += layer.get_weights()
    dscale_model.set_weights(W_unet)
    if ~trainable:
        dscale_model.trainable = False
        for layer in dscale_model.layers:
            layer.trainable = False
    return dscale_model
    

In [42]:
model_dscale = mu.UNET_MAP(N, input_size, latent_lev, latent_size, mapping_size, 
                           pool=pool, activation='leaky', noise=[False, False])

opt_G = keras.optimizers.Adam(lr=0)
model_dscale.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
model_dscale = dscale_network_fixer(model_dscale,
                                    model_import_dir+'SGAN_G_TMEAN_jja.hdf',
                                    trainable=False)

Import model:
/glade/work/ksha/data/Keras/BACKUP/SGAN_G_TMEAN_jja.hdf


In [43]:
Z = np.random.normal(0.0, 1.0, size=[5000, latent_size])
W_base = style_model_base.predict(Z)
W_elev = style_model_elev.predict(Z)

W_base_mean = np.mean(W_base, axis=0)
W_elev_mean = np.mean(W_elev, axis=0)

W_base_rep = np.repeat(W_base_mean[None, :], 200, axis=0)
W_elev_rep = np.repeat(W_elev_mean[None, :], 200, axis=0)

In [45]:
gen_valid_base1 = gan_predict(validfiles1, labels, input_flag, output_flag, [W_base_rep])
gen_valid_base2 = gan_predict(validfiles2, labels, input_flag, output_flag, [W_base_rep])
model_dscale.evaluate_generator(gen_valid_base1, verbose=1);
model_dscale.evaluate_generator(gen_valid_base2, verbose=1);



## Mixing W

In [46]:
from model_utils import *

In [47]:
def UNET_MAP_SWAP(N, input_size, latent_lev, latent_size, mapping_size, pool=True, activation='relu', noise=[0.2, 0.1]):
    # mapping network #1
    STY1 = keras.layers.Input(shape=[latent_size], name='mapping_in1')
    STY2 = keras.layers.Input(shape=[latent_size], name='mapping_in2')
    # main UNet
    IN3 = keras.layers.Input(input_size, name='unet_in')
    X_en1 = CONV_stack(IN3, N[0], kernel_size=3, stack_num=2, activation=activation, name='unet_left0')
    X_en2 = UNET_left(X_en1, N[1], pool=pool, activation=activation, name='unet_left1')
    X_en3 = UNET_left(X_en2, N[2], pool=pool, activation=activation, name='unet_left2')
    
    X_de4 = UNET_left_style(X_en3, STY1, N[3], pool=pool, activation=activation, noise=noise[0], name='unet_bottom')
    X_de3 = UNET_right_style(X_de4, X_en3, STY2, N[2], activation=activation, noise=noise[1], name='unet_right2')
    X_de2 = UNET_right_style(X_de3, X_en2, STY2, N[1], activation=activation, noise=noise[1], name='unet_right1')
    X_de1 = UNET_right_style(X_de2, X_en1, STY2, N[0], activation=activation, noise=noise[1], name='unet_right0')
    # output
    OUT = UNET_out_style(X_de1, STY2, activation=activation, name='unet_out')
    OUT = keras.layers.Conv2D(1, 1, activation=keras.activations.linear, padding='same', name='unet_exit')(OUT)
    G_style = keras.models.Model(inputs=[STY1, STY2, IN3], outputs=[OUT])
    return G_style

In [48]:
model_mix = UNET_MAP_SWAP(N, input_size, latent_lev, latent_size, mapping_size, 
                             pool=pool, activation='leaky', noise=[False, False])

opt_G = keras.optimizers.Adam(lr=0)
model_mix.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
model_mix = dscale_network_fixer(model_mix, model_import_dir+'SGAN_G_TMEAN_jja.hdf',
                                  trainable=False)

Import model:
/glade/work/ksha/data/Keras/BACKUP/SGAN_G_TMEAN_jja.hdf


**Replace**

In [49]:
gen_valid_mix1 = gan_predict(validfiles1, labels, input_flag, output_flag, [W_elev_rep, W_base_rep])
gen_valid_mix2 = gan_predict(validfiles2, labels, input_flag, output_flag, [W_elev_rep, W_base_rep])
model_mix.evaluate_generator(gen_valid_mix1, verbose=1);
model_mix.evaluate_generator(gen_valid_mix2, verbose=1);



**other failed tests**

In [50]:
# #
# W_mix = np.copy(W_base_mean)
# flag = W_mix==0
# W_mix[flag] = W_elev_mean[flag]
# #
# W_mix = 0.5*(W_elev_mean+W_base_mean)
# #
# W_mix = np.max(np.concatenate((W_base_mean[..., None], W_elev_mean[..., None]), axis=-1), axis=-1)
# #
# W_mix = np.min(np.concatenate((W_base_mean[..., None], W_elev_mean[..., None]), axis=-1), axis=-1)

# gen_valid_mix1 = gan_predict(validfiles1, labels, input_flag, output_flag, [W_mix, W_base_mean])
# gen_valid_mix2 = gan_predict(validfiles2, labels, input_flag, output_flag, [W_mix, W_base_mean])
# model_mix.evaluate_generator(gen_valid_mix1, verbose=1);
# model_mix.evaluate_generator(gen_valid_mix2, verbose=1);

**zero remove (improved?)**

In [21]:
W_base_mean[248]

1.5447198

In [144]:
flag_sparse = np.logical_or(W_base_mean==0, W_elev_mean==0)
print('zero out {} vectors'.format(np.sum(flag_sparse)))
W_mix = np.copy(W_base[:200, :])
W_mix[:, flag_sparse] = 0

gen_valid_mix1 = gan_predict(validfiles1, labels, input_flag, output_flag, [W_mix, W_base_rep])
gen_valid_mix2 = gan_predict(validfiles2, labels, input_flag, output_flag, [W_mix, W_base_rep])
model_mix.evaluate_generator(gen_valid_mix1, verbose=1);
model_mix.evaluate_generator(gen_valid_mix2, verbose=1);

zero out 271 vectors


**The joint of the two latent space**

In [61]:
# W_mix = np.copy(W_base)
# W_mix[:, flag_sparse] = 0
# C_mix = np.cov(W_mix, rowvar=0)
# C_base = np.cov(W_base, rowvar=0)
# C_elev = np.cov(W_elev, rowvar=0)
# mu_mix = np.copy(W_base_mean)
# mu_mix[flag_sparse] = 0

In [69]:
C_base = np.cov(W_base, rowvar=0)
C_elev = np.cov(W_elev, rowvar=0)

flag_sparse = np.logical_or(W_base_mean==0, W_elev_mean==0)
mu_base = np.copy(W_base_mean)
mu_base[flag_sparse] = 0

In [138]:
W_base_fake = np.random.multivariate_normal(mu_base, C_base, size=100000)
W_elev_fake = np.random.multivariate_normal(mu_base, C_elev, size=100000)

base_p99 = np.percentile(W_base_fake, 99, axis=0); base_p99[flag_sparse] = 0
base_p01 = np.percentile(W_base_fake, 1, axis=0); base_p01[flag_sparse] = 0

elev_p99 = np.percentile(W_elev_fake, 99, axis=0); elev_p99[flag_sparse] = 0
elev_p01 = np.percentile(W_elev_fake, 1, axis=0); elev_p01[flag_sparse] = 0

In [139]:
mix_p01 = np.max(np.concatenate((base_p01[..., None], elev_p01[..., None]), axis=-1), axis=-1)
mix_p99 = np.min(np.concatenate((base_p99[..., None], elev_p99[..., None]), axis=-1), axis=-1)

In [142]:
W_mix = 0.5*(mix_p99+mix_p01)
W_mix[flag_sparse] = 0
W_mix_rep = np.repeat(W_mix[None, :], 200, axis=0)

In [143]:
gen_valid_mix1 = gan_predict(validfiles1, labels, input_flag, output_flag, [W_mix_rep, W_base_rep])
gen_valid_mix2 = gan_predict(validfiles2, labels, input_flag, output_flag, [W_mix_rep, W_base_rep])
model_mix.evaluate_generator(gen_valid_mix1, verbose=1);
model_mix.evaluate_generator(gen_valid_mix2, verbose=1);



In [32]:
test = mu.UNET_MAP_MIX(N, input_size, latent_lev, latent_size, mapping_size, 
                            pool=pool, activation='leaky', noise=[False, False])

In [33]:
test.summary()

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
unet_in (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
unet_left0_stack0_conv (Conv2D) (None, None, None, 4 1296        unet_in[0][0]                    
__________________________________________________________________________________________________
unet_left0_stack0_bn (BatchNorm (None, None, None, 4 192         unet_left0_stack0_conv[0][0]     
__________________________________________________________________________________________________
unet_left0_stack0_leaky (LeakyR (None, None, None, 4 0           unet_left0_stack0_bn[0][0]       
____________________________________________________________________________________________

In [63]:
W_base_test = np.copy(W_base)
W_base_test[:, flag_sparse] = 0
W_elev_test = np.copy(W_elev)
W_elev_test[:, flag_sparse] = 0

In [50]:
W_elev_test.shape

(5000, 384)

In [51]:
~flag_sparse[i]

False

In [78]:
# L = len(flag_sparse)
# for i in range(L):
#     for j in range(L):
#         if (i > j) and np.logical_not(np.logical_or(flag_sparse[i], flag_sparse[j])):
#             fig = plt.figure(figsize=(6, 6))
#             ax = fig.gca()
#             ax.plot(W_base_test[:, i], W_base_test[:, j], 'r.')
#             ax.plot(W_elev_test[:, i], W_elev_test[:, j], 'b.')
#             #ax = gu.ax_decorate_box(ax)
#             ax.set_aspect('equal')
#             minimum = np.min((ax.get_xlim(),ax.get_ylim()))
#             maximum = np.max((ax.get_xlim(),ax.get_ylim()))
#             ax.set_xlim(minimum*1.2,maximum*1.2)
#             ax.set_ylim(minimum*1.2,maximum*1.2)
#             ax.set_title('x-axis: {}, y-axis: {}'.format(i, j), fontsize=14)

**manifold learning**

In [28]:
from sklearn import manifold
from sklearn import decomposition

In [27]:
import GPy
#GPy.models.GPLVM(data, input_dim=input_dim, kernel=kernel)

In [None]:
# from sklearn import manifold
# from sklearn import decomposition
# W_concat = np.concatenate((W_base.T, W_elev.T), axis=-1)

# TSVD = decomposition.TruncatedSVD(n_components=1, n_iter=1000)
# TSVD.fit(W_concat)
# W_reduced = TSVD.transform(W_concat)
# W_mix = np.repeat(W_reduced[:, 0][None, :], 200, axis=0)

# gen_valid_mix1 = gan_predict(validfiles1, labels, input_flag, output_flag, [W_mix, W_base_mean])
# gen_valid_mix2 = gan_predict(validfiles2, labels, input_flag, output_flag, [W_mix, W_base_mean])
# model_mix.evaluate_generator(gen_valid_mix1, verbose=1);
# model_mix.evaluate_generator(gen_valid_mix2, verbose=1);

**NN-based style mix**

In [9]:
def mapping_network_fixer(style_model, style_gan_name, trainable=False):
    '''
    Assuming layers names (and number of nodes) match
    '''
    W = tu.dummy_loader(style_gan_name)
    model = mu.UNET_STYLE(N, input_size, latent_lev, latent_size, mapping_size, 
                          pool=pool, activation=activation, noise=[0.2, 0.1])
    model.set_weights(W)
    W_mapping_network = []
    for layer in model.layers:
        if layer.name[:3] == 'map':
            W_mapping_network += layer.get_weights()
    
    if ~trainable:
        style_model.trainable = False
        for layer in style_model.layers:
            layer.trainable = False
    opt_G = keras.optimizers.Adam(lr=0)
    style_model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
    style_model.set_weights(W_mapping_network)
    return style_model


def mix_network_fixer(dscale_model, dscale_gan_name):
    W = tu.dummy_loader(dscale_gan_name)
    model = mu.UNET_STYLE(N, input_size, latent_lev, latent_size, mapping_size, 
                          pool=pool, activation=activation, noise=[False, False])
    model.set_weights(W)
    
    for layer in dscale_model.layers:
        if layer.name[:4] == 'unet':
            layer.trainable = False
    opt_G = keras.optimizers.Adam(lr=0)
    dscale_model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
    
    for layer in model.layers:
        for layer2 in dscale_model.layers:
            if layer.name == layer2.name and layer.name[:4] == 'unet':
                layer2.set_weights(layer.get_weights())
                
    return dscale_model

In [10]:
style_model_base = mu.STYLE_MAP(latent_lev, latent_size, mapping_size)
style_model_base = mapping_network_fixer(style_model_base, 
                                         model_import_dir+'STYLE_G_TMEAN_jja.hdf', 
                                         trainable=False)

style_model_elev = mu.STYLE_MAP(latent_lev, latent_size, mapping_size)
style_model_elev = mapping_network_fixer(style_model_elev, 
                                         model_import_dir+'STYLE_G_ELEV_jja.hdf', 
                                         trainable=False)

Import model:
/glade/work/ksha/data/Keras/BACKUP/STYLE_G_TMEAN_jja.hdf
Import model:
/glade/work/ksha/data/Keras/BACKUP/STYLE_G_ELEV_jja.hdf


In [35]:
model_mix = mu.UNET_MAP_MIX(N, input_size, latent_lev, latent_size, mapping_size, 
                            pool=pool, activation=activation, noise=[False, False])
model_mix = mix_network_fixer(model_mix, model_import_dir+'STYLE_G_TMEAN_jja.hdf')

Import model:
/glade/work/ksha/data/Keras/BACKUP/STYLE_G_TMEAN_jja.hdf


In [12]:
input_size_d = (None, None, N_input+1)
D = mu.vgg_descriminator(N, input_size_d)
opt_D = keras.optimizers.Adam(lr=l[1])
print('Compiling D')
D.compile(loss=keras.losses.mean_squared_error, optimizer=opt_D)
D.trainable = False
for layer in D.layers:
    layer.trainable = False

Compiling D


In [13]:
GAN_IN1 = keras.layers.Input(shape=[latent_size])
GAN_IN2 = keras.layers.Input(shape=[latent_size])
GAN_IN3 = keras.layers.Input((None, None, N_input))

# W1 = style_model_base(GAN_IN1)
# W2 = style_model_elev(GAN_IN2)

G_OUT = model_mix([GAN_IN1, GAN_IN2, GAN_IN3])
D_IN = keras.layers.Concatenate()([G_OUT, GAN_IN3])
D_OUT = D(D_IN)
GAN = keras.models.Model([GAN_IN1, GAN_IN2, GAN_IN3], [G_OUT, D_OUT])
# optimizer
opt_GAN = keras.optimizers.Adam(lr=l[0])
print('Compiling GAN')
# content_loss + 1e-3 * adversarial_loss
GAN.compile(loss=[keras.losses.mean_squared_error, keras.losses.binary_crossentropy], 
            loss_weights=[1.0, lmd],
            optimizer=opt_GAN)

Compiling GAN


In [14]:
# early stopping settings
min_del = 0
max_tol = 10 # early stopping with patience

In [15]:
class grid_grid_gen_noise_temp(keras.utils.Sequence):
    def __init__(self, filename, labels, input_flag, output_flag, latent_size):

        self.filename = filename
        self.labels = labels
        self.input_flag = input_flag
        self.output_flag = output_flag
        self.latent_size = latent_size
        self.filenum = len(self.filename)
    def __len__(self):
        return self.filenum
    
    def __getitem__(self, index):
        temp_name = self.filename[index]
        return self.__readfile__(temp_name, self.labels, self.input_flag, self.output_flag, self.latent_size)
    
    def __readfile__(self, temp_name, labels, input_flag, output_flag, latent_size):
        
        data_temp = np.load(temp_name, allow_pickle=True)
        X = data_temp[()][labels[0]][..., input_flag] # channel last
        Y = data_temp[()][labels[1]][..., output_flag]
        train_size = len(Y)
        Z1 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        Z2 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        noise = [np.squeeze(style_model_base.predict(Z1)), np.squeeze(style_model_elev.predict(Z2))]
        return noise+[X], [Y]

In [16]:
# key = 'STYLE'
# # Filepath
# file_path = BATCH_dir
# trainfile64 = glob(file_path+'TMEAN_BATCH_64_TORI_*{}*.npy'.format(sea))
# trainfile96 = glob(file_path+'TMEAN_BATCH_96_TORI_*{}*.npy'.format(sea))
# validfiles1 = glob(file_path+'TMEAN_BATCH_64_TSUB_*{}*.npy'.format(sea)) + glob(file_path+'TMEAN_BATCH_96_TSUB_*{}*.npy'.format(sea))
# validfiles2 = glob(file_path+'TMEAN_BATCH_64_VORI_*{}*.npy'.format(sea)) + glob(file_path+'TMEAN_BATCH_96_VORI_*{}*.npy'.format(sea)) 

In [17]:
# gen_train = grid_grid_gen_noise_temp(trainfile64+trainfile96, labels, input_flag, output_flag, latent_size)
# gen_valid = grid_grid_gen_noise_temp(validfiles1, labels, input_flag, output_flag, latent_size)

In [18]:
# G_name = '{}_G_TMEAN_{}_trans'.format(key, sea)
# D_name = '{}_D_TMEAN_{}_trans'.format(key, sea)
# G_path = temp_dir+G_name+'.hdf'
# callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=True),
#              keras.callbacks.ModelCheckpoint(filepath=G_path, verbose=True, monitor='val_loss', save_best_only=True)]

In [19]:
# model_mix.fit_generator(generator=gen_train, validation_data=gen_valid, callbacks=callbacks, 
#                         verbose=1, shuffle=True, max_queue_size=8, workers=8)

In [20]:
key = 'STYLE'
# Filepath
file_path = BATCH_dir
trainfile64 = glob(file_path+'TMEAN_BATCH_64_TORI_*{}*.npy'.format(sea))
trainfile96 = glob(file_path+'TMEAN_BATCH_96_TORI_*{}*.npy'.format(sea))
validfiles1 = glob(file_path+'TMEAN_BATCH_64_TSUB_*{}*.npy'.format(sea)) + glob(file_path+'TMEAN_BATCH_96_TSUB_*{}*.npy'.format(sea))
validfiles2 = glob(file_path+'TMEAN_BATCH_64_VORI_*{}*.npy'.format(sea)) + glob(file_path+'TMEAN_BATCH_96_VORI_*{}*.npy'.format(sea)) 
#
L_train = 320
gen_valid1 = grid_grid_gen_noise_temp(validfiles1, labels, input_flag, output_flag, latent_size)
gen_valid2 = grid_grid_gen_noise_temp(validfiles2, labels, input_flag, output_flag, latent_size)
# model names
G_name = '{}_G_TMEAN_{}_trans'.format(key, sea)
D_name = '{}_D_TMEAN_{}_trans'.format(key, sea)
G_path = temp_dir+G_name+'.hdf'
D_path = temp_dir+D_name+'.hdf'
hist_path = temp_dir+'{}_LOSS_TMEAN_{}_trans.npy'.format(key, sea)

# loss backup
GAN_LOSS = np.zeros([int(epochs*L_train), 3])*np.nan
D_LOSS = np.zeros([int(epochs*L_train)])*np.nan
V_LOSS = np.zeros([epochs])*np.nan
V_LOSS_TRANS = np.zeros([epochs])*np.nan
tol = 0
batch_size = 200
train_size = 100
record = 999

for i in range(epochs):
    print('epoch = {}'.format(i))
    start_time = time.time()
    shuffle(trainfile64)
    shuffle(trainfile96)
    trainfiles = trainfile64[:160] + trainfile96[:160]
    # shuffling at epoch begin
    shuffle(trainfiles)
    # loop over batches
    for j, name in enumerate(trainfiles):        

        # ----- import batch data subset ----- #
        inds = du.shuffle_ind(batch_size)[:train_size]
        temp_batch = np.load(name, allow_pickle=True)[()]
        X = temp_batch['batch'][inds, ...]
        # ------------------------------------ #

        # ----- D training ----- #
        # Latent space sampling
        Z1 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        Z2 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        W1 = np.squeeze(style_model_base.predict(Z1))
        W2 = np.squeeze(style_model_elev.predict(Z2))
        # soft labels
        dummy_bad = np.ones(train_size)*0.1 + np.random.uniform(-0.02, 0.02, train_size)
        dummy_good = np.ones(train_size)*0.9 + np.random.uniform(-0.02, 0.02, train_size)
        # get G_output (channel last)
        g_in = [W1, W2, X[..., input_flag]]
        g_out = model_mix.predict(g_in) # <-- np.array
        # train on batch
        d_in_fake = np.concatenate((g_out, X[..., input_flag]), axis=-1)
        d_in_true = X[..., inout_flag]
        d_loss1 = D.train_on_batch(d_in_true, dummy_good)
        d_loss2 = D.train_on_batch(d_in_fake, dummy_bad)
        d_loss = d_loss1 + d_loss2
        # ----------------------- #

        # ----- G training ----- #
        # Latent space sampling
        Z1 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        Z2 = np.random.normal(0.0, 1.0, size = [train_size, latent_size])
        W1 = np.squeeze(style_model_base.predict(Z1))
        W2 = np.squeeze(style_model_elev.predict(Z2))
        # soft labels
        dummy_good = np.ones(train_size)*0.9 + np.random.uniform(-0.02, 0.02, train_size)
        # train on batch
        gan_in = [W1, W2, X[..., input_flag]]
        gan_target = [X[..., output_flag], dummy_good]
        gan_loss = GAN.train_on_batch(gan_in, gan_target)
        # ---------------------- #

        # ----- Backup training loss ----- #
        D_LOSS[i*L_train+j] = d_loss
        GAN_LOSS[i*L_train+j, :] = gan_loss
        # -------------------------------- #
        if j%50 == 0:
            print('\t{} step loss = {}'.format(j, gan_loss))
    # on epoch-end
    record_temp_trans = model_mix.evaluate_generator(gen_valid1, verbose=1)
    record_temp = model_mix.evaluate_generator(gen_valid2, verbose=1)
    # Backup validation loss
    V_LOSS[i] = record_temp
    V_LOSS_TRANS[i] = record_temp_trans
    # Overwrite loss info
    LOSS = {'GAN_LOSS':GAN_LOSS, 'D_LOSS':D_LOSS, 
            'V_LOSS':V_LOSS, 'V_LOSS_TRANS':V_LOSS_TRANS}
    np.save(hist_path, LOSS)

    #record_temp += record_temp_trans 
    record_temp = record_temp_trans

    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        tol = 0
        print('tol: {}'.format(tol))
        # save
        print('save to: {}\n\t{}'.format(G_path, D_path))
        model_mix.save(G_path)
        D.save(D_path)
    else:
        print('Validation loss {} NOT improved'.format(record_temp))
        tol += 1
        print('tol: {}'.format(tol))
        if tol >= max_tol:
            print('Early stopping')
            sys.exit();
        else:
            print('Pass to the next epoch')
            continue;

    print("--- %s seconds ---" % (time.time() - start_time))

epoch = 0
	0 step loss = [17.530735, 0.026471388, 0.45896655]


KeyboardInterrupt: 

In [21]:
GAN.fit(gan_in, gan_target)

Train on 100 samples


<tensorflow.python.keras.callbacks.History at 0x2ab5f4d69c90>

In [38]:
model_mix.evaluate_generator(gen_valid1, verbose=1)

 10/119 [=>............................] - ETA: 25s - loss: 0.1084

KeyboardInterrupt: 

In [121]:
# def UNET_MAP(N, input_size, latent_lev, latent_size, mapping_size, pool=True, activation='relu', noise=[0.2, 0.1]):
#     # mapping network #1
#     STY1 = keras.layers.Input(shape=[latent_size], name='mapping_in1')
#     STY2 = keras.layers.Input(shape=[latent_size], name='mapping_in2')
#     STY3 = keras.layers.Input(shape=[latent_size], name='mapping_in3')
#     STY4 = keras.layers.Input(shape=[latent_size], name='mapping_in4')
#     STY5 = keras.layers.Input(shape=[latent_size], name='mapping_in5')
#     # main UNet
#     IN3 = keras.layers.Input(input_size, name='unet_in')
#     X_en1 = CONV_stack(IN3, N[0], kernel_size=3, stack_num=2, activation=activation, name='unet_left0')
#     X_en2 = UNET_left(X_en1, N[1], pool=pool, activation=activation, name='unet_left1')
#     X_en3 = UNET_left(X_en2, N[2], pool=pool, activation=activation, name='unet_left2')
    
#     X_de4 = UNET_left_style(X_en3, STY1, N[3], pool=pool, activation=activation, noise=noise[0], name='unet_bottom')
#     X_de3 = UNET_right_style(X_de4, X_en3, STY2, N[2], activation=activation, noise=noise[1], name='unet_right2')
#     X_de2 = UNET_right_style(X_de3, X_en2, STY3, N[1], activation=activation, noise=noise[1], name='unet_right1')
#     X_de1 = UNET_right_style(X_de2, X_en1, STY4, N[0], activation=activation, noise=noise[1], name='unet_right0')
#     # output
#     OUT = UNET_out_style(X_de1, STY5, activation=activation, name='unet_out')
#     OUT = keras.layers.Conv2D(1, 1, activation=keras.activations.linear, padding='same', name='unet_exit')(OUT)
#     G_style = keras.models.Model(inputs=[STY1, STY2, STY3, STY4, STY5, IN3], outputs=[OUT])
#     return G_style


# IN1 = keras.layers.Input(shape=[latent_size], name='z_in1')
# IN2 = keras.layers.Input(shape=[latent_size], name='z_in2')
# IN3 = keras.layers.Input(input_size, name='unet_in')

# style_model_base = mu.STYLE_MAP(latent_lev, latent_size, mapping_size)
# opt_G = keras.optimizers.Adam(lr=0)
# style_model_base.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
# style_model_base = mapping_network_fixer(style_model_base, model_import_dir+'STYLE_G_TMEAN_jja.hdf', trainable=False)

# style_model_elev = mu.STYLE_MAP(latent_lev, latent_size, mapping_size)
# opt_G = keras.optimizers.Adam(lr=0)
# style_model_elev.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
# style_model_elev = mapping_network_fixer(style_model_elev, model_import_dir+'STYLE_G_ELEV_jja.hdf', trainable=False)

# W1 = style_model_base([IN1])
# W2 = style_model_base([IN2])
# W_mix = keras.layers.Concatenate()([W1, W2])
# # ----- manifold learning layers ----- #
# W_map1 = keras.layers.Dense(latent_size)(W_mix)
# W_map2 = keras.layers.Dense(latent_size)(W_mix)
# W_map3 = keras.layers.Dense(latent_size)(W_mix)
# W_map4 = keras.layers.Dense(latent_size)(W_mix)
# W_map5 = keras.layers.Dense(latent_size)(W_mix)
# # ------------------------------------- #
# unet_map = UNET_MAP(N, input_size, latent_lev, latent_size, mapping_size, 
#                     pool=pool, activation=activation, noise=[False, False])

# opt_G = keras.optimizers.Adam(lr=0)
# unet_map.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
# unet_map = dscale_network_fixer(unet_map, model_import_dir+'STYLE_G_TMEAN_jja.hdf', trainable=False)
# OUT = unet_map([W_map1, W_map2, W_map3, W_map4, W_map5, IN3])

# model = keras.models.Model([IN1, IN2, IN3], [OUT])

In [95]:
opt = keras.optimizers.Adam(lr=1e-5)
model.compile(loss=keras.losses.mean_squared_error, optimizer=opt)

In [96]:
model.summary()

Model: "model_18"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
z_in1 (InputLayer)              [(None, 384)]        0                                            
__________________________________________________________________________________________________
z_in2 (InputLayer)              [(None, 384)]        0                                            
__________________________________________________________________________________________________
model_12 (Model)                (None, 384)          739200      z_in1[0][0]                      
                                                                 z_in2[0][0]                      
__________________________________________________________________________________________________
concatenate_36 (Concatenate)    (None, 768)          0           model_12[1][0]            

In [60]:
model_mix.summary()

Model: "model_25"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
unet_in (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
unet_left0_stack0_conv (Conv2D) (None, None, None, 4 1296        unet_in[0][0]                    
__________________________________________________________________________________________________
unet_left0_stack0_bn (BatchNorm (None, None, None, 4 192         unet_left0_stack0_conv[0][0]     
__________________________________________________________________________________________________
unet_left0_stack0_relu (ReLU)   (None, None, None, 4 0           unet_left0_stack0_bn[0][0]       
___________________________________________________________________________________________