Trust the notebook to see the training progress.

In [1]:
import pickle
import numpy as np
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.callbacks import TQDMProgressBar

%run xception.ipynb

In [2]:
# Define the paths
dataset = "../data/data_v2.npz"
modelpath = "xception_data_v2.tf"
scalerpath = "xception_data_v2.scaler"

# Load the data
with np.load(dataset) as data:
    image = data["img"]
    label = data["label"]

In [3]:
# Use noise unscaled images as input
X = image[..., np.newaxis]
input_shape = X.shape[1:]

# Scale the labels between -1 and 1 using the full dataset. We know the ranges of generated labels.
scaler = MinMaxScaler(feature_range=(-1, 1))
y = scaler.fit_transform(label)
output_shape = y.shape[1:]

In [5]:
# Build an Xception model
model = build_xception_model(
    input_shape, output_shape, conv2d_num_filters=16, sep_num_filters=64, num_residual_blocks=8,
)
model.compile(optimizer="adam", loss="mse")

# Adapt the normalization layer to the data
adapt_model(model, X)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 64, 1)]  0                                            
__________________________________________________________________________________________________
normalization (Normalization)   (None, 64, 64, 1)    3           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 16)   416         normalization[0][0]              
__________________________________________________________________________________________________
separable_conv2d (SeparableConv (None, 32, 32, 64)   1232        conv2d[0][0]                     
______________________________________________________________________________________________

In [6]:
# Setup callbacks
early_stopping = EarlyStopping(patience=20, restore_best_weights=True)
tqdm_callback = TQDMProgressBar()
checkpoint = ModelCheckpoint(modelpath, save_best_only=True)

# Train the model
history = model.fit(
    X,
    y,
    batch_size=32,
    epochs=1000,
    validation_split=0.1,
    verbose=0,
    callbacks=[early_stopping, tqdm_callback, checkpoint],
)
model.save(modelpath)

HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=1000.0, style=Pro…

Epoch 1/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: xception_data_v2.tf/assets
Epoch 2/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 3/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: xception_data_v2.tf/assets
Epoch 4/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 5/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: xception_data_v2.tf/assets
Epoch 6/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: xception_data_v2.tf/assets
Epoch 7/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 8/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 9/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 10/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 11/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 12/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 13/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 14/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 15/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: xception_data_v2.tf/assets
Epoch 16/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 17/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 18/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 19/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 20/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 21/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 22/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 23/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 24/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 25/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 26/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 27/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 28/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 29/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 30/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 31/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 32/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 33/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 34/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 35/1000


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…



INFO:tensorflow:Assets written to: xception_data_v2.tf/assets


In [7]:
# Evaluate the model on the validation set
n_train = int(label.shape[0] * 0.9)
X_val = X[n_train:]
label_val = label[n_train:]

# Predict the labels and convert them to the original scale
predictions = model.predict(X_val)
predictions = scaler.inverse_transform(predictions)

# Compute RMSE for each label individually
np.set_printoptions(suppress=True)
rmse = np.sqrt(((label_val - predictions)**2).mean(axis=0))
rmse

array([8993.485     ,    0.83563304,    0.05117362,    0.08388926,
          0.08306667], dtype=float32)

In [None]:
# Save the MinMax scaler to convert model outputs
with open(scalerpath, "wb") as f:
    pickle.dump(scaler, f)