In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.callbacks import TQDMProgressBar

%run ../gridsearch/xception.ipynb

In [0]:
# Load denoised log-images
basedir = Path("../../")
dataset = basedir / "./data/data_v1.npz"
modelpath = "xception_log.tf"

with np.load(dataset) as data:
    print("Available variables:", data.files)
    image = data["img_log_denoised"]
    label = data["label"]

Available variables: ['img', 'img_nonoise', 'label', 'psf_r', 'snr', 'sigma']


In [0]:
# Use unscaled images. Let TF normalize the data.
X = image[..., np.newaxis]
input_shape = X.shape[1:]

# Scale the labels between -1 and 1 using the full dataset. We know the ranges of generated labels.
scaler = MinMaxScaler(feature_range=(-1, 1))
y = scaler.fit_transform(label)
output_shape = y.shape[1:]

In [0]:
# Build the best model. SeparableConv2D layers form Xception blocks.
model = build_xception_model(
    input_shape, output_shape, conv2d_num_filters=16, sep_num_filters=64, num_residual_blocks=8
)
# Adapt the normalization layer to the data
adapt_model(model, X)
model.summary()

In [0]:
# Setup callbacks. TQDM is used due to issues with the default progress bar on my TF2.1 installation.
early_stopping = EarlyStopping(patience=10, restore_best_weights=True)
tqdm_callback = TQDMProgressBar()
checkpoint = ModelCheckpoint(modelpath, save_best_only=True)

# Train the xception model on the full dataset
model.compile(optimizer=Adam(0.0001), loss="mse")
history = model.fit(
    X,
    y,
    batch_size=32,
    epochs=100,
    validation_split=0.1,
    verbose=0,
    callbacks=[early_stopping, tqdm_callback, checkpoint],
)

In [0]:
# Evaluate the model on the validation set
n_train = int(label.shape[0] * 0.9)
X_val = X[n_train:]
label_val = label[n_train:]

# Predict the labels and convert them to the original scale
predictions = model.predict(X_val)
predictions = scaler.inverse_transform(predictions)

# Compute RMSE for each label individually
np.set_printoptions(suppress=True)
rmse = np.sqrt(((label_val - predictions) ** 2).mean(axis=0))
rmse