# CNN for Regression from Noiseless Images to Labels

## Stage 3: Training a Tiny Xception Model on the Full Dataset (200k Samples)

In [1]:
import json
from pprint import pprint
import numpy as np
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
import tensorflow_addons as tfa

%run xception.ipynb

In [2]:
dataset = "../../data/data_v1.npz"

with np.load(dataset) as data:
    print("Available variables:", data.files)
    image = data["img_nonoise"]
    label = data["label"]

Available variables: ['img', 'img_nonoise', 'label', 'psf_r', 'snr', 'sigma']


In [3]:
# Use unscaled images. Let TF normalize the data.
X = image[..., np.newaxis]
input_shape = X.shape[1:]

# Scale the labels between -1 and 1 using the full dataset. We know the ranges of generated labels.
scaler = MinMaxScaler(feature_range=(-1, 1))
y = scaler.fit_transform(label)
output_shape = y.shape[1:]

In [4]:
# Build a small Xception model. SeparableConv2D layers form Xception blocks.
model = build_xception_model(
    input_shape, output_shape, conv2d_num_filters=16, sep_num_filters=64, num_residual_blocks=8
)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 64, 1)]  0                                            
__________________________________________________________________________________________________
normalization (Normalization)   (None, 64, 64, 1)    3           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 16)   416         normalization[0][0]              
__________________________________________________________________________________________________
separable_conv2d (SeparableConv (None, 32, 32, 64)   1232        conv2d[0][0]                     
______________________________________________________________________________________________

In [5]:
# Configure the model
model.compile(optimizer="adam", loss="mse")

# Setup callbacks. TQDM is used due to issues with the default progress bar on my TF2.1 installation.
early_stopping = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
tqdm_callback = tfa.callbacks.TQDMProgressBar()
checkpoint = tf.keras.callbacks.ModelCheckpoint("noiseless-checkpoint.tf", save_best_only=True)

# Adapt the normalization layer to the data
adapt_model(model, X)
# Train the best model on the full dataset
history = model.fit(
    X,
    y,
    batch_size=32,
    epochs=100,
    validation_split=0.1,
    verbose=0,
    callbacks=[early_stopping, tqdm_callback, checkpoint],
)

HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), style=ProgressStyle(d…

Epoch 1/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 2/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 3/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 4/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 5/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 6/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 7/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 8/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 9/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 10/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 11/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 12/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 13/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 14/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 15/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 16/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 17/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 18/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 19/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 20/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 21/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 22/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


INFO:tensorflow:Assets written to: noiseless-checkpoint.tf/assets
Epoch 23/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 24/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 25/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 26/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…


Epoch 27/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5625.0), HTML(value='')), layout=Layout(d…





In [9]:
# Save the fully trained model
model.save("noiseless_tiny_200k.tf")

INFO:tensorflow:Assets written to: noiseless_tiny_200k.tf/assets


In [4]:
# Optionally: reload the saved model
model = tf.keras.models.load_model("noiseless_tiny_200k.tf")

In [6]:
# Evaluate the model on the validation set
n_train = int(label.shape[0] * 0.9)
X_val = X[n_train:]
label_val = label[n_train:]

# Predict the labels and convert them to the original scale
predictions = model.predict(X_val)
predictions = scaler.inverse_transform(predictions)

# Compute RMSE for each label individually
np.set_printoptions(suppress=True)
rmse = np.sqrt(((label_val - predictions)**2).mean(axis=0))
rmse

array([383.31216   ,   0.02920323,   0.00223816,   0.00404614,
         0.00360098], dtype=float32)

In [7]:
# Show the standard deviations of the labels for comparison
label_val.std(axis=0)

array([48742.957     ,     1.5922561 ,     0.14061895,     0.2791248 ,
           0.27738345], dtype=float32)

In [8]:
# RMSE / std. dev
rmse / label_val.std(axis=0)

array([0.00786395, 0.01834079, 0.01591651, 0.0144958 , 0.01298194],
      dtype=float32)