# Modules

In [None]:
# Modules
from tensorflow import keras
from keras.layers import Dense
from keras.layers.core import Activation
from keras.layers import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.core import Flatten
from keras.layers import Input
from keras.layers.convolutional import Conv2D
from keras.models import Model
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers import add
from keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Softmax

# part 2
from keras.models import Model
from keras.layers import Input
from tensorflow.keras.optimizers import Adam
import numpy as np
from tqdm import tqdm
from numpy import save
from numpy import load
import tensorflow.keras.backend as K
import tensorflow as tf
#from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import Callback, ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
import xarray as xr
from tensorflow.keras.utils import to_categorical

# Code from Paper

In [None]:
# Residual block
def res_block_gen(model, kernal_size, filters, strides, initializer):
    
    gen = model
    
    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same", kernel_initializer=initializer)(model)
    model = BatchNormalization(momentum = 0.5)(model)
    # Using Parametric ReLU
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same", kernel_initializer=initializer)(model)
    model = BatchNormalization(momentum = 0.5)(model)
        
    model = add([gen, model])
    
    return model
    

# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf
class Generator(object):

    def __init__(self, noise_shape):
        
        self.noise_shape = noise_shape

    def generator(self):
        init = RandomNormal(stddev=0.02)
        
        gen_input = Input(shape = self.noise_shape)
        model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(gen_input)
        model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
	    
        gen_model = model
        
        # Using 16 Residual Blocks
        for index in range(16):
	        model = res_block_gen(model, 3, 64, 1, init)
            
	    
        model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model)
        model = BatchNormalization(momentum = 0.5)(model)
        model = add([gen_model, model])
	    
        model = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model)
        
        # Task1 for downscaling to reanalysis (same task 2 at the moment)
        model1 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model)
        model1 = UpSampling2D(size = 3)(model1)
        model1 = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model1)
        
        model1 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model1)
        model1 = UpSampling2D(size = 3)(model1)
        model1 = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model1)
    
        model1 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model1)
        model1 = UpSampling2D(size = 5)(model1)
        model1 = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model1)
	    
        output1 = Conv2D(filters = 1, kernel_size = 9, strides = 1, padding = "same", kernel_initializer=init)(model1)
        
        # Task2 for downscaling to WRF
        model2 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model)
        model2 = UpSampling2D(size = 3)(model2)
        model2 = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model2)
        
        model2 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model2)
        model2 = UpSampling2D(size = 3)(model2)
        model2 = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model2)
    
        model2 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer=init)(model2)
        model2 = UpSampling2D(size = 5)(model2)
        model2 = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model2)
	    
        output2 = Conv2D(filters = 1, kernel_size = 9, strides = 1, padding = "same", kernel_initializer=init)(model2)
	   
        generator_model = Model(inputs = gen_input, outputs = [output1, output2]) # reanalysis, WRF
        
        return generator_model

# Loss functions

In [None]:
def my_MSE_weighted(y_true, y_pred):
  weights= tf.clip_by_value(y_true, K.log(0.1+1), K.log(100.0+1))
  return K.mean(tf.multiply(weights, tf.abs(tf.subtract(y_pred, y_true))))

def make_FSS_loss(mask_size):  # choose any mask size for calculating densities

    def my_FSS_loss(y_true, y_pred):

        # First: DISCRETIZE y_true and y_pred to have only binary values 0/1 
        # (or close to those for soft discretization)
        want_hard_discretization = False

        # This example assumes that y_true, y_pred have the shape (None, N, N, 1).
        
        cutoff = 0.5  # choose the cut off value for discretization

        if (want_hard_discretization):
           # Hard discretization:
           # can use that in metric, but not in loss
           y_true_binary = tf.where(y_true>cutoff, 1.0, 0.0)
           y_pred_binary = tf.where(y_pred>cutoff, 1.0, 0.0)

        else:
           # Soft discretization
           c = 10 # make sigmoid function steep
           y_true_binary = tf.math.sigmoid( c * ( y_true - cutoff ))
           y_pred_binary = tf.math.sigmoid( c * ( y_pred - cutoff ))

        # Done with discretization.

        # To calculate densities: apply average pooling to y_true.
        # Result is O(mask_size)(i,j) in Eq. (2) of [RL08].
        # Since we use AveragePooling, this automatically includes the factor 1/n^2 in Eq. (2).
        pool1 = tf.keras.layers.AveragePooling2D(pool_size=(mask_size, mask_size), strides=(1, 1), 
           padding='valid')
        y_true_density = pool1(y_true_binary);
        # Need to know for normalization later how many pixels there are after pooling
        n_density_pixels = tf.cast( (tf.shape(y_true_density)[1] * tf.shape(y_true_density)[2]) , 
           tf.float32 )

        # To calculate densities: apply average pooling to y_pred.
        # Result is M(mask_size)(i,j) in Eq. (3) of [RL08].
        # Since we use AveragePooling, this automatically includes the factor 1/n^2 in Eq. (3).
        pool2 = tf.keras.layers.AveragePooling2D(pool_size=(mask_size, mask_size),
                                                 strides=(1, 1), padding='valid')
        y_pred_density = pool2(y_pred_binary);

        # This calculates MSE(n) in Eq. (5) of [RL08].
        # Since we use MSE function, this automatically includes the factor 1/(Nx*Ny) in Eq. (5).
        MSE_n = tf.keras.losses.MeanSquaredError()(y_true_density, y_pred_density)

        # To calculate MSE_n_ref in Eq. (7) of [RL08] efficiently:
        # multiply each image with itself to get square terms, then sum up those terms.

        # Part 1 - calculate sum( O(n)i,j^2
        # Take y_true_densities as image and multiply image by itself.
        O_n_squared_image = tf.keras.layers.Multiply()([y_true_density, y_true_density])
        # Flatten result, to make it easier to sum over it.
        O_n_squared_vector = tf.keras.layers.Flatten()(O_n_squared_image)
        # Calculate sum over all terms.
        O_n_squared_sum = tf.reduce_sum(O_n_squared_vector)

        # Same for y_pred densitites:
        # Multiply image by itself
        M_n_squared_image = tf.keras.layers.Multiply()([y_pred_density, y_pred_density])
        # Flatten result, to make it easier to sum over it.
        M_n_squared_vector = tf.keras.layers.Flatten()(M_n_squared_image)
        # Calculate sum over all terms.
        M_n_squared_sum = tf.reduce_sum(M_n_squared_vector)
    
        MSE_n_ref = (O_n_squared_sum + M_n_squared_sum) / n_density_pixels
        
        # FSS score according to Eq. (6) of [RL08].
        # FSS = 1 - (MSE_n / MSE_n_ref)

        # FSS is a number between 0 and 1, with maximum of 1 (optimal value).
        # In loss functions: We want to MAXIMIZE FSS (best value is 1), 
        # so return only the last term to minimize.

        # Avoid division by zero if MSE_n_ref == 0
        # MSE_n_ref = 0 only if both input images contain only zeros.
        # In that case both images match exactly, i.e. we should return 0.
        my_epsilon = tf.keras.backend.epsilon()  # this is 10^(-7)

        if (want_hard_discretization):
           if MSE_n_ref == 0:
              return( MSE_n )
           else:
              return( MSE_n / MSE_n_ref )
        else:
           return (MSE_n / (MSE_n_ref + my_epsilon) )

    return my_FSS_loss 

mask_size = 6 

# Training Architecture

In [None]:
image_shape_hr = (90, 135, 1)
image_shape_lr = (2, 3, 1)
downscale_factor = 45

In [None]:
# load low resolution REFORECAST data for training
reforecast_train = tf.random.normal((100, 2, 3, 1)).numpy() # 13 , 16

# load high resolution REANALYSIS data for training
reanalysis_train = tf.random.normal((100, 90, 135, 1)).numpy()  # 156, 192

#load low resolution REFORECAST data for validation
reforecast_val = tf.random.normal((10, 2, 3, 1)).numpy() 

#load high resolution REANALYSIS data for validation
reanalysis_val = tf.random.normal((10, 90, 135, 1)).numpy() 

#load high resolution WRF data for training
WRF_train = tf.random.normal((100, 90, 135, 1)).numpy() 

#load high resolution WRF data for validation
WRF_val = tf.random.normal((10, 90, 135, 1)).numpy() 

In [None]:
# batch_size = 64
# len(np.random.randint(0, merra2_train.shape[0], size=batch_size))
# merra2_train[np.random.randint(0, merra2_train.shape[0], size=64)].shape

In [None]:
def train(epochs, batch_size):
    
    x_train_lr = reforecast_train             # reforecast lr
    y_train_hr = reanalysis_train             # reanalysis hr
    
    x_val_lr = reforecast_val                 # reforecast lr val
    y_val_hr = reanalysis_val                 # reanalysis hr val
    
    x_train_lr = reforecast_train             # reforecast lr
    y_train_WRF = WRF_train                   # WRF hr
    y_train_hr = reanalysis_train             # reanalysis
    
    x_val_lr = reforecast_val                 # reforecast
    y_val_WRF = WRF_val                       # WRF hr val
    y_val_hr = reanalysis_val                 # reanalysis
    
 #   loss=MSE_LOSS(image_shape_hr)
    
    batch_count = int(x_train_lr.shape[0] / batch_size)
    
    generator = Generator(image_shape_lr).generator()
    generator.compile(loss=[make_FSS_loss(mask_size), my_MSE_weighted], optimizer = Adam(learning_rate=0.0001, beta_1=0.9), loss_weights=[0.01, 1.0],metrics=['mae', 'mse'])
    loss_file = open('losses.txt' , 'w+')
    loss_file.close()
        
    for e in range(1, epochs+1):
        
        print ('-'*15, 'Epoch %d' % e, '-'*15)
        
        for _ in tqdm(range(batch_count)):
            
            rand_nums = np.random.randint(0, x_train_lr.shape[0], size=batch_size)
            
            x_lr = x_train_lr[rand_nums]   # reforecast lr
            y_hr = y_train_hr[rand_nums]   # reanalysis hr
            y_WRF = y_train_WRF[rand_nums] # WRF hr

            gen_loss = generator.train_on_batch(x_lr, [y_WRF,y_hr])

        gen_loss = str(gen_loss)
        val_loss = generator.evaluate(x_val_lr, [y_val_WRF, y_val_hr], verbose=0)
        val_loss = str(val_loss)
        loss_file = open('losses.txt' , 'a') 
        loss_file.write('epoch%d : generator_loss = %s; validation_loss = %s\n' 
                        %(e, gen_loss, val_loss))
        
        loss_file.close()
        # if e <=20:
        #     if e  % 5== 0:
        #         generator.save('gen_model%d.h5' % e)
        # else:
        #      if e  % 10 == 0:
        #         generator.save('gen_model%d.h5' % e)
        


train(1, 3)