In [None]:
# https://github.com/Nixtla/neuralforecast/blob/main/nbs/models.deepar.ipynb
# https://github.com/Nixtla/neuralforecast/blob/main/nbs/losses.pytorch.ipynb

# DeepAR

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Distribution
from torch.distributions import NegativeBinomial, Normal, Poisson
from torch.distributions import constraints

In [3]:
def level_to_outputs(level):
    qs = sum([[50-l/2, 50+l/2] for l in level], [])
    output_names = sum([[f'-lo-{l}', f'-hi-{l}'] for l in level], [])

    # Sort in increasing order
    sort_idx = np.argsort(qs)
    quantiles = np.array(qs)[sort_idx]
    output_names = list(np.array(output_names)[sort_idx])
    
    # Add median by default
    quantiles = np.concatenate([np.array([50]), quantiles])
    quantiles = torch.Tensor(quantiles) / 100
    output_names.insert(0, '-median')

    return quantiles, output_names

In [None]:
def normal_scale_decouple(output, loc=None, scale=None, eps: float = 0.2):
    mean, std = output
    std = F.softplus(std)
    if (loc is not None) and (scale is not None):
        mean = (mean * scale) + loc
        std = (std + eps) * scale
    return mean, std


def nbinomial_scale_decouple(output, loc=None, scale=None):
    mu, alpha = output
    mu = F.softplus(mu) + 1e-08
    alpha = F.softplus(alpha) + 1e-08
    if (loc is not None) and (scale is not None):
        mu = mu * scale + loc
        alpha = alpha / (scale + 1.)

    total_count = 1.0 / alpha
    probs = (mu * alpha / (1.0 + mu * alpha)) + 1e-8
    return total_count, probs


def poisson_scale_decouple(output, loc=None, scale=None):
    eps = 1e-10
    rate, _ = output
    if (loc is not None) and (scale is not None):
        rate = (rate * scale) + loc
    rate = F.softplus(rate) + eps
    return (rate, )
    

In [8]:
# DistributionLoss

class DistributionLoss(nn.Module):
    def __init__(
        self,
        distribution: str,
        level=[80, 90],  # Confidence levels of prediction intervals
        num_samples=1000,
        return_params=False,
        **distribution_kwargs,
    ):
        super().__init__()
        qs, output_names = level_to_outputs(level)
        qs = torch.Tensor(qs)
        self.quantiles = torch.nn.Parameter(qs, requires_grad=False)
        self.output_names = output_names

        available_distributions = dict(
            Normal=Normal,
            NegativeBinomial=NegativeBinomial,
            Poisson=Poisson
        )
        scale_decouples = dict(
            Normal=normal_scale_decouple,
            NegativeBinomial=nbinomial_scale_decouple,
            Poisson=poisson_scale_decouple,
        )
        param_names = dict(
            Normal=["-loc", "-scale"],
            NegativeBinomial=["-total_count", "-logits"], 
            Poisson=["-loc"],
        )
        
        assert distribution in available_distributions
        self.distribution = distribution
        self._base_distribution = available_distributions[distribution]
        self.scale_decouple = scale_decouples[distribution]
        self.param_names = param_names[distribution]
        self.outputsize_multiplier = len(self.param_names)
        self.num_samples = num_samples
        self.return_params = return_params
        if self.return_params:
            self.output_names = self.output_names + self.param_names
        self.distribution_kwargs = distribution_kwargs
        
    
    def _domain_map(self, input: torch.Tensor):
        """
        Maps output of neural network to domain of distribution loss
        """
        output = torch.tensor_split(input, self.outputsize_multiplier, dim=2)
        return output
        
    
    def get_distribution(self, distr_args, **distribution_kwargs) -> Distribution:
        distr = self._base_distribution(*distr_args, **distribution_kwargs)
        self.distr_mean = distr.mean
        
        if self.distribution in ('Poisson', 'NegativeBinomial'):
              distr.support = constraints.nonnegative
        return distr
    
    def sample(self, distr_args: torch.Tensor, num_samples=None):
        if num_samples is None:
            num_samples = self.num_samples
        
        # Instantiate Scale Decoupled Distribution
        distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)
        samples = distr.sample(sample_shape=(num_samples,))
        samples = samples.permute(1, 2, 3, 0)  # [samples, B, H, N] -> [B, H, N, samples]
        
        # Compute mean and quantiles
        sample_mean = torch.mean(samples, dim=-1, keepdim=True)
        quants = torch.quantile(samples, self.quantiles, dim=-1)
        quants = quants.permute(1, 2, 3, 0)  # [Q, B, H, N] -> [B, H, N, Q]

        return samples, sample_mean, quants
    
    def __call__(self, y: torch.Tensor, distr_args: torch.Tensor):
        # Instantiate Scale Decoupled Distribution
        distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)
        loss_values = -distr.log_prob(y)
        return loss_values.mean()



In [9]:
distr_loss = DistributionLoss(distribution="Normal", level=[80, 90])

In [11]:
distr_loss.quantiles

Parameter containing:
tensor([0.5000, 0.0500, 0.1000, 0.9000, 0.9500])

In [12]:
distr_loss.output_names

['-median',
 np.str_('-lo-90'),
 np.str_('-lo-80'),
 np.str_('-hi-80'),
 np.str_('-hi-90')]

In [15]:
distr = distr_loss.get_distribution(torch.Tensor([0.0, 1.0]))

In [20]:
distr.sample(sample_shape=(100, 3))

tensor([[ 1.8742,  1.0081, -0.0538],
        [-0.1986, -0.7672,  2.0808],
        [ 0.1821,  0.3600,  0.2400],
        [-0.0665,  0.6896, -0.7004],
        [ 2.1068, -0.3302, -0.8646],
        [-0.2016, -0.2594, -0.3241],
        [-0.4470,  0.1718,  1.6635],
        [ 2.2733,  1.1294, -0.1248],
        [ 1.5230,  0.2701,  0.4043],
        [ 0.0521,  0.0704, -0.8703],
        [ 0.1721, -0.6784, -0.1979],
        [-0.9488, -0.8515, -1.3582],
        [-1.2985,  0.6006, -0.0716],
        [ 0.8799, -0.7836, -0.1932],
        [-0.8165, -0.6849,  0.9421],
        [-0.7901,  0.0940, -0.0192],
        [-0.0060, -2.0451,  0.4674],
        [-0.1573, -0.1187,  1.0787],
        [-0.0842, -1.9280, -0.4245],
        [-1.4284, -0.2805,  2.5322],
        [ 0.4390, -0.9358,  0.7516],
        [ 0.0729, -0.0640,  1.1803],
        [ 0.9518, -0.9017, -0.8626],
        [ 0.6758,  1.1942, -1.2869],
        [ 2.7938, -0.9241, -0.9130],
        [-0.0929,  1.0451,  1.5427],
        [-0.7759, -0.2445, -0.6072],
 