In [None]:
import gc
import glob
import os
import time

import h5py
import keras
import mat73
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import psutil
import scipy.io
import tensorflow as tf
import tensorflow.keras.layers as layers
#from tensorflow.keras.layers import SpatialDropout3D, UpSampling3D, Dropout, RepeatVector, Average
#import tensorflow.keras.layers as layers
from keras import Input, Model
from keras import backend as K
from keras.callbacks import (Callback, EarlyStopping, ModelCheckpoint,
                             ReduceLROnPlateau)
from keras.layers import (Activation, Add, BatchNormalization, Conv2D, Conv3D,
                          Conv3DTranspose, ConvLSTM2D, Dense, Dropout, Flatten,
                          GaussianNoise, Input, Lambda, MaxPooling2D,
                          MaxPooling3D, Permute, Reshape, TimeDistributed,
                          UpSampling2D, add, concatenate)
from keras.models import Model, Sequential
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import mixed_precision
from tqdm import tqdm


In [None]:
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
tf.config.optimizer.set_experimental_options({'layout_optimizer': False})

os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

In [None]:
#%% load the data
mat = mat73.loadmat('../data/GD/1Deg_800Sample.mat')  # 8 time step estimation
X_1 = mat[
    'X_train']  # (sample, time sequence, latitude, longitude, channel) here channels are 1: precipitation, 2: wind velocity in x direction, 3: wind velocity in y direction
y_1 = mat['y_train']  # (sample, time sequence, lat, lon)

In [None]:
X_test = mat['X_test']
y_test = mat['y_test']
GFS = mat['GFS_test']

X_train, X_val, y_train, y_val = train_test_split(X_1,
                                                  y_1,
                                                  test_size=0.15,
                                                  random_state=42)
print('Train feature', X_train.shape, 'Train label', y_train.shape)
print('Validation feature', X_test.shape, 'Validation label', y_val.shape)


In [None]:
del X_1, y_1, mat, X_val, y_val, GFS
gc.collect()

In [None]:
#%%

sample_shape = (12, 120, 120, 3)
inputs = Input(shape=sample_shape)

c = 8
x_init = BatchNormalization()(inputs)  # Try with normalizing the dataset
x1 = (ConvLSTM2D(filters=c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm1',
                 activation='relu',
                 return_sequences=True))(x_init)
c1 = (ConvLSTM2D(filters=c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm13',
                 activation='relu',
                 return_sequences=True))(x1)
x2 = MaxPooling3D(pool_size=(1, 2, 2))(c1)
x3 = BatchNormalization(center=True, scale=True)(x2)
x4 = Dropout(0.2)(x3)

x5 = (ConvLSTM2D(filters=2 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm2',
                 activation='relu',
                 return_sequences=True))(x4)
x6 = (ConvLSTM2D(filters=2 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm22',
                 activation='relu',
                 return_sequences=True))(x5)
c2 = (ConvLSTM2D(filters=2 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm23',
                 activation='relu',
                 return_sequences=True))(x6)
x7 = MaxPooling3D(pool_size=(1, 2, 2))(c2)
x8 = BatchNormalization(center=True, scale=True)(x7)
x9 = Dropout(0.2)(x8)

x10 = (ConvLSTM2D(filters=4 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm3',
                  activation='relu',
                  return_sequences=True))(x9)
x11 = (ConvLSTM2D(filters=4 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm32',
                  activation='relu',
                  return_sequences=True))(x10)
c3 = (ConvLSTM2D(filters=4 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm33',
                 activation='relu',
                 return_sequences=True))(x11)
x12 = MaxPooling3D(pool_size=(1, 2, 2))(c3)
x13 = BatchNormalization(center=True, scale=True)(x12)
x14 = Dropout(0.2)(x13)

x15 = (ConvLSTM2D(filters=8 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm4',
                  activation='relu',
                  return_sequences=True))(x14)
x16 = (ConvLSTM2D(filters=8 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm41',
                  activation='relu',
                  return_sequences=True))(x15)
c4 = (ConvLSTM2D(filters=8 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm42',
                 activation='relu',
                 return_sequences=True))(x16)
x17 = BatchNormalization(center=True, scale=True)(c4)
x18 = Dropout(0.2)(x17)

x19 = Conv3D(filters=4 * c,
             kernel_size=(2, 1, 1),
             padding='valid',
             activation='relu')(x18)
x20 = (ConvLSTM2D(filters=4 * c,
                  kernel_size=(1, 1),
                  padding='same',
                  name='conv_lstm5',
                  activation='relu',
                  return_sequences=True))(x19)
x21 = Conv3DTranspose(filters=4 * c,
                      kernel_size=(1, 2, 2),
                      strides=(1, 2, 2),
                      padding='same',
                      activation='relu')(x20)
x22 = concatenate([x21, c3[:, 1:12, :, :, :]])
x23 = (ConvLSTM2D(filters=4 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm51',
                  activation='relu',
                  return_sequences=True))(x22)
x24 = (ConvLSTM2D(filters=4 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm52',
                  activation='relu',
                  return_sequences=True))(x23)
c5 = (ConvLSTM2D(filters=4 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm53',
                 activation='relu',
                 return_sequences=True))(x24)
x26 = BatchNormalization(center=True, scale=True)(c5)
x27 = Dropout(0.2)(x26)

x28 = Conv3D(filters=2 * c,
             kernel_size=(2, 1, 1),
             padding='valid',
             activation='relu')(x27)
x29 = (ConvLSTM2D(filters=2 * c,
                  kernel_size=(1, 1),
                  padding='same',
                  name='conv_lstm6',
                  activation='relu',
                  return_sequences=True))(x28)
x30 = Conv3DTranspose(filters=2 * c,
                      kernel_size=(1, 2, 2),
                      strides=(1, 2, 2),
                      padding='same',
                      activation='relu')(x29)
x31 = concatenate([x30, c2[:, 2:12, :, :, :]])
x32 = (ConvLSTM2D(filters=2 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm61',
                  activation='relu',
                  return_sequences=True))(x31)
x33 = (ConvLSTM2D(filters=2 * c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm62',
                  activation='relu',
                  return_sequences=True))(x32)
c6 = (ConvLSTM2D(filters=2 * c,
                 kernel_size=(3, 3),
                 padding='same',
                 name='conv_lstm63',
                 activation='relu',
                 return_sequences=True))(x33)
x34 = BatchNormalization(center=True, scale=True)(c6)
x35 = Dropout(0.2)(x34)

x36 = Conv3D(filters=c,
             kernel_size=(3, 1, 1),
             padding='valid',
             activation='relu')(x35)
x37 = (ConvLSTM2D(filters=c,
                  kernel_size=(1, 1),
                  padding='same',
                  name='conv_lstm7',
                  activation='relu',
                  return_sequences=True))(x36)
x38 = Conv3DTranspose(filters=c,
                      kernel_size=(1, 2, 2),
                      strides=(1, 2, 2),
                      padding='same',
                      activation='relu')(x37)
x39 = concatenate([x38, c1[:, 4:12, :, :, :]])
x40 = (ConvLSTM2D(filters=c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm71',
                  activation='relu',
                  return_sequences=True))(x39)
x41 = (ConvLSTM2D(filters=c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm72',
                  activation='relu',
                  return_sequences=True))(x40)
x42 = (ConvLSTM2D(filters=c,
                  kernel_size=(3, 3),
                  padding='same',
                  name='conv_lstm73',
                  activation='relu',
                  return_sequences=True))(x41)

residual_output = Conv3D(1, kernel_size=(1, 1, 1), padding='same')(x42)
output = Activation('linear', dtype='float32')(residual_output)

output = tf.squeeze(residual_output, axis=4)
# residual_input = x_init[:, :, :, :, 0]
# x = tf.expand_dims(residual_input, axis=4)
# last_timestep_input_residual = layers.Cropping3D(
#     cropping=((11, 0), (0, 0), (0, 0)), data_format="channels_last")(x)
# last_timestep_input_residual = layers.concatenate(
#     [last_timestep_input_residual] * 8, axis=1)
# combined = Add()([last_timestep_input_residual, residual_output])

model = Model(inputs, output)
model.summary()

In [None]:
#%%

model.compile(loss='mean_absolute_error',
              optimizer=keras.optimizers.Adam(learning_rate=0.001),
              metrics=['mae', 'mse'])
#mean_absolute_error
checkpoint_filepath = 'script_n1.h5'
callbacks = [
    EarlyStopping(patience=10, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
    ModelCheckpoint(filepath=checkpoint_filepath,
                    verbose=1,
                    save_best_only=True,
                    save_weights_only=True)
]

#%%
results = model.fit(X_train,
                    y_train,
                    batch_size=8,
                    epochs=50,
                    callbacks=callbacks,
                    verbose=1,
                    validation_data=(X_test, y_test))
