In [8]:
# Train neural network to predict significant wave height from SAR spectra.
# Author: Peter Sadowski, Dec 2020
import os, sys
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # Needed to avoid cudnn bug.
import numpy as np
import h5py

import tensorflow as tf
from tensorflow.keras.utils import Sequence, plot_model
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

sys.path = ['../'] + sys.path
from sarhs.generator import SARGenerator

In [11]:
def define_model():
    # Low-level features.
    inputs = Input(shape=(72, 60, 2))
    x = Conv2D(64, (3, 3), activation='relu')(inputs)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(128, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(256, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = GlobalMaxPooling2D()(x)
    x = Dense(256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    cnn = Model(inputs, x)

    # High-level features.
    inp = Input(shape=(32, ))  # 'hsSM', 'hsWW3v2', 'hsALT', 'altID', 'target' -> dropped
    x = Dense(units=256, activation='relu')(inp)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(units=256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dropout(0.5)(x)
    ann = Model(inputs=inp, outputs=x)
    
    # Combine
    combinedInput = concatenate([cnn.output, ann.output])
    x = Dense(256, activation="relu")(combinedInput)
    x = Dropout(0.5)(x)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.5)(x)
    x = Dense(1, activation="softplus")(x)
    model = Model(inputs=[cnn.input, ann.input], outputs=x)
    
    opt = Adam(lr=0.00025897101528140915)
    model.compile(loss='mean_squared_error', optimizer=opt)
    
    return model


In [None]:
# Train
model = define_model()
file_model = 'model.h5'
#model.summary()
#plot_model(model, to_file='model.png')

# Dataset
batch_size = 128
epochs = 123
filename = '/mnt/tmp/psadow/sar/sar_hs.h5'
#filename = '../../data/alt/sar_hs.h5'
train = SARGenerator(filename=filename, 
                     groups=['2015_2016', '2017'], 
                     batch_size=batch_size)
valid = SARGenerator(filename=filename, groups=['2018'], batch_size=batch_size)

# Callbacks
# This LR schedule is slower than in the paper.
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=1) 
check = ModelCheckpoint(file_model, monitor='val_loss', verbose=0,
                        save_best_only=True, save_weights_only=False,
                        mode='auto', save_freq='epoch')
stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, 
                     mode='auto', baseline=None, restore_best_weights=False)
clbks = [reduce_lr, check, stop]

history = model.fit(train,
                    epochs=epochs,
                    validation_data=valid,
                    callbacks=clbks,
                    verbose=1)


Epoch 1/123
Epoch 2/123
Epoch 3/123
Epoch 4/123
Epoch 5/123
Epoch 6/123
Epoch 7/123
Epoch 8/123
Epoch 9/123
Epoch 10/123
Epoch 11/123
Epoch 12/123
Epoch 13/123
Epoch 14/123
Epoch 15/123
Epoch 16/123
Epoch 17/123
Epoch 18/123
Epoch 19/123
Epoch 20/123
Epoch 21/123
Epoch 22/123
Epoch 23/123
Epoch 24/123