In [2]:
# Create NN with uncertainty prediction (heteroskedastic regression)
# from an existing model by training a network that takes the penultimate
# hidden layer of the base network and predicting the standard deviation 
# of the error distribution.
# Author: Peter Sadowski, Dec 2020
import numpy as np
import h5py
import os, sys
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # Needed to avoid cudnn bug.
import tensorflow as tf
from tensorflow.keras.callbacks import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import plot_model

sys.path = ['../'] + sys.path
from sarhs.generator import SARGenerator
from sarhs.heteroskedastic import Gaussian_NLL, Gaussian_MSE


In [42]:
# Load trained model.
file_model_src = './model.h5'
model_base = load_model(file_model_src)

# Define new output that predicts uncertainty. 
base_inputs = model_base.input
base_penultimate = model_base.get_layer('dense_7').output
base_output = model_base.output
model_base.trainable = False

x = Dense(256, activation='relu', name='std_hidden')(base_penultimate)
std_output = Dense(1, activation='softplus', name='std_output')(x)

output = concatenate([base_output, std_output], axis=-1)
model = Model(inputs=base_inputs, outputs=output)

import keras_extras
importlib.reload(keras_extras.losses.dirichlet)
from keras_extras.losses.dirichlet import Gaussian_NLL, Gaussian_MSE

opt = Adam(lr=0.0001)
model.compile(loss=Gaussian_NLL, optimizer=opt, metrics=[Gaussian_MSE])
#new_model.summary()

In [43]:
filename = '/mnt/tmp/psadow/sar/sar_hs.h5'
hp = {'bs':128}
train = sarhs.generator.SARGenerator2(filename=filename, subgroups=['2015_2016', '2018'], batch_size=hp['bs'])
valid = sarhs.generator.SARGenerator2(filename=filename, subgroups=['2017'], batch_size=hp['bs'])

history = model.fit(
    train,
    epochs=5,
    validation_data=valid,
    #callbacks=clbks,
    #steps_per_epoch=100
    verbose= 1 if INTERACTIVE else 2,
    )

model.save('model_transfer.h5')

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100

KeyboardInterrupt: 

In [52]:
# Make predictions with uncertainty estimates.
from tqdm import tqdm 

def predict(model, dataset):
    ys, yhats = [], []
    for batch in dataset:
        inputs, y = batch
        yhat = model.predict_on_batch(inputs)
        if y is not None:
            y = y.reshape(-1,2)
        else:
            y = np.zeros((yhat.shape[0], 1))
        ys.append(y)
        yhats.append(yhat)
    yhat = np.vstack(yhats)
    y = np.vstack(ys)
    return y, yhat

def define_groups():
    groups = {}
    for isat in range(2):
        for year in [2019]:
            for imonth in range(12):
                sat = 'A' if isat==1 else 'B'
                month = imonth+1
                name = f'S1{sat}_{year}{month:02d}S'
                groups[name] = (isat, year, month)
    return groups

# Dataset
filename = '/home/psadow/lts/preserve/stopa/sar_hs/data/alt/sar_hs_2019.h5' # Contains all processed 2019 data.
for group, (isat, year, month) in tqdm(define_groups().items()):
    # Make predictions for this group.
    test = sarhs.generator.SARGenerator(filename, subgroups=[group], batch_size=200)
    #print(test._num_examples())
    _, yhat = predict(model,test)
    
    # The predictions should be in order.
    # Include longitude, latitude, time, and file name.
    df = pd.DataFrame()
    df['hsNN'] = yhat[:,0]
    df['hsNN_std'] = yhat[:,1]
    df['timeSAR'] = test.h5file[group]['timeSAR'][:].flatten()
    df['latSAR'] = test.h5file[group]['latlonSAR'][:, 0]
    df['lonSAR'] = test.h5file[group]['latlonSAR'][:, 1]
    Path("./predictions/").mkdir(parents=True, exist_ok=True)
    df.to_csv(f'predictions/{group}.csv', index=False, )
    
print('Done')
print(df.columns)

100%|██████████| 24/24 [14:35<00:00, 36.48s/it]

Done
Index(['hsNN', 'hsNN_std', 'timeSAR', 'latSAR', 'lonSAR'], dtype='object')





# Fine tune the prediction part on 2017 data (the validation set).

In [55]:
# Load trained model.
file_model = './models/model_45.h5'
model_base = load_model(file_model)
# Fine tune.
opt = Adam(lr=0.00001)
model_base.compile(loss='mae', optimizer=opt, metrics=['mae', 'mse'])
history = model_base.fit(valid, epochs=5, verbose= 1 if INTERACTIVE else 2)
# Loss initially at mse 0.15, more than 0.13 because of dropout.

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [56]:
model_base.save('model_45_tuned.h5')

In [63]:


# Add back output that predicts uncertainty.
base_inputs = model_base.input
base_penultimate = model_base.get_layer('dense_7').output
base_output = model_base.output
model_std = load_model('./model_45_std.h5', custom_objects={'Gaussian_NLL':Gaussian_NLL, 'Gaussian_MSE': Gaussian_MSE})
x = model_std.get_layer('std_hidden')(base_penultimate)
std_output = model_std.get_layer('std_output')(x)
output = concatenate([base_output, std_output], axis=-1)
model = Model(inputs=base_inputs, outputs=output)
# Compile and save.
opt = Adam(lr=0.0001)
from keras_extras.losses.dirichlet import Gaussian_NLL, Gaussian_MSE
model.compile(loss=Gaussian_NLL, optimizer=opt, metrics=[Gaussian_MSE])
model.save('model_45_std_tuned.h5')

In [64]:
# Make predictions with uncertainty estimates.
from tqdm import tqdm 

def predict(model, dataset):
    ys, yhats = [], []
    for batch in dataset:
        inputs, y = batch
        yhat = model.predict_on_batch(inputs)
        if y is not None:
            y = y.reshape(-1,2)
        else:
            y = np.zeros((yhat.shape[0], 1))
        ys.append(y)
        yhats.append(yhat)
    yhat = np.vstack(yhats)
    y = np.vstack(ys)
    return y, yhat

def define_groups():
    groups = {}
    for isat in range(2):
        for year in [2019]:
            for imonth in range(12):
                sat = 'A' if isat==1 else 'B'
                month = imonth+1
                name = f'S1{sat}_{year}{month:02d}S'
                groups[name] = (isat, year, month)
    return groups

# Dataset
filename = '/home/psadow/lts/preserve/stopa/sar_hs/data/alt/sar_hs_2019.h5' # Contains all processed 2019 data.
for group, (isat, year, month) in tqdm(define_groups().items()):
    # Make predictions for this group.
    test = sarhs.generator.SARGenerator(filename, subgroups=[group], batch_size=200)
    #print(test._num_examples())
    _, yhat = predict(model,test)
    
    # The predictions should be in order.
    # Include longitude, latitude, time, and file name.
    df = pd.DataFrame()
    df['hsNN'] = yhat[:,0]
    df['hsNN_std'] = yhat[:,1]
    df['timeSAR'] = test.h5file[group]['timeSAR'][:].flatten()
    df['latSAR'] = test.h5file[group]['latlonSAR'][:, 0]
    df['lonSAR'] = test.h5file[group]['latlonSAR'][:, 1]
    Path("./predictions_std_tuned/").mkdir(parents=True, exist_ok=True)
    df.to_csv(f'predictions_std_tuned/{group}.csv', index=False, )
    
print('Done')
print(df.columns)

100%|██████████| 24/24 [09:18<00:00, 23.27s/it]

Done
Index(['hsNN', 'hsNN_std', 'timeSAR', 'latSAR', 'lonSAR'], dtype='object')



