In [1]:
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers, regularizers
from keras.layers import *

import tensorflow as tf

from keras import backend as K

# Auxilary

# Model

## Building block

### Resnet

In [2]:
def resnet_unit(feat_dim, kernel_size, x_in):
    # conv = Conv2D(feats, kernel, padding="same")
    res = keras.Sequential([
        Conv2D(feat_dim,
               kernel_size, padding="same",
               kernel_initializer='he_uniform',
               bias_initializer='he_uniform'),
        ReLU(),
        Conv2D(feat_dim,
               kernel_size,
               padding="same",
               kernel_initializer='he_uniform',
            bias_initializer='he_uniform')
    ])
    return ReLU()(x_in + res(x_in))

def resnet_block(x_in, feat_dim, kernel_size, reps, pooling = True):
    # Stage 2
    conv1 = Conv2D(feat_dim,
                   kernel_size,
                   padding="same",
                   kernel_initializer='he_uniform',
                   bias_initializer='he_uniform')(x_in)
    relu1 = ReLU()(conv1)
    conv2 = Conv2D(feat_dim,
                   kernel_size,
                   padding="same",
                   kernel_initializer='he_uniform',
                   bias_initializer='he_uniform')(relu1)
    x = ReLU()(conv2)
    for _ in range(reps):
        x = resnet_unit(feat_dim,kernel_size,x)
    if pooling == True:
        x = MaxPooling2D(2,2)(x)
        return x
    else:
        return x

### Convolution block

In [3]:
def conv_unit(feat_dim, kernel_size, x):
    x = Conv2D(feat_dim, 
               kernel_size, 
               activation = LeakyReLU(0.2), 
               padding="same", 
               kernel_initializer='he_uniform', 
               bias_initializer='he_uniform')(x)
    x = Conv2D(feat_dim,
               1,
               activation = LeakyReLU(0.2),
               padding="same",
               kernel_initializer='he_uniform',
               bias_initializer='he_uniform')(x)
    return x

def conv_block_down(x, feat_dim, reps, kernel_size, mode = 'normal'):
    if mode == 'down':
        x = MaxPooling2D(2,2)(x)
    for _ in range(reps):
        x = conv_unit(feat_dim, 
                      kernel_size,
                      x)

    return x

def conv_block_up_w_concat(x, x1, feat_dim, reps, kernel_size, mode = 'normal'):
    if mode == 'up':
        x = UpSampling2D((2,2))(x)
    
    x = Concatenate()([x,x1])
    for _ in range(reps):
        x = conv_unit(feat_dim,
                      kernel_size,
                      x)
    return x

def conv_block_up_wo_concat(x, feat_dim, reps, kernel_size, mode = 'normal'):
    if mode == 'up':
        x = UpSampling2D((2,2))(x)
    for _ in range(reps):
        x = conv_unit(feat_dim,
                      kernel_size,
                      x)
    return x

### SPADE

In [4]:
class SPADE(layers.Layer):
    def __init__(self, filters, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.conv = layers.Conv2D(filters, 3, padding="same", activation="relu",kernel_initializer='he_uniform',
               bias_initializer='he_uniform',kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4))
        self.conv_gamma = layers.Conv2D(filters, 3, padding="same",kernel_initializer='he_uniform',
               bias_initializer='he_uniform',kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4))
        self.conv_beta = layers.Conv2D(filters, 3, padding="same",kernel_initializer='he_uniform',
               bias_initializer='he_uniform',kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4))

    def build(self, input_shape):
        self.resize_shape = input_shape[1:3]
        # print(self.resize_shape)

    def call(self, input_tensor, raw_mask):
        mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
        x = self.conv(mask)    

        gamma = self.conv_gamma(x)
        beta = self.conv_beta(x)
        mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
        std = tf.sqrt(var + self.epsilon)
        # print(mean.shape)

        normalized = (input_tensor - mean) / std
        output = gamma * normalized + beta

        return output

def spade_generator_unit(x, mask, feats_in, kernel, upsampling = True):
    x = GaussianNoise(0.05)(x)
    # SPADE & conv
    spade1 = SPADE(feats_in)(x, mask)
    output = Conv2D(feats_in,kernel, padding='same', activation= LeakyReLU(0.2),kernel_initializer='he_uniform',
               bias_initializer='he_uniform',kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4))(spade1)
    if upsampling == True:
        output = UpSampling2D(size = (2,2))(output)
        return output
    else:
        return output

### Adaptive Instance Normalization

In [5]:
class AdaIN(layers.Layer):
    def __init__(self, filters, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.dense = layers.Dense(filters,
                                  activation = 'relu',
                                  kernel_initializer='he_uniform',
                                  bias_initializer='he_uniform')
        self.dense_gamma = layers.Dense(filters,
                                        kernel_initializer='he_uniform',
                                        bias_initializer='he_uniform')
        self.dense_beta = layers.Dense(filters,
                                       kernel_initializer='he_uniform',
                                       bias_initializer='he_uniform')

    def call(self, input_tensor, style_vector):
        x = self.dense(style_vector)
        gamma = self.dense_gamma(x)
        beta = self.dense_beta(x)
        #Normalize x[0]
        mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
        std = tf.sqrt(var + self.epsilon)
        normalized = (input_tensor - mean) / std
        y = (x[0] - mean) / std

        #Reshape gamma and beta
        pool_shape = [-1, 1, 1, y.shape[-1]]
        gamma = tf.reshape(gamma, pool_shape) + 1.0
        beta = tf.reshape(beta, pool_shape)

        return gamma * normalized + beta

### Feature extraction CNN

In [6]:
def feature_extraction_unet(n_out_features = 128, n_base_features = 64, input_shape = (256,512,5)):
    inputs = keras.Input(shape = input_shape)

    conv1 = conv_block_down(inputs,
                            feat_dim = n_base_features,
                            reps = 1,
                            kernel_size = 3)
    conv2 = conv_block_down(conv1,
                            feat_dim = n_base_features*2,
                            reps = 1,
                            kernel_size = 3,
                            mode = 'down')
    conv3 = conv_block_down(conv2,
                            feat_dim = n_base_features*4,
                            reps = 1,
                            kernel_size = 3,
                            mode = 'down')
    conv4 = conv_block_down(conv3,
                            feat_dim = n_base_features*8,
                            reps = 1,
                            kernel_size = 3,
                            mode = 'down')
    conv5 = conv_block_down(conv4,
                            feat_dim = n_base_features*16,
                            reps = 1,
                            kernel_size = 3,
                            mode = 'down')
    conv6 = conv_block_up_wo_concat(conv5,
                                    feat_dim = n_base_features*8,
                                    reps = 1,
                                    kernel_size = 3,
                                    mode = 'up')
    
    conv7 = conv_block_up_w_concat(conv6, conv3,
                                    feat_dim = n_base_features*4,
                                    reps = 1,
                                    kernel_size = 3,
                                    mode = 'up')
    
    conv8 = conv_block_up_wo_concat(conv7,
                                    feat_dim = n_base_features*2,
                                    reps = 1,
                                    kernel_size = 3,
                                    mode = 'up')
    
    conv9 = conv_block_up_w_concat(conv8, conv1,
                                    feat_dim = n_out_features,
                                    reps = 1,
                                    kernel_size = 3,
                                    mode = 'up')

    feature_out = conv_block_up_wo_concat(conv9,
                                    feat_dim = n_out_features,
                                    reps = 1,
                                    kernel_size = 1,
                                    mode = 'normal')
    unet = keras.Model(inputs, feature_out)
    return unet

### Mapping and reconstruction CNN

In [14]:
def mapping_and_recon_cnn(n_base_features = 64, input_shape = (256,512,128) ):
    inputs = keras.Input(shape = input_shape)
    inputs_2 = keras.Input(shape = (256,512,1))
    spade1 = spade_generator_unit(inputs,
                                  inputs_2,
                                  128,
                                  1,
                                  upsampling = False)
    conv1 = resnet_block(spade1,
                         feat_dim = n_base_features,
                         kernel_size = 1,
                         reps = 2,
                         pooling = False)
    spade2 = spade_generator_unit(conv1,
                                  inputs_2,
                                  n_base_features,
                                  1,
                                  upsampling = False)
    conv2 = resnet_block(spade2,
                         feat_dim = n_base_features*2,
                         kernel_size = 1,
                         reps = 2,
                         pooling = False)
    
    spade3 = spade_generator_unit(conv2,
                                  inputs_2,
                                  n_base_features*2,
                                  1,
                                  upsampling = False)
    conv3 = resnet_block(spade3,
                         feat_dim = n_base_features*4,
                         kernel_size = 1,
                         reps = 2,
                         pooling = False)
    conv_out = Conv2D(1,1,padding='same')(conv3)
    mapping_resnet = keras.Model([inputs,inputs_2], conv_out)
    return mapping_resnet
# mapping_and_recon_cnn().summary()
# PARCv2

### Advection layer

In [57]:
class Advection(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.trainable = False

    def call(self, state_variable, velocity_field):
        dy, dx = tf.image.image_gradients(state_variable)
        spatial_deriv = tf.concat([dy,dx],axis = -1)
        advect = tf.reduce_sum(tf.multiply(spatial_deriv,velocity_field),axis = -1, keepdims=True)
        print(advect.shape)
        return advect


(None, 256, 512, 1)
Model: "model_38"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_95 (InputLayer)          [(None, 256, 512, 1  0           []                               
                                )]                                                                
                                                                                                  
 input_96 (InputLayer)          [(None, 256, 512, 2  0           []                               
                                )]                                                                
                                                                                                  
 advection_40 (Advection)       (None, 256, 512, 1)  0           ['input_95[0][0]',               
                                                                  'inpu

### Integration CNN

In [15]:
def integrator_cnn(n_base_features = 128, input_shape = (256,512,1)):
    inputs = keras.Input(shape = input_shape)
    conv1 = resnet_block(inputs,
                         feat_dim = n_base_features,
                         kernel_size = 1,
                         reps = 2,
                         pooling = False)
    conv2 = resnet_block(conv1,
                         feat_dim = n_base_features*2,
                         kernel_size = 1,
                         reps = 2,
                         pooling = False)
    conv3 = resnet_block(conv2,
                         feat_dim = n_base_features*4,
                         kernel_size = 1,
                         reps = 2,
                         pooling = False)
    conv_out = Conv2D(1,1,padding='same')(conv3)
    integrator_resnet = keras.Model(inputs, conv_out)
    return integrator_resnet
# integrator_cnn().summary()

### PARCv2

In [None]:
class PARC_2(keras.Model):
    def __init__(self, **kwargs):
        super(PARC_2, self).__init__(**kwargs)
        self.encoder = resnet_based_encoder()
        self.decoder = spade_based_generator()
        self.latent_evolution = latent_evolution_model()
        # loss define
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.in_ae_loss_tracker = keras.metrics.Mean(name="in_ae_loss")
        self.out_ae_loss_tracker = keras.metrics.Mean(name="out_ae_loss")
        self.latent_loss_tracker = keras.metrics.Mean(name="latent_loss")
        self.recon_loss_tracker = keras.metrics.Mean(name="recon_loss")

    @property
    def metrics(self):
        return [
        self.total_loss_tracker,
        self.in_ae_loss_tracker,
        self.out_ae_loss_tracker,
        self.latent_loss_tracker,
        self.recon_loss_tracker
        ]

    def train_step(self, data):

        input_temp = data[0]
        output_temp = data[1]
        total_loss = 0
        with tf.GradientTape() as tape:
            # AE loss
            latent_temp_in = self.encoder(input_temp)
            ae_recon_temp_in = self.decoder(latent_temp_in)
            in_ae_loss = tf.keras.losses.MeanSquaredError()(input_temp,ae_recon_temp_in)
            latent_temp_out = self.encoder(output_temp)
            ae_recon_temp_out = self.decoder(latent_temp_out)
            out_ae_loss = tf.keras.losses.MeanSquaredError()(output_temp,ae_recon_temp_out)

            ae_loss = in_ae_loss + out_ae_loss

            pred_latent_temp_out = self.latent_evolution(latent_temp_in, latent_temp_out)
            latent_loss = tf.keras.losses.MeanSquaredError()(pred_latent_temp_out,latent_temp_out)
            recon_latent_temp = self.decoder(pred_latent_temp_out)
            recon_loss = tf.keras.losses.MeanSquaredError()(output_temp,recon_latent_temp)
            

            total_loss =  recon_loss + ae_loss + latent_loss
            
        self.total_loss_tracker.update_state(total_loss)
        self.in_ae_loss_tracker.update_state(in_ae_loss)
        self.out_ae_loss_tracker.update_state(out_ae_loss)
        self.latent_loss_tracker.update_state(latent_loss)
        self.recon_loss_tracker.update_state(recon_loss)
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))


        return {
            "total_loss": self.total_loss_tracker.result(),
            "in_ae_loss": self.in_ae_loss_tracker.result(),
            "out_ae_loss": self.out_ae_loss_tracker.result(),
            "latent_loss": self.latent_loss_tracker.result(),
            "recon_loss": self.recon_loss_tracker.result(),
        }
        

# Data pipeline

# Training

# Validation