# Train a CNN

In this notebook we will go through all the steps required to train a fully convolutional neural network. Because this takes a while and uses a lot of GPU RAM a separate command line script (`train_nn.py`) is also provided in the `src` directory.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Depending on your combination of package versions, this can raise a lot of TF warnings... 
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import seaborn as sns
import pickle
from src.score import *
from collections import OrderedDict

In [3]:
tf.__version__

'2.2.0'

In [4]:
keras.__version__

'2.3.0-tf'

In [5]:
def limit_mem():
    """By default TF uses all available GPU memory. This function prevents this."""
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    tf.compat.v1.Session(config=config)

In [6]:
limit_mem()

In [7]:
sns.set_style('darkgrid')
sns.set_context('notebook')

In [8]:
DATADIR = '/rds/general/user/mc4117/home/WeatherBench/data/'

## Create data generator

First up, we want to write our own Keras data generator. The key advantage to just feeding in numpy arrays is that we don't have to load the data twice because our intputs and outputs are the same data just offset by the lead time. Since the dataset is quite large and we might run out of CPU RAM this is important.

In [9]:
# Load the validation subset of the data: 2017 and 2018
z500_valid = load_test_data(f'{DATADIR}geopotential_500', 'z')
t850_valid = load_test_data(f'{DATADIR}temperature_850', 't')
valid = xr.merge([z500_valid, t850_valid])

In [10]:
z = xr.open_mfdataset(f'{DATADIR}geopotential_500/*.nc', combine='by_coords')
t = xr.open_mfdataset(f'{DATADIR}temperature_850/*.nc', combine='by_coords').drop('level')

In [11]:
# For the data generator all variables have to be merged into a single dataset.
datasets = [z, t]
ds = xr.merge(datasets)

In [12]:
# In this notebook let's only load a subset of the training data

ds_2015 = ds.sel(time = '2015-06')
ds_2016 = ds.sel(time = '2016-06')
ds_2017 = ds.sel(time = '2017-06')
ds_2018 = ds.sel(time = '2018-06')

ds_train = xr.merge([ds_2015, ds_2016])
ds_test = xr.merge([ds_2017, ds_2018])

In [13]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ds, var_dict, lead_time, batch_size=32, shuffle=True, load=True, mean=None, std=None):
        """
        Data generator for WeatherBench data.
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
        Args:
            ds: Dataset containing all variables
            var_dict: Dictionary of the form {'var': level}. Use None for level if data is of single level
            lead_time: Lead time in hours
            batch_size: Batch size
            shuffle: bool. If True, data is shuffled.
            load: bool. If True, datadet is loaded into RAM.
            mean: If None, compute mean from data.
            std: If None, compute standard deviation from data.
        """
        self.ds = ds
        self.var_dict = var_dict
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_time = lead_time

        data = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for var, levels in var_dict.items():
            try:
                data.append(ds[var].sel(level=levels))
            except ValueError:
                data.append(ds[var].expand_dims({'level': generic_level}, 1))

        self.data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
        self.mean = self.data.mean(('time', 'lat', 'lon')).compute() if mean is None else mean
        self.std = self.data.std('time').mean(('lat', 'lon')).compute() if std is None else std
        # Normalize
        self.data = (self.data - self.mean) / self.std
        self.n_samples = self.data.isel(time=slice(0, -lead_time)).shape[0]
        self.init_time = self.data.isel(time=slice(None, -lead_time)).time
        self.valid_time = self.data.isel(time=slice(lead_time, None)).time
        
        self.bins_z = np.linspace(ds.z.min(), ds.z.max(), 100)

        self.binned_data = xr.Dataset({'z': xr.DataArray(
               np.digitize(ds.z, self.bins_z)-1,
               dims=['time', 'lat', 'lon'],
               coords={'time':self.data.time.values, 'lat': self.data.lat.values, 'lon': self.data.lon.values
               },
               )})

        del ds
        self.on_epoch_end()

        # For some weird reason calling .load() earlier messes up the mean and std computations
        if load: print('Loading data into RAM'); self.data.load()
        if load: print('Loading data into RAM'); self.binned_data.z.load()            

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))

    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
        X = self.data.isel(time=idxs).values
        y = self.binned_data.z.isel(time=idxs + self.lead_time).values
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        if self.shuffle == True:
            np.random.shuffle(self.idxs)

In [14]:
# then we need a dictionary for all the variables and levels we want to extract from the dataset
dic = OrderedDict({'z': None, 't': None})

In [15]:
bs=32
lead_time=72

In [16]:
# Create a training and validation data generator. Use the train mean and std for validation as well.
dg_train = DataGenerator(
    ds_train.sel(time=slice('2015', '2015')), dic, lead_time, batch_size=bs, load=True)

dg_valid = DataGenerator(
    ds_train.sel(time=slice('2016', '2016')), dic, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, shuffle=False)

dg_test = DataGenerator(
    ds_test, dic, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, shuffle=False)

Loading data into RAM
Loading data into RAM
Loading data into RAM
Loading data into RAM
Loading data into RAM
Loading data into RAM


## Create and train model

Next up, we need to create the model architecture. Here we will use a fully connected convolutional network. Because the Earth is periodic in longitude, we want to use a periodic convolution in the lon-direction. This is not implemented in Keras, so we have to do it manually.

In [17]:
class PeriodicPadding2D(tf.keras.layers.Layer):
    def __init__(self, pad_width, **kwargs):
        super().__init__(**kwargs)
        self.pad_width = pad_width

    def call(self, inputs, **kwargs):
        if self.pad_width == 0:
            return inputs
        inputs_padded = tf.concat(
            [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2)
        # Zero padding in the lat direction
        inputs_padded = tf.pad(inputs_padded, [[0, 0], [self.pad_width, self.pad_width], [0, 0], [0, 0]])
        return inputs_padded

    def get_config(self):
        config = super().get_config()
        config.update({'pad_width': self.pad_width})
        return config


class PeriodicConv2D(tf.keras.layers.Layer):
    def __init__(self, filters,
                 kernel_size,
                 conv_kwargs={},
                 **kwargs, ):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.conv_kwargs = conv_kwargs
        if type(kernel_size) is not int:
            assert kernel_size[0] == kernel_size[1], 'PeriodicConv2D only works for square kernels'
            kernel_size = kernel_size[0]
        pad_width = (kernel_size - 1) // 2
        self.padding = PeriodicPadding2D(pad_width)
        self.conv = Conv2D(
            filters, kernel_size, padding='valid', **conv_kwargs
        )

    def call(self, inputs):
        return self.conv(self.padding(inputs))

    def get_config(self):
        config = super().get_config()
        config.update({'filters': self.filters, 'kernel_size': self.kernel_size, 'conv_kwargs': self.conv_kwargs})
        return config

In [22]:
block_no = 1

def convblock(inputs, f, k, l2, dr = 0):
    x = inputs
    if l2 is not None:
        x = PeriodicConv2D(f, k, conv_kwargs={
            'kernel_regularizer': keras.regularizers.l2(l2)})(x) 
    else:
        x = PeriodicConv2D(f, k)(x)
    x = LeakyReLU()(x)
    #x = BatchNormalization()(x)
    #if dr>0: x = Dropout(dr)(x, training = True)

    return x

def build_resnet_cnn(filters, kernels, input_shape, l2 = None, dr = 0, skip = True):
    """Fully convolutional residual network"""

    x = input = keras.layers.Input(shape=input_shape)
    x = convblock(x, filters[0], kernels[0], dr)

    #Residual blocks
    for f, k in zip(filters[1:-1], kernels[1:-1]):
        y = x
        for _ in range(2):
            x = convblock(x, f, k, l2, dr)
        #if skip: x = Add()([y, x])

    #out = Activation('softmax')(x)
    output = PeriodicConv2D(filters[-1], kernels[-1])(out)
   
    return keras.models.Model(input, output)

filt = [64]
kern = [5]

for i in range(int(block_no)-1):
    filt.append(64)
    kern.append(5)

filt.append(100)
kern.append(5)

filt.append(1)
kern.append(5)

cnn = build_resnet_cnn(filt, kern, (32, 64, 2), l2 = 1e-5)
print(cnn.summary())

NameError: name 'out' is not defined

In [None]:
def my_loss(yTrue,yPred):
    #if len(yPred.shape) == 4:
    #    yPred = tf.squeeze(yPred, axis = 3)
    #if len(yTrue.shape) == 4:
    #    yTrue = tf.squeeze(yTrue, axis = 3)
    #yTrue_reshape = tf.reshape(yTrue, shape = [-1])
    #yPred_reshape = tf.reshape(yPred, shape = [-1])
    #time, nx, ny = yPred.shape
    #nx_total = nx*ny
    #yTrue_reshape = tf.reshape(yTrue, [-1]) #[tf.shape(yPred)[0]])
    #yPred_reshape = tf.reshape(yPred, [-1]) #[tf.shape(yPred)[0]])
    yTrue_reshape = tf.keras.backend.flatten(yTrue)
    yPred_reshape = tf.keras.backend.flatten(yPred)    
    print(yTrue_reshape)
    print(yPred_reshape)
    loss = tf.keras.losses.sparse_categorical_crossentropy(yTrue_reshape, yPred_reshape)
    print(loss)
    #return loss
    return K.mean(tf.dtypes.cast(yTrue_reshape, tf.float32) - tf.dtypes.cast(yPred_reshape, tf.float32))
    #return K.mean(K.equal(tf.dtypes.cast(yTrue_reshape, tf.int32), tf.dtypes.cast(yPred_reshape, tf.int32)))

def sparse_cross_entropy(y_true, y_pred):
    y_true = tf.squeeze(y_true, axis = -1)    
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.dtypes.cast(y_true, tf.int32),
                                                          logits=y_pred)
    loss_mean = tf.reduce_mean(loss)
    return loss_mean    
    
    
def custom_sparse_categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.max(tf.dtypes.cast(y_true, tf.float32), axis=-1),
                          K.cast(K.argmax(tf.dtypes.cast(y_pred, tf.float32), axis=-1), K.floatx())),
                  K.floatx())
    
cnn.compile(keras.optimizers.Adam(5e-5), loss = 'mse') #my_loss) #loss=sparse_cross_entropy, metrics = [custom_sparse_categorical_accuracy])

print(cnn.summary())

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
                        monitor='val_loss',
                        min_delta=0,
                        patience=5,
                        verbose=1, 
                        mode='auto'
                    )

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
            monitor = 'val_loss',
            patience=2,
            factor=0.2,
            verbose=1)


In [23]:
cnn.fit(dg_train, epochs=20, validation_data=dg_valid, callbacks=[early_stopping_callback, reduce_lr_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x2b0c260fdcc0>

In [None]:
# Load weights from externally trained model
cnn.load_weights('/rds/general/user/mc4117/home/WeatherBench/saved_models/whole_train_72.h5')

## Create predictions

Now that we have our model we need to create a prediction NetCDF file. This function does this. 

We can either directly predict the target lead time (e.g. 5 days) or create an iterative forecast by chaining together many e.g. 6h forecasts.

In [26]:
fc = cnn.predict(dg_test)

In [27]:
fc

array([[[[1.3860289],
         [1.3860289],
         [1.3860289],
         ...,
         [1.3860289],
         [1.3860289],
         [1.3860289]],

        [[1.7386955],
         [1.7386955],
         [1.7386955],
         ...,
         [1.7386955],
         [1.7386955],
         [1.7386955]],

        [[2.1212044],
         [2.1212044],
         [2.1212044],
         ...,
         [2.1212044],
         [2.1212044],
         [2.1212044]],

        ...,

        [[2.1209831],
         [2.1211045],
         [2.1211774],
         ...,
         [2.1208968],
         [2.1208727],
         [2.1208944]],

        [[1.6720175],
         [1.672153 ],
         [1.6722298],
         ...,
         [1.6719476],
         [1.6719065],
         [1.6719164]],

        [[1.1996075],
         [1.1997206],
         [1.1997862],
         ...,
         [1.1995715],
         [1.1995174],
         [1.1995227]]],


       [[[1.3860289],
         [1.3860289],
         [1.3860289],
         ...,
         [1.3860

In [92]:
fc_conv = np.zeros(fc[:, :, :, 0].shape)
for i in range(fc.shape[0]):
    for j in range(fc.shape[1]):
        for k in range(fc.shape[2]):
            fc_conv[i][j][k] = dg_train.bins_z[np.int(np.round(fc[i][j][k][0], 0))]

In [97]:
X, y1 = dg_test[0]


for i in range(1, len(dg_test)):
    X2, y2 = dg_test[i]
    y1 = np.concatenate((y1, y2))
    
real_ds = xr.Dataset({
    'z': xr.DataArray(
        y1,
        dims=['time', 'lat', 'lon'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon,
                })})

In [95]:
fc_conv_ds = xr.Dataset({
    'z': xr.DataArray(
        fc_conv,
        dims=['time', 'lat', 'lon'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon,
                })})

In [96]:
cnn_rmse = compute_weighted_rmse(fc_conv_ds, ds_test.z[72:])
cnn_rmse.compute()

In [102]:
fc_ds = xr.Dataset({
    'z': xr.DataArray(
        fc[:, :, :, 0],
        dims=['time', 'lat', 'lon'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon,
                })})

cnn_rmse = compute_weighted_rmse(fc_ds, real_ds)
cnn_rmse.compute()
