In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning) 
import numpy as np
import os
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
import h5py
import glob
import random
import time
import tensorflow as tf
import datetime


# Create model

In [2]:
from temporal_3d_general import *

def vae_loss(x, t_decoded):
    '''Total loss for the plain UAE'''
    return K.mean(reconstruction_loss(x, t_decoded))

def reconstruction_loss(x, t_decoded):
    '''Reconstruction loss for the plain UAE'''
    return K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2, axis=-1)

input_shape=(96, 200, 24 ,3)
vae_model = create_vae(input_shape)
vae_model.summary()
output_dir = 'saved_models/'

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              (None, 96, 200, 24,  0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 48, 100, 12,  1312        image[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 48, 100, 12,  64          conv3d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 48, 100, 12,  0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv3d_1 (

# Load data set

In [3]:
def load_data(name):
    hf_r = h5py.File(f'/data/cees/gegewen/CCSNet_publish_dataset/{name}.hdf5', 'r')
    variable = np.array(hf_r.get(name))
    hf_r.close()
    return variable

def create_x(train_input, train_BPR, train_BYMF):
    data_x_BPR = (train_BPR - 100)/(565 - 100)
    data_x_temp = np.repeat(train_input[:,:,:,-4,:][:,:,:,np.newaxis,:], 24, axis=-2)
    return np.concatenate([data_x_temp, data_x_BPR, train_BYMF], axis=-1)

In [None]:
train_x = create_x(load_data('train_x'), load_data('train_y_BPR'), load_data('train_y_BYMF'))
train_x[:,:,:,:,-1][train_x[:,:,:,:,-1]==0] = 0.9 # Clipping for normalization
train_x[:,:,:,:,-1] = (train_x[:,:,:,:,-1] - 0.9)/0.1 # Normalization
train_y = load_data('train_y_BDENG')
train_y[train_y==0] = 100 # Clipping for normalization
train_y = (train_y - 100) / 900 # Normalization

In [None]:
test_x = create_x(load_data('test_x'), load_data('test_y_BPR'), load_data('test_y_BYMF'))
test_x[:,:,:,:,-1][test_x[:,:,:,:,-1]==0] = 0.9 # Clipping for normalization
test_x[:,:,:,:,-1] = (test_x[:,:,:,:,-1] - 0.9)/0.1 # Normalization
test_y = load_data('test_y_BDENG')
test_y[test_y==0] = 100 # Clipping for normalization
test_y = (test_y - 100) / 900 # Normalization

In [None]:
print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

In [None]:
# Random shuffle
np.random.seed(0)
train_nr = train_x.shape[0]
train_shuffle_index = np.random.choice(train_nr, train_nr, replace=False)
print(train_shuffle_index.shape)

test_nr = test_x.shape[0]
test_shuffle_index = np.random.choice(test_nr, test_nr, replace=False)
print(test_shuffle_index.shape)

train_x = train_x[train_shuffle_index, ...]
train_y = train_y[train_shuffle_index, ...]
test_x = test_x[test_shuffle_index, ...]
test_y = test_y[test_shuffle_index, ...]

print('train_x shape is ', train_x.shape)
print('train_y shape is ', train_y.shape)
print('test_x shape is ', test_x.shape)
print('test_y shape is ', test_y.shape)

# Training

*Note: gradually reduce learning rate from 1e-4 to 1e-7 as you train*

In [None]:
train_nr = train_x.shape[0]
batch_size = 8
num_batch = int(train_nr/batch_size)
test_nr = 8
e_start = 0
epoch = 1000
learning_rate = 1e-4

opt = Adam(lr=learning_rate)

train_target = K.placeholder(shape=(batch_size, 96, 200, 24 ,1))
test_target = K.placeholder(shape=(test_nr, 96, 200, 24 ,1))

rec_loss = vae_loss(vae_model.output, train_target)
vae_model.compile(loss=vae_loss, optimizer=opt)

total_loss = rec_loss

updates = opt.get_updates(total_loss, vae_model.trainable_weights)

iterate = K.function(vae_model.inputs + [train_target], [rec_loss], updates=updates)

eval_rec_loss = vae_loss(vae_model.output, test_target)

evaluate = K.function(vae_model.inputs + [test_target], [eval_rec_loss])

model_string = 'bdeng'

In [None]:
for e in range(e_start, e_start + epoch):
    for ib in range(num_batch):
        ind0 = ib * batch_size        
        x_batch = train_x[ind0:ind0+batch_size, ...]
        y_batch = train_y[ind0:ind0+batch_size, ...]
        rec_loss_val = iterate([x_batch] + [y_batch])
        
        if ib % 100 == 0:
            print('Epoch %d/%d, Batch %d/%d, Rec Loss %f' % (e+1, epoch, ib+1, num_batch, rec_loss_val[0]))
            
    eval_rec_loss_val = evaluate([test_x[:4,...]] + [test_y[:4,...]])
    print('Epoch %d/%d, Train Rec loss %f, Eval Rec loss %f' % (e + 1, epoch, rec_loss_val[0], eval_rec_loss_val[0]))
    
    if (e+1) % 5 == 0:
        vae_model.save_weights(output_dir + model_string + '_%dtrain_lr%f_ep%d.h5' % (train_nr, learning_rate, (e + 1)))

vae_model.save_weights(output_dir + model_string + '_%dtrain_lr%f_ep%d.h5' % (train_nr, learning_rate, epoch + e_start))