# Train a CNN

In this notebook we will go through all the steps required to train a fully convolutional neural network. Because this takes a while and uses a lot of GPU RAM a separate command line script (`train_nn.py`) is also provided in the `src` directory.

In [1]:
import climetlab as cml

In [2]:
import tensorflow as tf
def limit_mem():
    """By default TF uses all available GPU memory. This function prevents this."""
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    tf.compat.v1.Session(config=config)

In [3]:
limit_mem()

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



2023-02-06 13:25:42.011368: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-02-06 13:25:42.011813: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


## Create data generator

First up, we want to write our own Keras data generator. The key advantage to just feeding in numpy arrays is that we don't have to load the data twice because our intputs and outputs are the same data just offset by the lead time. Since the dataset is quite large and we might run out of CPU RAM this is important.

In [32]:
features_names = ['geopotential_500','temperature_850']
train = {f:cml.load_dataset('weatherbench-extended', f, year = ['2015']) for f in features_names}
#train = {f:cml.load_dataset('weatherbench-extended', f, year = ['2015', '2016']) for f in features_names}
valid = {f:cml.load_dataset('weatherbench-extended', f, year = ['2017']) for f in features_names}
test =  {f:cml.load_dataset('weatherbench-extended', f, year = ['2018']) for f in features_names}

# Could also use: climetlab>=0.13.2
# ds = cml.load_dataset('weatherbench-extended', features_names, year = ['2017'])
# len(ds), len(ds.sel(param='t'))

[('geopotential', 500)]
[('temperature', 850)]
[('geopotential', 500)]
[('temperature', 850)]
[('geopotential', 500)]
[('temperature', 850)]


In [33]:
bs = 32
lead_time = 6

offset = 6


def add_offset(ds, offset):
    def ds_with_offset(i):
        return ds[i + offset].to_numpy()

    return ds_with_offset


def features_to_targets(data):
    return [add_offset(f, offset) for f in data]


print(train)
print([len(v) for k, v in train.items()])


def to_tfdataset(features):
	options = [dict(normalize="min-max") for f, feature in features.items()]
	# options = [dict(normalize='min-max', feature.mean, feature.std) for f, feature in train.items()]

	# targets = {f: add_offset(feature, offset) for f, feature in features.items()}
	targets = features
	target_options = [dict(normalize="min-max") for f, feature in features.items()]
	# target_options = [dict(normalize="min-max", offset=offset) for f, feature in features.items()]

	total_length = len(list(features.values())[0]) - offset
	print(total_length)

	first_feature = list(features.values())[0]
	return first_feature.to_tfdataset(
	    features=list(features.values()),
	    targets=list(targets.values()),
	    options=options,
	    target_options=target_options,
	    total_length=total_length,
	)

tfds_train = to_tfdataset(train)
tfds_valid = to_tfdataset(valid)
tfds_test = to_tfdataset(test)


{'geopotential_500': <climetlab_weatherbench.extended.WeatherbenchExtendedCDS object at 0x2d69471c0>, 'temperature_850': <climetlab_weatherbench.extended.WeatherbenchExtendedCDS object at 0x2a4dcb550>}
[1460, 1460]
1454
1454
1454


In [34]:
for i in tfds_train.as_numpy_iterator():
	print([_.shape for _ in i])
	break

[(2, 33, 64), (2, 33, 64)]


In [35]:
tfds_train._climetlab_tf_input

<PrefetchDataset element_spec=TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)>

## Create and train model

Next up, we need to create the model architecture. Here we will use a fully connected convolutional network. Because the Earth is periodic in longitude, we want to use a periodic convolution in the lon-direction. This is not implemented in Keras, so we have to do it manually.

In [36]:
###########################

# from climetlab.ml.tf import PeriodicConv2D
from tensorflow.keras.layers import Conv2D


class PeriodicPadding2D(tf.keras.layers.Layer):
    def __init__(self, pad_width, **kwargs):
        super().__init__(**kwargs)
        self.pad_width = pad_width

    def call(self, inputs, **kwargs):
        if self.pad_width == 0:
            return inputs
        inputs_padded = tf.concat(
            [
                inputs[:, :, -self.pad_width :, :],
                inputs,
                inputs[:, :, : self.pad_width, :],
            ],
            axis=2,
        )
        # Zero padding in the lat direction
        inputs_padded = tf.pad(
            inputs_padded, [[0, 0], [self.pad_width, self.pad_width], [0, 0], [0, 0]]
        )
        return inputs_padded

    def get_config(self):
        config = super().get_config()
        config.update({"pad_width": self.pad_width})
        return config


class PeriodicConv2D(tf.keras.layers.Layer):
    def __init__(
        self,
        filters,
        kernel_size,
        conv_kwargs={},
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.conv_kwargs = conv_kwargs
        if type(kernel_size) is not int:
            assert (
                kernel_size[0] == kernel_size[1]
            ), "PeriodicConv2D only works for square kernels"
            kernel_size = kernel_size[0]
        pad_width = (kernel_size - 1) // 2
        self.padding = PeriodicPadding2D(pad_width)
        self.conv = Conv2D(filters, kernel_size, padding="valid", **conv_kwargs)


#####################

In [37]:
def build_cnn(filters, kernels, input_shape, dr=0):
    """Fully convolutional network"""
    import tensorflow.keras as keras
    from tensorflow.keras.layers import Input, Dropout,LeakyReLU
    #from climetlab.ml.tf import PeriodicConv2D

    x = input = Input(shape=input_shape)
    for f, k in zip(filters[:-1], kernels[:-1]):
        x = PeriodicConv2D(f, k)(x)
        x = LeakyReLU()(x)
        if dr > 0:
            x = Dropout(dr)(x)
    output = PeriodicConv2D(filters[-1], kernels[-1])(x)
    return keras.models.Model(input, output)


In [20]:
cnn = build_cnn([64, 2], [5, 5], (32, 64, 2))
# cnn = build_cnn([64, 64, 64, 64, 2], [5, 5, 5, 5, 5], (32, 64, 2))

In [21]:
cnn.compile(tf.keras.optimizers.Adam(1e-4), 'mse')

In [22]:
cnn.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 64, 2)]       0         
                                                                 
 periodic_conv2d_2 (Periodic  (None, 32, 64, 2)        0         
 Conv2D)                                                         
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 32, 64, 2)         0         
                                                                 
 periodic_conv2d_3 (Periodic  (None, 32, 64, 2)        0         
 Conv2D)                                                         
                                                                 
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [23]:
# Since we didn't load the full data this is only for demonstration.
cnn.fit(tfds_train, epochs=100, validation_data=tfds_valid, 
          callbacks=[tf.keras.callbacks.EarlyStopping(
                        monitor='val_loss',
                        min_delta=0,
                        patience=2,
                        verbose=1, 
                        mode='auto'
                    )]
         )

Epoch 1/100


2023-02-06 13:28:33.193564: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2023-02-06 13:28:41.368338: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/100
Epoch 3/100
Epoch 3: early stopping


<keras.callbacks.History at 0x2abb92200>

In [24]:
cnn.save_weights('cnn.h5')

In [25]:
# Load weights from externally trained model
cnn.load_weights('cnn.h5')
# cnn.load_weights('cnn_good.h5')

## Create predictions

Now that we have our model we need to create a prediction NetCDF file. This function does this. 

We can either directly predict the target lead time (e.g. 5 days) or create an iterative forecast by chaining together many e.g. 6h forecasts.

In [40]:
def create_predictions(model, dg):
    """Create predictions for non-iterative model"""
    preds = model.predict_generator(dg)
    # Unnormalize
    preds = preds * dg.std.values + dg.mean.values
    fcs = []
    lev_idx = 0
    for var, levels in dg.var_dict.items():
        if levels is None:
            fcs.append(xr.DataArray(
                preds[:, :, :, lev_idx],
                dims=['time', 'lat', 'lon'],
                coords={'time': dg.valid_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon},
                name=var
            ))
            lev_idx += 1
        else:
            nlevs = len(levels)
            fcs.append(xr.DataArray(
                preds[:, :, :, lev_idx:lev_idx+nlevs],
                dims=['time', 'lat', 'lon', 'level'],
                coords={'time': dg.valid_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon, 'level': levels},
                name=var
            ))
            lev_idx += nlevs
    return xr.merge(fcs)

In [42]:
fc = create_predictions(cnn, tfds_test._climetlab_tf_input)

  preds = model.predict_generator(dg)


AttributeError: 'PrefetchDataset' object has no attribute 'std'

In [None]:
compute_weighted_rmse(fc, valid).compute()

NameError: name 'compute_weighted_rmse' is not defined

In [43]:
def create_iterative_predictions(model, dg, max_lead_time=5*24):
    state = dg.data[:dg.n_samples]
    preds = []
    for _ in range(max_lead_time // dg.lead_time):
        state = model.predict(state)
        p = state * dg.std.values + dg.mean.values
        preds.append(p)
    preds = np.array(preds)
    
    lead_time = np.arange(dg.lead_time, max_lead_time + dg.lead_time, dg.lead_time)
    das = []; lev_idx = 0
    for var, levels in dg.var_dict.items():
        if levels is None:
            das.append(xr.DataArray(
                preds[:, :, :, :, lev_idx],
                dims=['lead_time', 'time', 'lat', 'lon'],
                coords={'lead_time': lead_time, 'time': dg.init_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon},
                name=var
            ))
            lev_idx += 1
        else:
            nlevs = len(levels)
            das.append(xr.DataArray(
                preds[:, :, :, :, lev_idx:lev_idx+nlevs],
                dims=['lead_time', 'time', 'lat', 'lon', 'level'],
                coords={'lead_time': lead_time, 'time': dg.init_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon, 'level': levels},
                name=var
            ))
            lev_idx += nlevs
    return xr.merge(das)

In [None]:
fc_iter = create_iterative_predictions(cnn, tfds_test)

: 

In [None]:
rmse = evaluate_iterative_forecast(fc_iter, valid)

: 

In [None]:
rmse.load()

: 

In [None]:
rmse.z_rmse.plot()

: 

In [None]:
rmse.t_rmse.plot()

: 

# The end