In [2]:
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Conv2DTranspose, Lambda
from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

class InstanceNormalization(Layer):
    """Instance normalization layer.
    Normalize the activations of the previous layer at each step,
    i.e. applies a transformation that maintains the mean activation
    close to 0 and the activation standard deviation close to 1.
    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=1` in `InstanceNormalization`.
            Setting `axis=None` will normalize all values in each
            instance of the batch.
            Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
        epsilon: Small float added to variance to avoid dividing by zero.
        center: If True, add offset of `beta` to normalized tensor.
            If False, `beta` is ignored.
        scale: If True, multiply by `gamma`.
            If False, `gamma` is not used.
            When the next layer is linear (also e.g. `nn.relu`),
            this can be disabled since the scaling
            will be done by the next layer.
        beta_initializer: Initializer for the beta weight.
        gamma_initializer: Initializer for the gamma weight.
        beta_regularizer: Optional regularizer for the beta weight.
        gamma_regularizer: Optional regularizer for the gamma weight.
        beta_constraint: Optional constraint for the beta weight.
        gamma_constraint: Optional constraint for the gamma weight.
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a Sequential model.
    # Output shape
        Same shape as input.
    # References
        - [Layer Normalization](https://arxiv.org/abs/1607.06450)
        - [Instance Normalization: The Missing Ingredient for Fast Stylization](
        https://arxiv.org/abs/1607.08022)
    """
    def __init__(self,
                 axis=None,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        ndim = len(input_shape)
        if self.axis == 0:
            raise ValueError('Axis cannot be zero')

        if (self.axis is not None) and (ndim == 2):
            raise ValueError('Cannot specify axis for rank 1 tensor')

        self.input_spec = InputSpec(ndim=ndim)

        if self.axis is None:
            shape = (1,)
        else:
            shape = (input_shape[self.axis],)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        reduction_axes = list(range(0, len(input_shape)))

        if self.axis is not None:
            del reduction_axes[self.axis]

        del reduction_axes[0]

        mean = K.mean(inputs, reduction_axes, keepdims=True)
        stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs - mean) / stddev

        broadcast_shape = [1] * len(input_shape)
        if self.axis is not None:
            broadcast_shape[self.axis] = input_shape[self.axis]

        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            normed = normed * broadcast_gamma
        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            normed = normed + broadcast_beta
        return normed

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(InstanceNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def conv_block_seq_res_fixe(inputs, filters, kernel_size, strides, name, bn=True, In=True, ResCon=True):
    outputs = tf.keras.layers.Conv1D(64, kernel_size, strides=strides, padding="same", name="{}_Conv1D1".format(name))(inputs)
    if bn:
        outputs = tf.keras.layers.BatchNormalization(name="{}_BatchNorm1".format(name))(outputs)
    outputs = tf.keras.layers.Activation("relu", name="{}_ReLU1".format(name))(outputs)
    outputs = tf.keras.layers.Conv1D(64, 1, strides=strides, padding="same", name="{}_Conv1D2".format(name))(outputs)
    if bn:
        outputs = tf.keras.layers.BatchNormalization(name="{}_BatchNorm2".format(name))(outputs)
    outputs = tf.keras.layers.Activation("relu", name="{}_ReLU2".format(name))(outputs)
    outputs = tf.keras.layers.Conv1D(256, kernel_size, strides=strides, padding="same", name="{}_Conv1D3".format(name))(outputs)
    if bn:
        outputs = tf.keras.layers.BatchNormalization(name="{}_BatchNorm3".format(name))(outputs)
    # Residual Add
    if ResCon:
        outputs = tf.keras.layers.Add()([outputs, inputs])
    
    if In:
        outputs = InstanceNormalization(name="{}_InstNorm2".format(name))(outputs)

    outputs = tf.keras.layers.Activation("relu", name="{}_ReLU3".format(name))(outputs)
    return outputs
    

def Conv1DTranspose(input_tensor, filters, kernel_size, strides=2, padding='same', activation=None):
    """
        input_tensor: tensor, with the shape (batch_size, time_steps, dims)
        filters: int, output dimension, i.e. the output tensor will have the shape of (batch_size, time_steps, filters)
        kernel_size: int, size of the convolution kernel
        strides: int, convolution step size
        padding: 'same' | 'valid'
    """
    x = Lambda(lambda x: K.expand_dims(x, axis=2))(input_tensor)
    x = Conv2DTranspose(filters=filters, kernel_size=(kernel_size, 1), strides=(strides, 1), padding=padding, activation=activation)(x)
    x = Lambda(lambda x: K.squeeze(x, axis=2))(x)
    return x


In [3]:
class Sampling(Layer):
    def __init__(self, **kwargs):
        super(Sampling, self).__init__(**kwargs)

    def call(self, inputs):
        z_mu, z_log_var = inputs
        batch_size = tf.shape(z_mu)[0]
        latent_dim = tf.shape(z_mu)[1]
        epsilon = tf.random.normal(shape=(batch_size, latent_dim))  # Generate noise
        z = z_mu + tf.exp(0.5 * z_log_var) * epsilon  # Reparameterization trick
        return z

In [8]:
def create_model(model, config, width, optimizer):
    
    config = "fixe_filter"
    
    if model == "VAE":
        def KL_loss(y_true, y_pred):
            # Regularization term
            kl_loss = - .5 * tf.reduce_sum(1 + z_log_var - K.square(z_mu) - K.exp(z_log_var), axis=-1)
            return kl_loss

        def Recon_loss(data_orig, data_reconstructed):
            reconstruction_loss = tf.reduce_mean((data_orig - data_reconstructed)**2)
            return reconstruction_loss

        def vae_loss(data_orig, data_reconstructed):
            reconstruction_loss = tf.reduce_mean((data_orig - data_reconstructed)**2)

            kl_loss = - .5 * tf.reduce_sum(1 + z_log_var - K.square(z_mu) - K.exp(z_log_var), axis=-1)

            return reconstruction_loss + kl_loss
            
        if config == "fixe_filter":
            start_filter_num = 256
            kernel_size = 3
            latent_dim = 1
            div = 64

            x = tf.keras.Input(shape=(width,1))
            print("Input shape: ", x.shape)
            # eps = tf.keras.Input(shape=(latent_dim,), name="eps")

            conv_seq1 = conv_block_seq_res_fixe(x, start_filter_num, kernel_size, 1, "conv_seq1", ResCon=False)
            pool1 = tf.keras.layers.MaxPooling1D(name="pool1")(conv_seq1)

            conv_seq2 = conv_block_seq_res_fixe(pool1, start_filter_num, kernel_size, 1, "conv_seq2")
            pool2 = tf.keras.layers.MaxPooling1D(name="pool2")(conv_seq2)

            conv_seq3 = conv_block_seq_res_fixe(pool2, start_filter_num, kernel_size, 1, "conv_seq3")
            pool3 = tf.keras.layers.MaxPooling1D(name="pool3")(conv_seq3)

            conv_seq4 = conv_block_seq_res_fixe(pool3, start_filter_num, kernel_size, 1, "conv_seq4")
            pool4 = tf.keras.layers.MaxPooling1D(name="pool4")(conv_seq4)

            conv_seq5 = conv_block_seq_res_fixe(pool4, start_filter_num, kernel_size, 1, "conv_seq5")
            pool5 = tf.keras.layers.MaxPooling1D(name="pool5")(conv_seq5)

            conv_seq6 = conv_block_seq_res_fixe(pool5, start_filter_num, kernel_size, 1, "conv_seq6", In=False)
            pool6 = tf.keras.layers.MaxPooling1D(name="pool6")(conv_seq6)

            conv_seq7 = conv_block_seq_res_fixe(pool6, start_filter_num, kernel_size, 1, "conv_seq7", In=False)
            flatten1 = tf.keras.layers.Flatten()(conv_seq7)

            z_mu = tf.keras.layers.Dense(latent_dim, name="z_mu")(flatten1)
            z_log_var = tf.keras.layers.Dense(latent_dim, name="z_log_var")(flatten1)

            ###############################################################################
            # normalize log variance to std dev
            # z_sigma = tf.keras.layers.Lambda(lambda t: K.exp(.5*t), name="z_sigma")(z_log_var)
            # eps = tf.keras.Input(tensor=K.random_normal(shape=(K.shape(x)[0], latent_dim)), name="eps")
            # eps = tf.keras.layers.Lambda(lambda _: tf.random.normal(shape=(tf.shape(x)[0], latent_dim)), name="eps")(x)


            # z_eps = tf.keras.layers.Multiply(name="z_eps")([z_sigma, eps])
            z = Sampling(name="sampling")([z_mu, z_log_var])
            # z = tf.keras.layers.Add(name="z")([z_mu, z_eps])
            #latent_conv = tf.keras.layers.Dense(width//64, name="latent_conv")(z)

            reshape1 = tf.keras.layers.Reshape([width//div,1], name="reshape1")(z)
            
            ###############################################################################
            #New for conditional VAE
            dconv_seq4 = conv_block_seq_res_fixe(reshape1, start_filter_num, kernel_size, 1, "dconv_seq4", In=False, ResCon=False)
            dconc5 = tf.keras.layers.concatenate([dconv_seq4, conv_seq7], name="dconc5")
            deconv1 = Conv1DTranspose(dconc5, start_filter_num, kernel_size=3, strides=2, padding='same')

            dconv_seq5 = conv_block_seq_res_fixe(deconv1, start_filter_num, kernel_size, 1, "dconv_seq5", In=False)
            dconc7 = tf.keras.layers.concatenate([dconv_seq5, conv_seq6], name="dconc7")
            deconv2 = Conv1DTranspose(dconc7, start_filter_num, kernel_size=3, strides=2, padding='same')

            dconv_seq6 = conv_block_seq_res_fixe(deconv2, start_filter_num, kernel_size, 1, "dconv_seq6", In=False)
            dconc9 = tf.keras.layers.concatenate([dconv_seq6, conv_seq5], name="dconc9")
            deconv3 = Conv1DTranspose(dconc9, start_filter_num, kernel_size=3, strides=2, padding='same')

            dconv_seq7 = conv_block_seq_res_fixe(deconv3, start_filter_num, kernel_size, 1, "dconv_seq7", In=False)
            dconc11 = tf.keras.layers.concatenate([dconv_seq7, conv_seq4], name="dconc11")
            deconv4 = Conv1DTranspose(dconc11, start_filter_num, kernel_size=3, strides=2, padding='same')

            dconv_seq8 = conv_block_seq_res_fixe(deconv4, start_filter_num, kernel_size, 1, "dconv_seq8", In=False)
            dconc13 = tf.keras.layers.concatenate([dconv_seq8, conv_seq3], name="dconc13")
            deconv5 = Conv1DTranspose(dconc13, start_filter_num, kernel_size=3, strides=2, padding='same')

            dconv_seq9 = conv_block_seq_res_fixe(deconv5, start_filter_num, kernel_size, 1, "dconv_seq9", In=False)
            dconc15 = tf.keras.layers.concatenate([dconv_seq9, conv_seq2], name="dconc15")
            deconv6 = Conv1DTranspose(dconc15, start_filter_num, kernel_size=3, strides=2, padding='same')

            dconv_seq10 = conv_block_seq_res_fixe(deconv6, start_filter_num, kernel_size, 1, "dconv_seq10", In=False)
            dconc17 = tf.keras.layers.concatenate([dconv_seq10, conv_seq1], name="dconc17")

            x_pred = tf.keras.layers.Conv1D(1, 3, padding="same", activation="relu", name="x_pred")(dconc17)

            model = tf.keras.Model(inputs=x, outputs=x_pred)
            model.summary()
            
        model.compile(optimizer=optimizer, loss=vae_loss, metrics=[KL_loss, vae_loss, "mean_absolute_error"])
        
        
    return model

class CustomStopper(tf.keras.callbacks.EarlyStopping):
    def __init__(self, monitor='val_loss', min_delta=0.0001, patience=10, verbose=1, mode='auto', start_epoch=5):
        super().__init__(monitor=monitor, min_delta=min_delta, patience=patience, verbose=verbose, mode=mode)
        self.start_epoch = start_epoch

    def on_epoch_end(self, epoch, logs=None):
        print("On epoch End!")
        if epoch > self.start_epoch:
            super().on_epoch_end(epoch, logs)
            print("On epoch End after starting point!")
    
class AdditionalValidationSets(tf.keras.callbacks.Callback):
    def __init__(self, validation_sets, verbose=0, batch_size=None):
        """
        :param validation_sets:
        a list of 3-tuples (validation_data, validation_targets, validation_set_name)
        or 4-tuples (validation_data, validation_targets, sample_weights, validation_set_name)
        :param verbose:
        verbosity mode, 1 or 0
        :param batch_size:
        batch size to be used when evaluating on the additional datasets
        """
        super(AdditionalValidationSets, self).__init__()
        self.validation_sets = validation_sets
        for validation_set in self.validation_sets:
            if len(validation_set) not in [2, 3]:
                raise ValueError()
        self.epoch = []
        self.history = {}
        self.verbose = verbose
        self.batch_size = batch_size

    def on_train_begin(self, logs=None):
        self.epoch = []
        self.history = {}

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epoch.append(epoch)

        # record the same values as History() as well
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        # evaluate on the additional validation sets
        for validation_set in self.validation_sets:
            if len(validation_set) == 3:
                validation_data, validation_targets, validation_set_name = validation_set
                sample_weights = None
            elif len(validation_set) == 4:
                validation_data, validation_targets, sample_weights, validation_set_name = validation_set
            else:
                raise ValueError()

            results = self.model.evaluate(x=validation_data,
                                          y=validation_targets,
                                          verbose=self.verbose,
                                          sample_weight=sample_weights,
                                          batch_size=self.batch_size)

            for i, result in enumerate(results):
                if i == 0:
                    valuename = validation_set_name + '_loss'
                else:
                    valuename = validation_set_name + '_' + str(self.model.metrics[i-1].name)
                self.history.setdefault(valuename, []).append(result)

def reconstruct(y, width, strides, merge_type="mean"):
    
    len_total = width+(y.shape[0]-1)*strides
    depth = width//strides
    
    yr = np.zeros([len_total, depth])
    yr[:] = np.nan
    
    for i in range(y.shape[0]):
        for d in range(depth):
            yr[i*strides+(d*strides):i*strides+((d+1)*strides),d] = y[i,d*strides:(d+1)*strides,0]
    if merge_type == "mean":
        yr = np.nanmean(yr, axis=1)
    else:
        yr = np.nanmedian(yr, axis=1)
    
    return yr