In [1]:
import tensorflow as tf
import xarray as xr
import numpy as np
import os
import sys
from scipy.ndimage import gaussian_filter

In [16]:
############################
# set up the run information
############################
output_root_directory = '/glade/work/hardt/models'
model_run_name        = 'test2'
input_model_name      = 'start_model_v2.h5'
output_model_name     = 'trained_model_s1_{}.h5'

feature_data          = '/glade/work/hardt/ds612/2000-2013_June-Sept_CTRLradrefl_REFL.nc'
label_data            = '/glade/work/hardt/ds612/2000-2013_June-Sept_CTRL3D_maxW.nc'
############################

output_path = os.path.join(output_root_directory, model_run_name)
if not os.path.exists(output_path):
    os.makedirs(output_path)

In [17]:
#
# set up the model
#
from model import unet

input_model = os.path.join(output_path, input_model_name)
output_model = os.path.join(output_path, output_model_name)

if not os.path.isfile(input_model):
    test_model = unet()
    test_model.save(input_model)

In [54]:
#
# load the data
#
xds = xr.open_dataset(feature_data)
yds = xr.open_dataset(label_data)
x = xds.refl.values
y = yds.maxW.values

y = gaussian_filter(y, sigma=1)
y[y<2.0]=0

In [55]:
train_dataset = tf.data.Dataset.from_tensor_slices((x[0:6112,:,:,np.newaxis], y[0:6112,:,:,np.newaxis]))
val_dataset = tf.data.Dataset.from_tensor_slices((x[6113:7649,:,:,np.newaxis], y[6113:7649,:,:,np.newaxis]))

In [56]:
print(train_dataset)

<TensorSliceDataset shapes: ((256, 256, 1), (256, 256, 1)), types: (tf.float32, tf.float32)>


In [57]:
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 6112

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True).batch(BATCH_SIZE, drop_remainder=True)
val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)

In [58]:
print(train_dataset)

<BatchDataset shapes: ((32, 256, 256, 1), (32, 256, 256, 1)), types: (tf.float32, tf.float32)>


In [59]:
len(list(train_dataset))

191

In [60]:
model = tf.keras.models.load_model(input_model)

In [61]:
# model.fit(train_dataset, epochs=5)
model.fit(train_dataset, epochs=5, validation_data=val_dataset)

Train for 191 steps, validate for 48 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x2aec840140d0>

In [62]:
#
# write out the trained model
#
import time
t = time.localtime()
d = [str(k) for k in t]
model.save(output_model.format("_".join(d[:5])))