In this notebook, I train keras to predict the tendency of precipitable water on just the tropics. My hope is that the network can learn to conserve water pretty well. My goal is to overfit the data using the neural network.

In [None]:
%matplotlib inline

In [None]:
import xarray as xr

In [None]:
ds = xr.open_dataset('../data/processed/training.nc')

mean = ds.mean(['time', 'x'])
sig = np.abs(ds-mean).max(['x', 'time', 'z'])/3

ds = ds.isel(x=[0, ], y=[30, 5], time=slice(0,20))

# compute Q1 and Q2
time = ds.time
ds['Q2'] = ds.QT.diff('time')/(time[1]-time[0]) - 86400*(ds.FQT + ds.FQT.shift(time=-1))/2
ds['Q1'] = ds.SLI.diff('time')/(time[1]-time[0]) - 86400*(ds.FSLI + ds.FSLI.shift(time=-1))/2

ds['CQ2'] = (ds.Q2 * ds.layer_mass).sum('z')/1e4

# drop the null dimensions
ds = ds.dropna('time')

In [None]:
ds.Q1.isel(x=0, y=0).plot(x='time')

In [None]:
ds.Q2.isel(x=0, y=0).plot(x='time')

In [None]:
ds.CQ2.isel(x=0, y=0).plot(x='time')

In [None]:
from itertools import product
import random

def grouplen(sequence, chunk_size):
    return list(zip(*[iter(sequence)] * chunk_size))


def index_and_stack(inds, x):
    return np.stack(x[ind] for ind in inds)[:, :, np.newaxis, np.newaxis]


def prepare_input_output_data(ds, input_fields, output_fields, height_dim='z'):
    
    def prepvar(x):
        if height_dim in x.dims:
            return x.values
        else:
            return np.expand_dims(x.values, -3)
            
            
            
    return [[prepvar(ds[key]) for key in field_list]
            for field_list in [input_fields, output_fields]]



class batch_generator(object):
    def __init__(self, ds, in_fields, out_fields, batch_size=None, shuffle=True):
        ins, outs = prepare_input_output_data(ds, in_fields, out_fields)

        (t, z, y, x) = ins[0].shape

        inds = [(i,slice(None), j, k) for i,j,k in product(range(t), range(y), range(x))]

        if shuffle:
            random.shuffle(inds)

        if batch_size is None:
            self.batch_indices = [inds]
        else:
            self.batch_indices = grouplen(inds, batch_size)
        self.ins = ins
        self.outs = outs
        
        self.cur_batch = 0
        
    def __len__(self):
        return len(self.batch_indices)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        idxs = self.batch_indices[self.cur_batch]
        self.cur_batch = (self.cur_batch + 1) % len(self)
        return [index_and_stack(idxs, x) for x in self.ins], [index_and_stack(idxs, y) for y in self.outs]

        
def normalize_dataset(ds):
    dims = [dim for dim in ['x', 'y', 'time'] if dim in ds.dims]
    mean = ds.mean(dims)
    sig = np.abs(ds-mean).max()/3
    
    ds_normalized = (ds- mean)/sig
    return ds_normalized   
    

in_fields, out_fields = ['QT', 'SLI', 'LHF', 'SHF', 'SOLIN'], ['Q2']
ins, outs= prepare_input_output_data(ds, in_fields, out_fields)

In [None]:
from keras.layers import *
from keras.models import Model
from keras.optimizers import absolute_import
from keras.optimizers import Adam

# Messing around with keras

In [None]:
i = Input(shape=(34, None, None))
p = Permute([2, 3, 1])(i)
o = Dense(512)(p)

mod_simple = Model(inputs=[i], outputs=[o])
mod_simple.compile(optimizer='adam', loss='mse')

# Main model

In [None]:


for k, (x, y) in enumerate(batch_generator(ds, in_fields, out_fields)):
    break

    
def get_model(n):
    # assume inputs have shape (*, z, y, x)
    inputs_keras = [Input((xx.shape[1], None, None), name=name) for name, xx in zip(in_fields, x)]
    output_shapes = [yy.shape[1] for yy in y]

    catted = Concatenate(axis=-3)(inputs_keras)
    perm = Permute([2,3,1])(catted)
    layer1 = Dense(n, activation='relu')(perm)
    # layer2 = Dense(activation='relu')(layer1)
    outputs = [Permute([3, 1, 2], name=name)(Dense(n)(layer1)) 
               for name, n in zip(out_fields, output_shapes)]

    return Model(inputs_keras, outputs)


def fit_model(ds_normalized, model):

    model.compile(loss='mean_squared_error',
                  optimizer=Adam(lr=.01),
                  metrics=['accuracy'])
    gen = batch_generator(ds_normalized, in_fields, out_fields, batch_size=None)
    history = model.fit_generator(gen, steps_per_epoch=len(gen), epochs=5000, verbose=0)
    
    return history


def plot_model(ds_normalized, model):
    x, y = prepare_input_output_data(ds_normalized, in_fields, out_fields)
    outs = {name: val for name, val in zip(model.output_names, model.predict(x))}

    cq2 = model.predict(x).squeeze()

    dims = ['time', 'z', 'y']
    coords = {key: ds_normalized.coords[key] for key in dims}

    cq2 = xr.DataArray(cq2, dims=dims, coords=coords)

    # ds_normalized.Q2.isel(x=0,y=0).plot()
    # plt.figure()
    # (cq2-ds_normalized.Q2).plot(x='time')

    from matplotlib.colors import LogNorm
    x,y,z = xr.broadcast(cq2, ds_normalized.Q2, cq2.z)
    plt.scatter(x, y, c=z, norm=LogNorm())
    plt.colorbar()
    
    
def fit_and_plot(ds, n=256):
    model = get_model(n)
    history= fit_model(ds, model)
    plt.figure()
    plot_model(ds, model)
    return history

# Fit the model

In [None]:
import holoviews as hv
hv.extension('bokeh')

We use two normalization strategies. By latitude and global. In by latitude, the data are first normalized by mean and standard deviation.

In [None]:
# normalize by pre-computed mean/sig
ds_normalized = ds.copy()
for key in in_fields:
    ds_normalized[key] = (ds[key]-mean[key].sel(y=ds.y))/sig[key].sel(y=ds.y)

# normalize by mean accross meridional slices
ds_normalized_global = ds.copy()
for key in in_fields:
    ds_normalized_global[key] = normalize_dataset(ds[key])

In [None]:
hmap = hv.HoloMap(kdims=['n', 'norm'])
for n in [16, 32, 64]:
    history = fit_and_plot(ds_normalized, n)
    hmap[(n, 'ByLat')] = hv.Curve(history.history['loss'])
    
    history = fit_and_plot(ds_normalized_global, n)
    hmap[(n, 'Global')] = hv.Curve(history.history['loss'])

In [None]:
%%opts Curve[width=400, height=int(400/1.61)]
hmap.overlay("norm").redim.range(y=(0,.5))

This shows that normalizing the input data by zonal means significantly improves the training procedure. For low number of hidden nodes it seems that the Global normalization procedure cannot overfit the small number of training points. This suggests it will be very challenging to sufficiently train the network on a full dataset. Therefore, we should **normalize the data by the zonal mean and scaling**. This is not a problem for now, but in the future, we can perhaps make the normalization depend on the state for instance be subtracting a moist adiatic profile or something.