# Training DeepSD models

This notebook trains DeepSD models with training data prepared using
`deepsd_data_prep.ipynb`. Here, it is assumed that available data is present on
azure storage and separated by year and month. The model architecture and
training parameters are taken from
[this KDD paper](https://arxiv.org/abs/1703.03126), and the model training code
is adapted from the [deepsd github repo](https://github.com/tjvandal/deepsd). We
[forked the SRCNN repo implemented by the DeepSD author](https://github.com/carbonplan/srcnn-tensorflow)
and made slight modifications described in this
[PR](https://github.com/carbonplan/srcnn-tensorflow/pull/1).

Main differences between the DeepSD model described in the paper and implemented
in this notebook are:

1. The batch normalization layer originally in the SCRNN model was removed.
   Instead, the input data is normalized by a 30 year historical average and
   standard deviation in the data prep notebook. This historical average and
   standard deviation is then used to restored downscaled GCM data as a form of
   bias correction. Note that alternatively, we can first bias correct the GCM
   data to the historical observation data, then normalize/restore GCM data
   using the historical average of the GCM data. This is not implemented for the
   sake of time commitment.
2. The paper described training the model with 10^7 iterations. In this
   notebook, the precip model was trained with 3000 iterations and tmax/tmin
   models were trained with 1000 iterations. The reduced number of iterations
   used was mainly due to time saving, as model training was done with CPU only
   at this point. The difference iterations in precip and termpature model was
   due to the observation that, the training loss and RMSE quickly go towards 0
   after training started for temperature models. My hypothesis is that this is
   due to the data normalization routine. Since there are pronounced seasonal
   trends in temperature and we use the simple 30 year historical average at
   each pixel for normalization, the "normalized" temperature data may be mostly
   exhibiting the seasonal trends. Thus, the data may seem "easy" to predict by
   the model based on the coarse scale data. To confirm this hypothesis, trend
   removal techniques (e.g. removal of long term moving average) can be
   implemented to make sure that the model is learning to predict the
   "anomalies" instead of the trends. A complementary hypothesis is that the
   fine scale elevation data used as auxilliary data is very informative in
   disaggregating temperature data (since there is well known relationship
   between elevation and temperature). If this is the case, the model may be
   naturally doing well and we can trust the low error metrics.

DeepSD models are stacked models of SRCNN. As an example, a DeepSD model that
downscaled data from 2 degree to 0.25 degree (an 8-fold increase in resolution)
would be a stacked model of 3 SRCNN models, each with a 2-fold increase in
resolution (2^3 = 8). During training, each of the SRCNN model is trained
independently. Then, as the last step we stack the trained models together to
form the joint models. Two joint models are compiled for each variable in this
notebook, one downscaling from 2 degree to 0.25 degree, and the other
downscaling from 1 degree to 0.25 degree. These two models will then be used
according to the initial resolution of the GCM model in question. If the
original GCM model has a resolution coarser than or close to 2 degree, the GCM
data will first be interpolated to 2 degree and then downscaled. If the original
GCM model has a resolution closer to 1 degree, then the GCM data will first be
interpolated to 1 degree and downscaled. The starting resolution of each GCM
model is calculated in `deepsd_data_prep.ipynb` and manually copied to
`deepsd.py` file for usage during inference.

This notebook saved the model files locally then put them on azure cloud. The
file names used are then copied to `deepsd.py` manually to be read during
inference.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import xarray as xr
import fsspec
import os
import time
import tensorflow_io

In [None]:
# reading config values

import configparser

config = configparser.ConfigParser()
config.read("/home/jovyan/cmip6-downscaling/methods/deepsd_config.ini")

# config values that matches the DeepSD paper that we directly read from
LAYER_SIZES = [int(k) for k in config.get("SRCNN", "layer_sizes").split(",")]
KERNEL_SIZES = [int(k) for k in config.get("SRCNN", "kernel_sizes").split(",")]
OUTPUT_DEPTH = LAYER_SIZES[-1]
LEARNING_RATE = float(config.get("SRCNN", "learning_rate"))
UPSCALE_FACTOR = config.getint("DeepSD", "upscale_factor")

# how many iterations to run --> determines how long training goes
TRAINING_ITERS = (
    1000  # int(config.get('SRCNN', 'training_iters'))  # paper uses 10e7 iterations/batches
)
TEST_STEP = 50  # int(config.get('SRCNN', 'test_step'))

# config values that differ from the DeepSD paper, and we use the values in the paper instead of in the config file
INPUT_DEPTH = 2  # int(config.get('SRCNN', 'training_input_depth'))  # 2 instead of the 1 since we combined elevation into the input when saving data
BATCH_SIZE = 200  # int(config.get('SRCNN', 'batch_size'))  # 200 instead of 100
INPUT_SIZE = 51  # int(config.get('SRCNN', 'training_input_size'))  # 51 instead of 38

# where to save and get data
az_storage_account = "cmip6downscaling/"
DATA_DIR = "az://{az_storage_account}training/deepsd/{var}/{output_resolution_str}/"
SAVE_DIR = "/home/jovyan/deepsd_models/{var}_{output_resolution_str}/"

# specify training and testing years
train_years = np.arange(1981, 2011)
test_years = np.arange(2009, 2011)
months = np.arange(1, 13)

# variables
variables = ["pr", "tasmax", "tasmin"]
output_resolutions = [0.25, 0.5, 1.0]

In [None]:
fs = fsspec.get_filesystem_class("az")(
    account_name="carbonplan", account_key=os.environ["TF_AZURE_STORAGE_KEY"]
)

## Individual model training


In [None]:
tf.compat.v1.disable_eager_execution()

In [None]:
import sys

sys.path.append("/home/jovyan/srcnn-tensorflow")
from srcnn import srcnn

In [None]:
feature = {
    "hr_h": tf.io.FixedLenFeature([], tf.int64),
    "hr_w": tf.io.FixedLenFeature([], tf.int64),
    "hr_d": tf.io.FixedLenFeature([], tf.int64),
    "lr_h": tf.io.FixedLenFeature([], tf.int64),
    "lr_w": tf.io.FixedLenFeature([], tf.int64),
    "lr_d": tf.io.FixedLenFeature([], tf.int64),
    "label": tf.io.FixedLenFeature([], tf.string),
    "img_in": tf.io.FixedLenFeature([], tf.string),
    "lat": tf.io.FixedLenFeature([], tf.string),
    "lon": tf.io.FixedLenFeature([], tf.string),
    # TODO: this needs to be string
    "time": tf.io.FixedLenFeature([], tf.string),
}


def read_and_decode(filename_queue, input_size, input_depth, output_depth):

    reader = tf.compat.v1.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.compat.v1.parse_single_example(serialized_example, features=feature)

    label = tf.io.parse_tensor(features["label"], out_type=tf.float32)
    img_in = tf.io.parse_tensor(features["img_in"], out_type=tf.float32)
    lat = tf.io.parse_tensor(features["lat"], out_type=tf.float32)
    lon = tf.io.parse_tensor(features["lon"], out_type=tf.float32)

    img_in.set_shape([input_size, input_size, input_depth])
    label.set_shape([input_size, input_size, output_depth])

    return {"input": img_in, "label": label, "lat": lat, "lon": lon}


def get_inputs(filenames, batch_size, input_size, input_depth, output_depth):
    with tf.name_scope("input"), tf.device("/cpu:0"):
        filename_queue = tf.compat.v1.train.string_input_producer(filenames)
        data = read_and_decode(filename_queue, input_size, input_depth, output_depth)

        images, labels = tf.compat.v1.train.shuffle_batch(
            [data["input"], data["label"]],
            batch_size=batch_size,
            num_threads=8,
            capacity=2000 + 3 * batch_size,
            min_after_dequeue=1000,
            allow_smaller_final_batch=True,
        )

    return images, labels

In [None]:
srcnn.__file__

In [None]:
from cmip6_downscaling.methods.deepsd import res_to_str

In [None]:
# build one model per variable x resolution
overwrite = False

for var in ["tasmax", "tasmin"]:
    for output_resolution in [0.25]:
        output_resolution_str = res_to_str(output_resolution)
        data_dir = DATA_DIR.format(
            az_storage_account=az_storage_account,
            var=var,
            output_resolution_str=output_resolution_str,
        )
        save_dir = SAVE_DIR.format(var=var, output_resolution_str=output_resolution_str)
        final_save_path = os.path.join(save_dir, "srcnn.ckpt")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        if os.path.exists(final_save_path + ".meta") and not overwrite:
            print(f"output already exists at {final_save_path}, skipping")
            continue

        print(f"starting model train for {var} {output_resolution}")
        with tf.Graph().as_default(), tf.device("/cpu:0"):
            # read inputs
            train_files = [
                f"{data_dir}{year}-{month:02d}-zscore.tfrecords"
                for year in train_years
                for month in months
            ]
            test_files = [
                f"{data_dir}{year}-{month:02d}-zscore.tfrecords"
                for year in test_years
                for month in months
            ]
            train_images, train_labels = get_inputs(
                filenames=train_files,
                batch_size=BATCH_SIZE,
                input_size=INPUT_SIZE,
                input_depth=INPUT_DEPTH,
                output_depth=OUTPUT_DEPTH,
            )
            test_images, test_labels = get_inputs(
                filenames=test_files,
                batch_size=BATCH_SIZE,
                input_size=INPUT_SIZE,
                input_depth=INPUT_DEPTH,
                output_depth=OUTPUT_DEPTH,
            )

            # crop the training labels
            # the labels currently have the same spatial dimensions as the input image, but due to the convolution process, the output will be smaller than the input
            border_size = int((sum(KERNEL_SIZES) - len(KERNEL_SIZES)) / 2)
            train_labels_cropped = train_labels[
                :, border_size:-border_size, border_size:-border_size, :
            ]
            # the test labels are not cropped because within the srcnn code testing output is padded when `is_training` is false

            # instantiate the input x and y pipeline
            is_training = tf.compat.v1.placeholder_with_default(True, (), name="is_training")
            x = tf.cond(
                pred=is_training,
                true_fn=lambda: train_images,
                false_fn=lambda: test_images,
            )
            y = tf.cond(
                pred=is_training,
                true_fn=lambda: train_labels_cropped,
                false_fn=lambda: test_labels,
            )
            x = tf.identity(x, name="x")
            y = tf.identity(y, name="y")

            # instantiate the srcnn model
            model = srcnn.SRCNN(
                x,
                y,
                LAYER_SIZES,
                KERNEL_SIZES,
                is_training=is_training,
                learning_rate=LEARNING_RATE,
                device="/cpu:0",  # this is the line that needs to change if we want to run on gpu
            )
            prediction = tf.identity(model.prediction, name="prediction")

            # initialize graph and start session
            init_op = tf.group(
                tf.compat.v1.global_variables_initializer(),
                tf.compat.v1.local_variables_initializer(),
            )

            sess = tf.compat.v1.Session(
                config=tf.compat.v1.ConfigProto(
                    allow_soft_placement=True, log_device_placement=False
                )
            )
            saver = tf.compat.v1.train.Saver()

            sess.run(init_op)

            # start coordinator for data
            coord = tf.train.Coordinator()
            threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)

            # summary data
            summary_op = tf.compat.v1.summary.merge_all()
            train_writer = tf.compat.v1.summary.FileWriter(save_dir + "/train", sess.graph)
            test_writer = tf.compat.v1.summary.FileWriter(save_dir + "/test", sess.graph)

            # run through model training
            for step in range(TRAINING_ITERS):
                start_time = time.time()
                _, train_loss, train_rmse, x_mean, pred_mean = sess.run(
                    [
                        model.opt,
                        model.loss,
                        model.rmse,
                        model.x_mean,
                        model.pred_mean,
                    ],
                    feed_dict={is_training: True},
                )
                duration = time.time() - start_time

                # print(f'step {step}: train loss = {train_loss:2.5f}, train rmse = {train_rmse:2.5f}, x mean = {x_mean:2.5f}, pred mean = {pred_mean:2.5f}')

                if step % TEST_STEP == 0:
                    train_summary = sess.run(summary_op, feed_dict={is_training: True})
                    train_writer.add_summary(train_summary, step)
                    test_loss, test_rmse, test_summary = sess.run(
                        [model.loss, model.rmse, summary_op],
                        feed_dict={is_training: False},
                    )
                    test_writer.add_summary(test_summary, step)
                    print(
                        "Step: %d, Examples/sec: %0.5f, Training Loss: %2.5f, Train RMSE: %2.5f, Test RMSE: %2.5f"
                        % (
                            step,
                            BATCH_SIZE / duration,
                            train_loss,
                            train_rmse,
                            test_rmse,
                        )
                    )

            save_path = saver.save(sess, final_save_path)

In [None]:
print("done")

## Merge three models together


In [None]:
# these functions are adapted from deepsd inference code with minor edits
# edits mostly involve removing code related to batch normalization and making the code tf2 compatible

from tensorflow.python.framework import graph_util


def freeze_graph(model_folder):
    # We start a session and restore the graph weights
    with tf.compat.v1.Session() as sess:
        # We retrieve our checkpoint fullpath
        checkpoint = tf.train.get_checkpoint_state(model_folder)
        input_checkpoint = checkpoint.model_checkpoint_path

        # We precise the file fullname of our freezed graph
        output_graph = model_folder + "/frozen_model.pb"
        if os.path.exists(output_graph):
            os.remove(output_graph)

        # Before exporting our graph, we need to precise what is our output node
        # This is how TF decides what part of the Graph he has to keep and what part it can dump
        # NOTE: this variable is plural, because you can have multiple output nodes
        output_node_names = "prediction"

        # We clear devices to allow TensorFlow to control on which device it will load operations
        clear_devices = True

        # We import the meta graph and retrieve a Saver
        saver = tf.compat.v1.train.import_meta_graph(
            input_checkpoint + ".meta", clear_devices=clear_devices
        )

        # We retrieve the protobuf graph definition
        graph = tf.compat.v1.get_default_graph()
        input_graph_def = graph.as_graph_def()
        saver.restore(sess, input_checkpoint)
        gd = sess.graph.as_graph_def()

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            gd,  # The graph_def is used to retrieve the nodes
            output_node_names.split(
                ","
            ),  # The output node names are used to select the usefull nodes
            variable_names_blacklist=[],
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.io.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))


def load_graph(frozen_graph_filename, graph_name, x):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.io.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

        # Then, we can use again a convenient built-in function to import a graph_def into the
        # current default Graph
        is_training = tf.constant(False)
        (y,) = tf.import_graph_def(
            graph_def,
            input_map={"x": x, "is_training": is_training},
            return_elements=["prediction:0"],
            name=graph_name,
            op_dict=None,
            producer_op_list=None,
        )
    return y


def join_graphs(checkpoints, new_checkpoint):
    """
    placeholders:
        low-resolution ppt
        elevation for each checkpoint

    x = concat([ppt, elev_1])
    for each checkpoint:
        x -> y
        x = concat([y, elev_i])
    return y
    """
    # begin by freezing each graph independently
    for cpt in checkpoints:
        # freeze current graph
        freeze_graph(cpt)
        tf.compat.v1.reset_default_graph()

    x = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None, 1), name="lr_x")
    elevs = []
    for j, cpt in enumerate(checkpoints):
        # another elevation placeholder
        elv = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None, 1), name="elev_%i" % j)
        elevs.append(elv)

        # resize low-resolution
        h = tf.shape(input=x)[1]
        w = tf.shape(input=x)[2]
        size = tf.stack([h * UPSCALE_FACTOR, w * UPSCALE_FACTOR])
        x = tf.image.resize(x, size, method=tf.image.ResizeMethod.BILINEAR)

        # join elevation and interpolated image
        x = tf.concat([x, elv], axis=3)
        graph_name = os.path.basename(cpt.strip("/"))

        # load frozen graph with x as the input
        next_input = graph_name + "/x"
        x = load_graph(os.path.join(cpt, "frozen_model.pb"), graph_name, x=x)

    with tf.compat.v1.Session() as sess:
        summary_op = tf.compat.v1.summary.merge_all()
        train_writer = tf.compat.v1.summary.FileWriter(new_checkpoint, sess.graph)
        train_writer.add_graph(tf.compat.v1.get_default_graph())

        gd = sess.graph.as_graph_def()
        output_graph = os.path.join(new_checkpoint, "frozen_graph.pb")
        with tf.io.gfile.GFile(output_graph, "wb") as f:
            f.write(gd.SerializeToString())
        print("%d ops in the final graph." % len(gd.node))

    tf.compat.v1.reset_default_graph()
    return output_graph, x.name

In [None]:
# save two stacked models per variable
# one of the stacked models will consist of three layers 2.0 -> 1.0 -> 0.5 -> 0.25
# the other will consist of two layers 1.0 -> 0.5 -> 0.25
# each one will be used according to the starting resolution of the GCM model

resolutions = [
    [0.25, 0.5, 1.0],
    [0.25, 0.5],
]

output_model_files = []
output_node_names = []
for var in ["tasmax", "tasmin"]:
    for resolution in resolutions:
        resolution = sorted(
            resolution, reverse=True
        )  # make sure this is from coarse resolution to fine resolution

        # get the model names and checkpoint files in a list, also sorted from coarse resolution to fine resolution
        model_sections = [(f"{var}_{res_to_str(res)}", res) for res in resolution]
        CHECKPOINTS = [
            f"{SAVE_DIR.format(var=var, output_resolution_str=res_to_str(res))}"
            for res in resolution
        ]

        # create the joint model name indicating the coarsening factors
        input_res = res_to_str(UPSCALE_FACTOR * np.max(resolution))
        output_res = res_to_str(np.min(resolution))
        JOINED_RESOLUTION_STR = f"{input_res}d_to_{output_res}d"

        joined_checkpoint = SAVE_DIR.format(var=var, output_resolution_str=JOINED_RESOLUTION_STR)
        if not os.path.exists(joined_checkpoint):
            os.makedirs(joined_checkpoint)

        new_graph_path, output_node_name = join_graphs(CHECKPOINTS, joined_checkpoint)
        output_model_files.append(new_graph_path)
        output_node_names.append(output_node_name)

out = pd.DataFrame({"model_file": output_model_files, "output_node": output_node_names})

In [None]:
out

In [None]:
# upload models to cloud storage
# these paths will then be used in `deepsd_inference.ipynb`

for i, row in out.iterrows():
    local_model_path = row.model_file
    remote_model_path = local_model_path.replace("/home/jovyan", "az://training/deepsd")
    print(remote_model_path, row.output_node)
    fs.put_file(lpath=local_model_path, rpath=remote_model_path, overwrite=True)