In [None]:
import tensorflow as tf
import xarray as xr
import numpy as np
import os
import sys
from scipy.ndimage import gaussian_filter
from tensorflow.keras.optimizers import *
import glob
import time
import keras.backend as K

In [None]:
t = time.strftime("%Y_%m_%d_%H_%M", time.localtime())

def scheduler(epoch):
  if epoch < 6:
    return 0.0001
  else:
    return 0.0001 * tf.math.exp(0.1 * (10 - epoch))

def ref_only_loss(y_true, y_pred, thresh):
    mask = tf.math.greater(y_true, thresh)
    y_true2 = tf.boolean_mask(y_true, mask)
    y_pred2 = tf.boolean_mask(y_pred, mask)
    mse = tf.keras.losses.MeanSquaredError()
    huber = tf.keras.losses.Huber()
    return mse(y_true2, y_pred2)

def refl_loss(thresh):
    def ref(y_true, y_pred):
        return ref_only_loss(y_true, y_pred, thresh)
    return ref

#
# Mean Absolute Error metric
#
def mae(y_true, y_pred):
            
    eval = K.abs(y_pred - y_true)
    eval = K.mean(eval, axis=-1)
        
    return eval

In [None]:
############################
# set up the run information
############################
output_root_directory = '/glade/work/hardt/models'
model_run_name        = 'unet_v1p0'
from unet_model_v1p0 import unet
#--------------------------
output_model_name     = 'trained_model_{}.h5'
log_dir = os.path.join(output_root_directory, model_run_name, 'logs', 'fit',output_model_name.format(t))
feature_data     = '/glade/work/hardt/ds612/2000-2013_June-Sept_scale_REFL.nc'
label_data       = '/glade/work/hardt/ds612/2000-2013_June-Sept_scale_maxW.nc'
#feature_data     = '/glade/work/hardt/ds612/2000-2013_June-Sept_CTRLradrefl_REFL.nc'
#label_data       = '/glade/work/hardt/ds612/2000-2013_June-Sept_CTRL3D_maxW.nc'
############################

In [None]:
output_path = os.path.join(output_root_directory, model_run_name)
if not os.path.exists(output_path):
    os.makedirs(output_path)
    
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [None]:
#
# load the data
#
fds = xr.open_dataset(feature_data)
lds = xr.open_dataset(label_data)
feature = fds.refl.values
label = lds.maxW.values

In [None]:
#
# set up the the data sets
#
#feature_dataset = tf.data.Dataset.from_tensor_slices(feature[0:6112,:,:,np.newaxis])
#label_dataset = tf.data.Dataset.from_tensor_slices(label[0:6112,:,:,np.newaxis])
#train_dataset = tf.data.Dataset.zip((feature_dataset, label_dataset))

train_dataset = tf.data.Dataset.from_tensor_slices((feature[0:6112,:,:,np.newaxis], label[0:6112,:,:,np.newaxis]))
val_dataset   = tf.data.Dataset.from_tensor_slices((feature[6113:7649,:,:,np.newaxis], label[6113:7649,:,:,np.newaxis]))
print(train_dataset)

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 6112

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True).batch(BATCH_SIZE, drop_remainder=True)
val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)

print(train_dataset)

In [None]:
#
# set up the model
#
output_model = os.path.join(output_path, output_model_name)


In [None]:
model = unet()
mse = tf.keras.losses.MeanSquaredError()
#model.compile(optimizer = SGD(lr=1e-4, momentum=0.5), loss=refl_loss(thresh=0.5), metrics = ['accuracy'], run_eagerly=True)
model.compile(optimizer = SGD(lr=1e-4, momentum=0.5), loss=refl_loss(0.01), metrics = ['mae'], run_eagerly=True)
#model.compile(optimizer = SGD(lr=1e-4, momentum=0.5), loss=mse, metrics = ['accuracy'], run_eagerly=True)

In [None]:
tf.keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)

In [None]:
#
# callbacks
#
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model_save_callback = tf.keras.callbacks.ModelCheckpoint(filepath='/glade/scratch/hardt/unet_v1/trained_model_{epoch}.h5',save_freq='epoch')
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(output_path,"weights_best.h5"), monitor='val_mae', verbose=1, save_best_only=True, mode='min')
LRS = tf.keras.callbacks.LearningRateScheduler(scheduler)

In [None]:
model.fit(train_dataset, epochs=50, validation_data=val_dataset, callbacks=[LRS, checkpoint])

In [None]:
#
# write out the trained model
#
t = time.strftime("%Y_%m_%d_%H_%M", time.localtime())
model.save(output_model.format(t))
