In [60]:
# builtins
import locale
import math
import glob
import pathlib
import functools
import logging

# numerical stuff
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization, Conv2D, MaxPooling2D, Conv2DTranspose, Reshape
from tensorflow.keras.layers import Activation, Dropout, Dense, Flatten, Input, UpSampling2D, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# plotting
import matplotlib.pyplot as plt

logging.basicConfig(level=logging.INFO)


In [56]:
data_path = 'gs://bathy_sample/processed/20211013/combined_data'
all_checkpoints_path = 'gs://bathy_sample/dnn/checkpoints'
model_name = 'guus-2d-mlp-cnn-v0.1'
learning_rate = 1e-4
n_epochs = 20
batch_size = 8
checkpoints_path = all_checkpoints_path + '/' + model_name


In [27]:
def tf_parse(eg):
    """parse an example (or batch of examples, not quite sure...)"""

    # here we re-specify our format
    # you can also infer the format from the data using tf.train.Example.FromString
    # but that did not work
    example = tf.io.parse_example(
        eg[tf.newaxis],
        {
            "height": tf.io.FixedLenFeature([], tf.int64),
            "width": tf.io.FixedLenFeature([], tf.int64),
            "depth": tf.io.FixedLenFeature([], tf.int64),
            "bathy": tf.io.FixedLenFeature([], tf.string),
            "hs": tf.io.FixedLenFeature([], tf.string),
            "eta": tf.io.FixedLenFeature([], tf.float32),
            "zeta": tf.io.FixedLenFeature([], tf.float32),
            "theta": tf.io.FixedLenFeature([], tf.float32),
        },
    )
    bathy = tf.io.parse_tensor(example["bathy"][0], out_type="float32")
    hs = tf.io.parse_tensor(example["hs"][0], out_type="float32")
    eta = example["eta"]
    zeta = example["zeta"]
    theta = example["theta"]
    attr = tf.stack([eta, zeta, theta], axis=1)
    attr = tf.reshape(attr, shape=[-1])
    return (bathy, attr), hs

In [31]:
def get_files(data_path):
    files = tf.io.gfile.glob(data_path + "/" + "*.tfrecords")
    return files

def get_dataset(files):
    """return a tfrecord dataset with all tfrecord files"""
    dataset =  tf.data.TFRecordDataset(files)
    dataset = dataset.map(tf_parse)
    return dataset

In [39]:
def create_mlp(dim):
    model = Sequential()
    model.add(Dense(64, input_dim=dim, activation="relu"))
    model.add(Dense(1024, activation="relu"))
    model.add(Dense(256 * 256, activation="relu"))
    model.build((None, 256 * 256))

    return model


def create_cnn(width, height, depth):
    input_shape = (height, width, depth)

    inputs = Input(shape=input_shape)
    x = Conv2D(64, (3, 3), padding="same")(inputs)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    ax = MaxPooling2D(pool_size=(2, 2))(x)

    # Branch 1
    x = Conv2D(32, (3, 3), padding="same")(ax)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    bx = MaxPooling2D(pool_size=(2, 2))(x)

    # Branch 2
    x = Conv2D(32, (3, 3), padding="same")(bx)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    cx = MaxPooling2D(pool_size=(2, 2))(x)

    # Branch 3
    x = Conv2D(32, (3, 3), padding="same")(cx)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(32, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2DTranspose(32, (3, 3), padding="same")(x)

    # Branch 2
    x = Concatenate()([x, cx])
    x = Conv2D(32, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2DTranspose(32, (3, 3), padding="same")(x)

    # Branch 1
    x = Concatenate()([x, bx])
    x = Conv2D(32, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2DTranspose(32, (3, 3), padding="same")(x)

    # Main Branch
    x = Concatenate()([x, ax])
    x = Conv2D(16, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    model = Model(inputs, x)

    return model


def full_model(cnn_model, mlp_model):

    x = cnn_model.output
    cx = mlp_model.output

    conv_shape = K.int_shape(x)

    cx = Reshape((conv_shape[1], conv_shape[2], int(conv_shape[3] / 4)))(cx)

    x = Concatenate()([x, cx])

    x = Conv2D(16, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    x = Conv2D(32, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    x = Conv2D(64, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    x = Conv2D(128, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    x = Conv2D(128, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2DTranspose(1, (3, 3), padding="same", activation="linear")(x)

    model = Model(inputs=[cnn_model.input, mlp_model.input], outputs=x)

    return model

In [37]:
files = get_files(data_path)
train_files, test_files = train_test_split(files)
train_dataset = get_dataset(train_files)
test_dataset = get_dataset(test_files)

train_dataset = train_dataset.batch(batch_size)
test_dataset = test_dataset.batch(batch_size)
len(train_files), len(test_files)

(76, 26)

In [52]:
list(get_dataset(files[1:2]))

10

In [41]:
cnn_model = create_cnn(256, 256, 1)
mlp_model = create_mlp(3)

In [57]:
model = full_model(cnn_model, mlp_model)
opt = Adam(learning_rate=learning_rate, decay=learning_rate / n_epochs)
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoints_path)
]
model.compile(loss="mean_squared_error", optimizer=opt)

In [None]:
logging.info("training model...")
steps_per_epoch = len(train_files) * 10 // batch_size

model.fit(x=train_dataset, validation_data=test_dataset, epochs=n_epochs, steps_per_epoch=steps_per_epoch)

INFO:root:training model...


Epoch 1/20
 2/95 [..............................] - ETA: 10:02 - loss: 34.7611

In [None]:
model.save('ModelV5.h5')



model = load_model('/content/drive/MyDrive/DeepLearning/ModelV11.h5', compile = True)


source = 9

Prediction = model.predict([inputImages[318:319], inputAttr[318:319]])[0][:,:,0]
Truehs = outputImages[318][:,:,0]
#Prediction = model.predict([testImgX[0:10], testAttrX[0:10]])[source][:,:,0]
Prediction[Prediction < 0] = np.nan
#Truehs = testY[source][:,:,0]
Truehs[Truehs < 0] = np.nan
#print(testAttrX[6])
fig = plt.figure(figsize=(6,3))

ax = fig.add_subplot(1,2,1)
ax.set_title('colorMap')
plt.imshow(Prediction)

qx = fig.add_subplot(1,2,2)
plt.imshow(Truehs)

cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
cax.get_xaxis().set_visible(False)
cax.get_yaxis().set_visible(False)
cax.patch.set_alpha(0)
cax.set_frame_on(False)
plt.colorbar(orientation='vertical')
plt.show()