In [1]:
import os
import glob
import time
import datetime

import numpy as np
import pandas as pd
import xarray as xr
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib

In [2]:
%config Completer.use_jedi = False

# Hyperparameter

In [14]:
# data param
day_len = 40
total_months = (2100-2015+1)*12
batch_size = total_months  # predict all points at once

# model param
num_filters = 16

# Dataset

In [4]:
data_models = ['GFDL-ESM4','IPSL-CM6A-LR','MPI-ESM1-2-HR']  # models for temp, prec, LAI
dmodel = 'IPSL-CM6A-LR' # np.random.choice(np.array(data_models))  # TODO: choose a model
scenarios = ['ssp126', 'ssp370', 'ssp585']
scenario = scenarios[0]

# load historical data
dx_temp = xr.open_mfdataset('data/near_surface_air_temperature/historical/{}/*.nc'.format(dmodel)).tas
dx_prec = xr.open_mfdataset('data/precipitation_flux/historical/{}/*.nc'.format(dmodel)).pr

# load prediction of climate until 2100 and concatanate with historical data
dx_temp_future = xr.open_mfdataset('data/near_surface_air_temperature/{}/{}/*.nc'.format(scenario, dmodel)).tas
dx_temp_future = xr.concat((dx_temp, dx_temp_future), dim='time')
dx_prec_future = xr.open_mfdataset('data/precipitation_flux/{}/{}/*.nc'.format(scenario, dmodel)).pr
dx_prec_future = xr.concat((dx_prec, dx_prec_future), dim='time')

### Generator
one could also use just function. Systematically sliding window to predict all future months using timestamp index

In [6]:
def gen_future_climate():
#     first_month = (2015-1850)*12 + 1  # int index of which month january 2015 is with 0 being january 1850
    counter = 0
    while counter < total_months: # predict from 2015 to 2100
        cyear = (counter+1) // 12 + 2015
        cmonth = (counter+1) % 12
        if cmonth == 0:
            cmonth = 12
        current_timestamp = pd.Timestamp(cyear, cmonth, 1)
        input_start_timestamp = current_timestamp - pd.Timedelta(day_len-1, unit='day')
        
        counter += 1
        yield (np.stack((np.array(dx_temp_future.loc[input_start_timestamp:current_timestamp+pd.Timedelta(1,unit='day')]),
                         np.array(dx_prec_future.loc[input_start_timestamp:current_timestamp+pd.Timedelta(1,unit='day')])),
                         axis=-1)
              )

In [22]:
# data pipeline
ds_future = tf.data.Dataset.from_generator(generator=gen_future_climate, output_types=(tf.float32)).batch(batch_size)

# Model

In [9]:
class ConvLSTM(tf.keras.Model):
    def __init__(self, num_filters):
        super(ConvLSTM, self).__init__()
        
        self.convlstm2D_1 = tf.keras.layers.ConvLSTM2D(filters = num_filters, kernel_size=(3,3),
                                                     padding="same", return_sequences=True,
                                                      activation = "tanh")
        self.bn_1 = tf.keras.layers.BatchNormalization()
        #self.acti_1 = tf.keras.layers.Activation(activation)
        
        self.convlstm2D_2 = tf.keras.layers.ConvLSTM2D(filters = num_filters, kernel_size=(3,3),
                                                     padding="same", return_sequences=True,
                                                      activation = "tanh")
        self.bn_2 = tf.keras.layers.BatchNormalization()
        #self.acti_2 = tf.keras.layers.Activation(activation)


        self.convlstm2D_3 = tf.keras.layers.ConvLSTM2D(filters = num_filters, kernel_size=(3,3),
                                                     padding="same", return_sequences=True,
                                                      activation = "tanh")
        self.bn_3 = tf.keras.layers.BatchNormalization()
        #self.acti_3 = tf.keras.layers.Activation(activation)
        
        self.convlstm2D_4 = tf.keras.layers.ConvLSTM2D(filters = num_filters, kernel_size=(3,3),
                                                     padding="same", return_sequences=True,
                                                      activation = "tanh")
        self.bn_4 = tf.keras.layers.BatchNormalization()
        #self.acti_4 = tf.keras.layers.Activation(activation)
        
        # convolve over time, lat, lon. This means that we assume timesteps close to each other share local similarities
        self.conv3d = tf.keras.layers.Conv3D(filters = 2, kernel_size = (3,3,3), 
                                             activation= "tanh", padding="same")
        # computed convolved sum over all time dimension to get a single time slice
        self.bottleneck = tf.keras.layers.Conv3D(filters=1, kernel_size=1, activation="relu",strides=1)


    def call(self, x, training, input_shape):
        # (batch, time, lat, lon, channel)
        x = tf.ensure_shape(x, input_shape) 
        # (batch, time, lat, lon, channel)
        x = self.convlstm2D_1(x,training= training)
        # (batch, time, lat1, lon1, filter1)
        x = self.bn_1(x,training = training)
        
        x = self.convlstm2D_2(x,training = training)
        x = self.bn_2(x,training = training)
        
        x = self.convlstm2D_3(x,training = training)
        x = self.bn_3(x,training = training)
        
        x = self.convlstm2D_4(x,training = training)
        x = self.bn_4(x, training = training)
        # (batch, time, lat4, lon4, filter4)
        x = self.conv3d(x)
        # (batch, newtime, newlat, newlon, newfilter=2)
        
        x = tf.transpose(x, [0,4,2,3,1])
        # (batch, 2, lat, lon, time)
        x = self.bottleneck(x)
        # (batch, 2, lat, lon, 1)
        
        x = tf.transpose(x, [0,4,2,3,1])
        # (batch, 1, lat, lon, 2)
        x = tf.squeeze(x,axis=1)
        # (batch, lat, lon, 2)
        
        return x

In [15]:
model = ConvLSTM(num_filters=num_filters)
input_shape = (batch_size, day_len, 36, 72, 2)
# model.build(input_shape)
# model.summary() # TODO: doesn't work cuz some layers aren't built?
# TODO: The fact that we can't build the model probably results in undefined rank error

# Deployment

In [None]:
model.load_weights('model_weights/Version1.hdf5')

In [18]:
for whole_data in ds_future.take(1):  # just batch whole data into one batch
    predictions = model(whole_data, input_shape=input_shape)

In [None]:
np.save('results/pred_version1_{}'.format(scenario), predictions)

# Plot

In [None]:
plt.rcParams["animation.html"] = "jshtml"  # allow animation for jupyter
plt.rcParams['xtick.bottom'] = False
plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.labelleft'] = False

In [None]:
frames = []  # append each image
fig = plt.figure()
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)

for timeindex in range(total_months):  # animate for 1 yr
    frames.append([plt.imshow(predictions[timeindex,:,:,0], # TODO: change to 1 to save NPP
                              cmap='gray', origin='lower', animated=True)])

ani = matplotlib.animation.ArtistAnimation(fig, frames, interval=100, blit=True, repeat=True)
# ani.save('figs/pred_v1_{}_LAI.gif'.format(scenario), writer='imagemagick', fps=60)
ani
