# Miscelaneous helper functions for Xception CNN

In [None]:
import pickle

import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.python.util import nest

In [None]:
def adapt_model(model, dataset):
    """Adapt the preprocessing layers, e.g. Normalization(), to the dataset.
    """
    if isinstance(dataset, tf.data.Dataset):
        x = dataset.map(lambda x, y: x)
    else:
        x = nest.flatten(dataset)

    def get_output_layer(tensor):
        tensor = nest.flatten(tensor)[0]
        for layer in model.layers:
            if isinstance(layer, tf.keras.layers.InputLayer):
                continue
            input_node = nest.flatten(layer.input)[0]
            if input_node is tensor:
                return layer
        return None

    for index, input_node in enumerate(nest.flatten(model.input)):

        def get_data(*args):
            return args[index]

        if isinstance(x, tf.data.Dataset):
            temp_x = x.map(get_data)
        else:
            temp_x = x[index]
        layer = get_output_layer(input_node)
        while isinstance(layer, preprocessing.PreprocessingLayer):
            layer.adapt(temp_x)
            layer = get_output_layer(layer.output)
    return model

In [None]:
# Define available models
xception_models = [
    {
        "name": "Noiseless images with fixed PSF",
        "X": lambda data: data["img_nonoise"][..., np.newaxis],
        "dataset": "../data/data_v1.npz",
        "modelpath": "../models/xception_data_v1_noiseless.tf",
        "scalerpath": "../models/xception_data_v1.scaler",
    },
    {
        "name": "Noisy images with fixed PSF and noise",
        "X": lambda data: data["img"][..., np.newaxis],
        "dataset": "../data/data_v1.npz",
        "modelpath": "../models/xception_data_v1.tf",
        "scalerpath": "../models/xception_data_v1.scaler",
    },
    {
        "name": "Noisy images with fixed PSF and varying noise",
        "X": lambda data: data["img"][..., np.newaxis],
        "dataset": "../data/data_v2.npz",
        "modelpath": "../models/xception_data_v2.tf",
        "scalerpath": "../models/xception_data_v2.scaler",
    },
    {
        "name": "Noisy images with varying PSF and noise",
        "X": lambda data: np.stack((data["img"], data["psf_img"]), axis=-1),
        "dataset": "../data/data_v3.npz",
        "modelpath": "../models/xception_data_v3.tf",
        "scalerpath": "../models/xception_data_v3.scaler",
    },
]

In [None]:
def load_xception_model(model_version):
    """Load a model that was trained on a particular dataset.
    """
    config = xception_models[model_version]
    with open(config["scalerpath"], "rb") as f:
        scaler = pickle.load(f)

    model = tf.keras.models.load_model(config["modelpath"])

    return model, scaler

In [1]:
def load_xception(model_version):
    """Load a model and the corresponding dataset.

    Args:
        model_version: an identificator of the model
    Returns:
        a dictionary with a model, a scaler, X, y, ...
    """
    config = xception_models[model_version]
    model, scaler = load_xception_model(model_version)

    with np.load(config["dataset"]) as data:
        X = config["X"](data)
        label = data["label"]
        snr = data["snr"]
        noisy = data["img"]
        noiseless = data["img_nonoise"]
        psf_r = data["psf_r"]
        sigma = data["sigma"]

    # Obtain the validation set
    n_train = int(label.shape[0] * 0.9)
    X_val = X[n_train:]
    y_val = label[n_train:]
    snr_val = snr[n_train:]
    noisy_val = noisy[n_train:]
    noiseless_val = noiseless[n_train:]
    psf_r_val = psf_r[n_train:]
    sigma_val = sigma[n_train:]

    result = {
        "model": model,
        "scaler": scaler,
        "X": X,
        "y": label,
        "snr": snr,
        "X_val": X_val,
        "y_val": y_val,
        "snr_val": snr_val,
        "noisy_val": noisy_val,
        "noiseless_val": noiseless_val,
        "psf_r_val": psf_r_val,
        "sigma_val": sigma_val,
    }
    return result

In [None]:
def load_experiments():
    """Load a set of experiments for training the models.
    """
    for experiment in xception_models:
        with np.load(experiment["dataset"]) as data:
            X = experiment["X"](data)
            label = data["label"]

        experiment = experiment.copy()
        experiment["X"] = X
        experiment["label"] = label
        yield experiment

In [None]:
def load_testset(snr, load_psf=False):
    datasets = {30: "../data/snr30.npz", 60: "../data/snr60.npz"}

    dataset = datasets[snr]
    with np.load(dataset) as data:
        image = data["img"]
        psf_image = data["psf_img"]
        label = data["label"]

    if load_psf:
        X = np.stack((image, np.broadcast_to(psf_image, image.shape)), axis=-1)
    else:
        X = image[..., np.newaxis]

    # All labels in the test set are the same, take the first one
    y = label[0]

    return X, y