In [None]:
class Exponentiate(keras.layers.Layer):
    """Custom layer to exp the sigma and tau estimates inline."""

    def __init__(self, name=None, **kwargs):
        super(Exponentiate, self).__init__(name=name)
        super(Exponentiate, self).__init__(**kwargs)

    def get_config(self):
        config = super(Exponentiate, self).get_config()
        return config

    def call(self, inputs):
        return tf.math.exp(inputs)
    

    
class CustomMAE(tf.keras.metrics.Metric):
    """Compute the prediction mean absolute error.

    The "predicted value" is the median of the conditional distribution.

    Notes
    -----
    * The computation is done by maintaining running sums of total predictions
        and correct predictions made across all batches in an epoch. The
        running sums are reset at the end of each epoch.

    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.error = self.add_weight("error", initializer="zeros")
        self.total = self.add_weight("total", initializer="zeros")

    def update_state(self, y_true, pred, sample_weight=None):
        mu = pred[:, 0]
        sigma = pred[:, 1]
        
        if pred.shape[1] >= 3:
            gamma = pred[:, 2]
        else:
            gamma = tf.zeros_like(mu)

        if pred.shape[1] >= 4:
            tau = pred[:, 3]
        else:
            tau = tf.ones_like(mu)

        predictions = shash_median(mu, sigma, gamma, tau)

        error = tf.math.abs(y_true[:, 0] - predictions)
        batch_error = tf.reduce_sum(error)
        batch_total = tf.math.count_nonzero(error)

        self.error.assign_add(tf.cast(batch_error, tf.float32))
        self.total.assign_add(tf.cast(batch_total, tf.float32))

    def result(self):
        return self.error / self.total

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
   
    
def compute_NLL(y, distr): 
    return -distr.log_prob(y) 

def compute_shash_NLL(y_true, pred):
    """Negative log-likelihood loss using the sinh-arcsinh normal distribution.

    Arguments
    ---------
    y_true : tensor
        The ground truth values.
        shape = [batch_size, n_parameter]

    pred :
        The predicted local conditionsal distribution parameter values.
        shape = [batch_size, n_parameters]

    Returns
    -------
    loss : tensor, shape = [1, 1]
        The average negative log-likelihood of the batch using the predicted
        conditional distribution parameters.

    Notes
    -----
    * The value of n_parameters depends on the chosen form of the conditional
        sinh-arcsinh normal distribution.
            shash2 -> n_parameter = 2, i.e. mu, sigma
            shash3 -> n_parameter = 3, i.e. mu, sigma, gamma
            shash4 -> n_parameter = 4, i.e. mu, sigma, gamma, tau

    * Since sigma and tau must be strictly positive, the network learns the
        log of these two parameters.

    * If gamma is not learned (i.e. shash2), they are set to 0.

    * If tau is not learned (i.e. shash2 or shash3), they are set to 1.

    """
    mu = pred[:, 0]
    sigma = pred[:, 1]

    if pred.shape[1] >= 3:
        gamma = pred[:, 2]
    else:
        gamma = tf.zeros_like(mu)
    
    if pred.shape[1] >= 4:
        tau = pred[:, 3]
    else:
        tau = tf.ones_like(mu)

    loss = -shash_log_prob(y_true[:, 0], mu, sigma, gamma, tau)
    return tf.reduce_mean(loss, axis=-1)