In [20]:
#!/usr/bin/env python
import os, datetime
import socket
from IPython import get_ipython
HOSTNAME = socket.gethostname()
INTERACTIVE = get_ipython() is not None
if INTERACTIVE:
    get_ipython().run_line_magic('env', 'CUDA_VISIBLE_DEVICES=0')
SHERPA_TRIAL_ID = os.environ.get('SHERPA_TRIAL_ID', '0000')
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # Needed to avoid cudnn bug.

import sherpa
import numpy as np
import pandas as pd
import h5py
from pathlib import Path

import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model

import sys
sys.path.append('/home/psadow/lts/preserve/stopa/sar_hs/sar_hs/')
import sarhs.generator 
import importlib
importlib.reload(sarhs.generator)

env: CUDA_VISIBLE_DEVICES=0


<module 'sarhs.generator' from '/home/psadow/lts/preserve/stopa/sar_hs/sar_hs/sarhs/generator.py'>

In [30]:
def define_model():
    # Low-level features.
    inputs = Input(shape=(72, 60, 2))
    x = Conv2D(64, (3, 3), activation='relu')(inputs)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(128, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(256, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = GlobalMaxPooling2D()(x)
    x = Dense(256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    cnn = Model(inputs, x)

    # High-level features.
    inp = Input(shape=(32, ))  # 'hsSM', 'hsWW3v2', 'hsALT', 'altID', 'target' -> dropped
    x = Dense(units=256, activation='relu')(inp)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(units=256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dropout(0.5)(x)
    ann = Model(inputs=inp, outputs=x)
    
    # Combine
    combinedInput = concatenate([cnn.output, ann.output])
    x = Dense(256, activation="relu")(combinedInput)
    x = Dropout(0.5)(x)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.5)(x)
    x = Dense(1, activation="softplus")(x)
    model = Model(inputs=[cnn.input, ann.input], outputs=x)
    
    opt = Adam(lr=0.00025897101528140915)
    model.compile(loss='mean_squared_error', optimizer=opt)
    
    return model

def step_decay_schedule(initial_lr=1e-3, decay_factor=0.40, step_size=4):
    '''Wrapper function to create a LearningRateScheduler with step decay schedule.'''
    def schedule(epoch):
        if epoch >= 10 and epoch < 20:
            exponent = 1
        elif epoch >= 20 and epoch <= 118:
            exponent = 2
        else:
            exponent = 3
        return initial_lr * (decay_factor ** exponent)
    return LearningRateScheduler(schedule)



In [31]:
# Train
model = define_model()
file_model = 'model.h5'
#model.summary()
#plot_model(model, to_file='model.png')

# Dataset
importlib.reload(sarhs.generator)
batch_size = 128
epochs = 123
#filename = '/home/psadow/lts/preserve/stopa/sar_hs/data/alt/sar_hs.h5'
filename = '/mnt/tmp/psadow/sar/sar_hs.h5'
train = sarhs.generator.SARGenerator(filename=filename, split='2015_2016', batch_size=batch_size)
valid = sarhs.generator.SARGenerator(filename=filename, split='2017', batch_size=batch_size)
test = sarhs.generator.SARGenerator(filename=filename, split='2018', batch_size=batch_size)


# Callbacks
#reduce_lr = step_decay_schedule(initial_lr=0.00025897101528140915, decay_factor=0.40, step_size=4)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=1) # This is slower than in paper.
check = ModelCheckpoint(file_model, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
clbks = [reduce_lr, check, stop]

# train = datagen(gen_data, 101171//batch_size, file, 'train', batch_size)
# val = datagen(gen_data, 28906//batch_size, file, 'val', batch_size)
history = model.fit(train,
                    epochs=epochs,
                    validation_data=valid,
                    callbacks = clbks,
                    verbose = 1)


Epoch 1/123
Epoch 2/123
Epoch 3/123
Epoch 4/123
Epoch 5/123
Epoch 6/123
Epoch 7/123
Epoch 8/123
Epoch 9/123
Epoch 10/123
Epoch 11/123
Epoch 12/123
Epoch 13/123
Epoch 14/123
Epoch 15/123
Epoch 16/123
Epoch 17/123
Epoch 18/123
Epoch 19/123
Epoch 20/123
Epoch 21/123
Epoch 22/123
Epoch 23/123
 288/2469 [==>...........................] - ETA: 35s - loss: 0.1200

KeyboardInterrupt: 

In [32]:
test = sarhs.generator.SARGenerator(filename=filename, split='2018', batch_size=batch_size)

history = model.fit(valid,
                    epochs=epochs,
                    validation_data=test,
                    callbacks = clbks,
                    verbose = 1)
model.save('model2_valid.h5')

Epoch 1/123
Epoch 2/123
Epoch 3/123
Epoch 4/123
Epoch 5/123
Epoch 6/123
Epoch 7/123
Epoch 8/123
Epoch 9/123
Epoch 10/123
Epoch 11/123
Epoch 12/123
Epoch 13/123
Epoch 14/123
Epoch 15/123
Epoch 16/123
Epoch 17/123


In [34]:
history = model.fit(test,
                    epochs=10,
                    verbose = 1)
model.save('model2_test.h5')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
