# CNN for Regression from Noisy Images to Labels

## Stage 1: Hyperparameter Grid Search on a Subset of the Data (20K Samples)

In [14]:
import json
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from sklearn.model_selection import ParameterGrid
from sklearn.preprocessing import MinMaxScaler

%run xception.ipynb

In [15]:
dataset = "../../data/data_v1_small.npz"

with np.load(dataset) as data:
    print("Available variables:", data.files)
    image = data["img"]
    label = data["label"]

Available variables: ['img', 'img_nonoise', 'label', 'psf_r', 'snr', 'sigma']


In [16]:
# Use unscaled noisy images. Let TF normalize the data.
X = image[..., np.newaxis]
input_shape = X.shape[1:]

# Scale the labels between -1 and 1 using the full dataset. We know the ranges of generated labels.
scaler = MinMaxScaler(feature_range=(-1, 1))
y = scaler.fit_transform(label)
output_shape = y.shape[1:]

In [None]:
# params = dict(conv2d_num_filters=[16, 32, 64], sep_num_filters=[128, 256, 384])
params = dict(conv2d_num_filters=[8, 16, 32, 64], sep_num_filters=[8, 16, 32, 64])
model_args = list(ParameterGrid(params))
total = len(model_args)

for i, args in enumerate(model_args, start=1):
    cval = args["conv2d_num_filters"]
    sval = args["sep_num_filters"]
    filename = Path("trials-v1") / f"trial-c{cval}-s{sval}.json"
    if filename.exists():
        continue
        
    print(f"Testing {i}/{total}: {args}")
    model = build_xception_model(input_shape, output_shape, num_residual_blocks=8, **args)
    model.compile(optimizer="adam", loss="mse")
    early_stopping = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    adapt_model(model, X)
    history = model.fit(
        X,
        y,
        batch_size=32,
        epochs=50,
        validation_split=0.2,
        verbose=2,
        callbacks=[early_stopping],
    )
    
    doc = args.copy()
    doc['score'] = np.min(history.history["val_loss"])
    with open(filename, "w") as f:
        json.dump(doc, f)
    print()