# Trainer

In [2]:
from typing import Optional, Union

from copy import deepcopy
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

from gluonts.core.component import validated
from pts import Trainer


class TrainerForecasting(Trainer):
    @validated()
    def __init__(
        self,
        epochs: int = 100,
        batch_size: int = 32,
        num_batches_per_epoch: int = 50,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-6,
        maximum_learning_rate: float = 1e-2,
        clip_gradient: Optional[float] = None,
        patience: int = None,
        device: Optional[Union[torch.device, str]] = None,
        **kwargs,
    ) -> None:
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_batches_per_epoch = num_batches_per_epoch
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.maximum_learning_rate = maximum_learning_rate
        self.clip_gradient = clip_gradient
        self.patience = patience
        self.device = device

    def __call__(
        self,
        net: nn.Module,
        train_iter: DataLoader,
        validation_iter: Optional[DataLoader] = None,
    ) -> None:

        optimizer = Adam(net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        lr_scheduler = OneCycleLR(
            optimizer,
            max_lr=self.maximum_learning_rate,
            steps_per_epoch=self.num_batches_per_epoch,
            epochs=self.epochs,
        )

        # Early stopping setup
        best_loss = float('inf')
        waiting = 0
        best_net = deepcopy(net.state_dict())

        # Training loop
        for epoch_no in range(self.epochs):
            # mark epoch start time
            cumm_epoch_loss = 0.0
            total = self.num_batches_per_epoch - 1

            # training loop
            with tqdm(train_iter, total=total) as it:
                for batch_no, data_entry in enumerate(it, start=1):

                    optimizer.zero_grad()

                    inputs = [v.to(self.device) for v in data_entry.values()]

                    loss = net(*inputs)

                    if isinstance(loss, (list, tuple)):
                        loss = loss[0]

                    loss.backward()

                    if self.clip_gradient is not None:
                        nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)

                    optimizer.step()
                    lr_scheduler.step()

                    cumm_epoch_loss += loss.item()
                    avg_epoch_loss = cumm_epoch_loss / batch_no
                    it.set_postfix(
                        {
                            "epoch": f"{epoch_no + 1}/{self.epochs}",
                            "avg_loss": avg_epoch_loss,
                        },
                        refresh=False,
                    )

                    if self.num_batches_per_epoch == batch_no:
                        break
                it.close()

            # validation loop
            if validation_iter is not None:
                cumm_epoch_loss_val = 0.0
                with tqdm(validation_iter, total=total, colour="green") as it:

                    for batch_no, data_entry in enumerate(it, start=1):
                        inputs = [v.to(self.device) for v in data_entry.values()]
                        with torch.no_grad():
                            output = net(*inputs)
                        if isinstance(output, (list, tuple)):
                            loss = output[0]
                        else:
                            loss = output

                        cumm_epoch_loss_val += loss.item()
                        avg_epoch_loss_val = cumm_epoch_loss_val / batch_no
                        it.set_postfix(
                            {
                                "epoch": f"{epoch_no + 1}/{self.epochs}",
                                "avg_loss": avg_epoch_loss,
                                "avg_val_loss": avg_epoch_loss_val,
                            },
                            refresh=False,
                        )

                        if self.num_batches_per_epoch == batch_no:
                            break
                it.close()

                # Early stopping logic
                if avg_epoch_loss_val < best_loss:
                    best_loss = avg_epoch_loss_val
                    best_net = deepcopy(net.state_dict())
                    waiting = 0
                elif waiting > self.patience:
                    print(f'Early stopping at epoch {epoch_no}')
                    break
                else:
                    waiting += 1

            # mark epoch end time and log time cost of current epoch

        net.load_state_dict(best_net)

# python -m tsdiff.train_forecasting --seed 1 --dataset electricity_nips --network timegrad_rnn --noise ou --epochs 100

# Model Estimator

In [3]:
from typing import Any, Callable, List, Optional

import torch

from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.util import copy_parameters
from gluonts.model.predictor import Predictor
from gluonts.transform import (
    Transformation,
    Chain,
    InstanceSplitter,
    ExpectedNumInstanceSampler,
    ValidationSplitSampler,
    TestSplitSampler,
    RenameFields,
    AsNumpyArray,
    ExpandDimArray,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    VstackFeatures,
    SetFieldIfNotPresent,
    TargetDimIndicator,
)
from gluonts.core.component import validated

from pts.feature import (
    fourier_time_features_from_frequency,
    lags_for_fourier_time_features_from_frequency,
)
from pts.model import PyTorchEstimator
from pts.model.utils import get_module_forward_input_names



In [4]:
class ScoreEstimator(PyTorchEstimator):
    def __init__(
        self,
        training_net: Callable,
        prediction_net: Callable,
        noise: str,
        input_size: int,
        freq: str,
        prediction_length: int,
        target_dim: int,
        trainer: TrainerForecasting = TrainerForecasting(),
        context_length: Optional[int] = None,
        num_layers: int = 2,
        num_cells: int = 40,
        cell_type: str = "GRU",
        num_parallel_samples: int = 100,
        dropout_rate: float = 0.1,
        cardinality: List[int] = [1],
        embedding_dimension: int = 5,
        hidden_dim: int = 100,
        diff_steps: int = 100,
        loss_type: str = "l2",
        beta_end=0.1,
        beta_schedule="linear",
        residual_layers=8,
        residual_channels=8,
        dilation_cycle_length=2,
        scaling: bool = True,
        pick_incomplete: bool = False,
        lags_seq: Optional[List[int]] = None,
        time_features: Optional[List[TimeFeature]] = None,
        old: bool = False,
        time_feat_dim: int = 4,
        **kwargs,
    ) -> None:
        super().__init__(trainer=trainer, **kwargs)

        self.training_net = training_net
        self.prediction_net = prediction_net
        self.noise = noise

        self.old = old

        self.freq = freq
        self.context_length = context_length if context_length is not None else prediction_length

        self.input_size = input_size
        self.prediction_length = prediction_length
        self.target_dim = target_dim
        self.time_feat_dim = time_feat_dim
        self.num_layers = num_layers
        self.num_cells = num_cells
        self.cell_type = cell_type
        self.num_parallel_samples = num_parallel_samples
        self.dropout_rate = dropout_rate
        self.cardinality = cardinality
        self.embedding_dimension = embedding_dimension

        self.conditioning_length = hidden_dim
        self.diff_steps = diff_steps
        self.loss_type = loss_type
        self.beta_end = beta_end
        self.beta_schedule = beta_schedule
        self.residual_layers = residual_layers
        self.residual_channels = residual_channels
        self.dilation_cycle_length = dilation_cycle_length

        self.lags_seq = (
            lags_seq
            if lags_seq is not None
            else lags_for_fourier_time_features_from_frequency(freq_str=freq)
        )

        self.time_features = (
            time_features
            if time_features is not None
            else fourier_time_features_from_frequency(self.freq)
        )

        self.history_length = self.context_length + max(self.lags_seq)
        self.pick_incomplete = pick_incomplete
        self.scaling = scaling

        self.train_sampler = ExpectedNumInstanceSampler(
            num_instances=1.0,
            min_past=0 if pick_incomplete else self.history_length,
            min_future=prediction_length,
        )

        self.validation_sampler = ValidationSplitSampler(
            min_past=0 if pick_incomplete else self.history_length,
            min_future=prediction_length,
        )

    def create_transformation(self) -> Transformation:
        return Chain(
            [
                AsNumpyArray(
                    field=FieldName.TARGET,
                    expected_ndim=2,
                ),
                # maps the target to (1, T)
                # if the target data is uni dimensional
                ExpandDimArray(
                    field=FieldName.TARGET,
                    axis=None,
                ),
                AddObservedValuesIndicator(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.OBSERVED_VALUES,
                ),
                AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    time_features=self.time_features,
                    pred_length=self.prediction_length,
                ),
                VstackFeatures(
                    output_field=FieldName.FEAT_TIME,
                    input_fields=[FieldName.FEAT_TIME],
                ),
                SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
                TargetDimIndicator(
                    field_name="target_dimension_indicator",
                    target_field=FieldName.TARGET,
                ),
                AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
            ]
        )

    def create_instance_splitter(self, mode: str):
        assert mode in ["training", "validation", "test"]

        instance_sampler = {
            "training": self.train_sampler,
            "validation": self.validation_sampler,
            "test": TestSplitSampler(),
        }[mode]

        return InstanceSplitter(
            target_field=FieldName.TARGET,
            is_pad_field=FieldName.IS_PAD,
            start_field=FieldName.START,
            forecast_start_field=FieldName.FORECAST_START,
            instance_sampler=instance_sampler,
            past_length=self.history_length,
            future_length=self.prediction_length,
            time_series_fields=[
                FieldName.FEAT_TIME,
                FieldName.OBSERVED_VALUES,
            ],
        ) + (
            RenameFields(
                {
                    f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
                    f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
                }
            )
        )

    def create_training_network(self, device: torch.device):
        return self.training_net(
            noise=self.noise,
            input_size=self.input_size,
            target_dim=self.target_dim,
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            cell_type=self.cell_type,
            history_length=self.history_length,
            context_length=self.context_length,
            prediction_length=self.prediction_length,
            dropout_rate=self.dropout_rate,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            diff_steps=self.diff_steps,
            loss_type=self.loss_type,
            beta_end=self.beta_end,
            beta_schedule=self.beta_schedule,
            residual_layers=self.residual_layers,
            residual_channels=self.residual_channels,
            dilation_cycle_length=self.dilation_cycle_length,
            lags_seq=self.lags_seq,
            scaling=self.scaling,
            conditioning_length=self.conditioning_length,
            time_feat_dim=self.time_feat_dim,
        ).to(device)

    def create_predictor(
        self,
        transformation: Transformation,
        trained_network: Any,
        device: torch.device,
    ) -> Predictor:
        prediction_network = self.prediction_net(
            noise=self.noise,
            input_size=self.input_size,
            target_dim=self.target_dim,
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            cell_type=self.cell_type,
            history_length=self.history_length,
            context_length=self.context_length,
            prediction_length=self.prediction_length,
            dropout_rate=self.dropout_rate,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            diff_steps=self.diff_steps,
            loss_type=self.loss_type,
            beta_end=self.beta_end,
            beta_schedule=self.beta_schedule,
            residual_layers=self.residual_layers,
            residual_channels=self.residual_channels,
            dilation_cycle_length=self.dilation_cycle_length,
            lags_seq=self.lags_seq,
            scaling=self.scaling,
            conditioning_length=self.conditioning_length,
            num_parallel_samples=self.num_parallel_samples,
            time_feat_dim=self.time_feat_dim,
        ).to(device)

        copy_parameters(trained_network, prediction_network)
        input_names = get_module_forward_input_names(prediction_network)
        prediction_splitter = self.create_instance_splitter("test")

        return PyTorchPredictor(
            input_transform=transformation + prediction_splitter,
            input_names=input_names,
            prediction_net=prediction_network,
            batch_size=self.trainer.batch_size,
            freq=self.freq,
            prediction_length=self.prediction_length,
            device=device,
        )


# Noise 

In [5]:
from typing import Union
from torchtyping import TensorType

import numpy as np
import scipy.fftpack
from functools import lru_cache

import torch
import torch.nn as nn


class Normal(nn.Module):
    def __init__(self, dim: int, **kwargs):
        super().__init__()
        self.dim = dim

    def forward(self, *shape, **kwargs):
        return torch.randn(*shape, self.dim)

    def covariance(self, **kwargs):
        return torch.eye(self.dim)


class Wiener(nn.Module):
    """
    Wiener process / Brownian motion.
    """
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(
        self,
        t: Union[TensorType['seq_len'], TensorType[..., 'seq_len', 1]],
        **kwargs,
    ) -> Union[TensorType['seq_len'], TensorType[..., 'seq_len', 'dim']]:
        one_dimensional = len(t.shape) == 1

        if one_dimensional:
            t = t.unsqueeze(-1)
        t = t.repeat_interleave(self.dim, dim=-1)

        dt = torch.diff(t, dim=-2, prepend=torch.zeros_like(t[...,:1,:]).to(t))
        dw = torch.randn_like(dt) * dt.clamp(1e-5).sqrt()
        w = dw.cumsum(dim=-2)

        if one_dimensional and self.dim == 1:
            w = w.squeeze(-1)
        return w


class OrnsteinUhlenbeck(nn.Module):
    """
    Ornstein-Uhlenbeck process.

    Args:
        theta: Diffusion param, higher value = spikier (float)
    """
    def __init__(self, dim: int, theta: float = 0.5):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.wiener = Wiener(dim)

    def forward(
        self,
        *args,
        t: TensorType[..., 'seq_len', 1],
        **kwargs,
    ) -> TensorType[..., 'seq_len', 'dim']:

        delta = torch.diff(t, dim=-2, prepend=torch.zeros_like(t[...,:1,:]))
        coeff = torch.exp(-self.theta * delta)

        sample = []

        x = torch.randn(*t.shape[:-2], 1, self.dim).to(t)
        for i in range(coeff.shape[-2]):
            z = torch.randn(*t.shape[:-2], 1, self.dim).to(t)
            c = coeff[...,i,None,:]
            x = c * x + torch.sqrt(1 - c**2) * z
            sample.append(x)

        sample = torch.cat(sample, dim=-2)
        return sample

    def covariance(
        self,
        t: TensorType[..., 'seq_len', 1],
        diag_epsilon: float = 1e-4,
        **kwargs,
    ) -> TensorType[..., 'seq_len', 'seq_len']:
        t = t.squeeze(-1)
        diag = torch.eye(t.shape[-1]).to(t) * diag_epsilon
        cov = torch.exp(-(t.unsqueeze(-1) - t.unsqueeze(-2)).abs() * self.theta)
        return cov + diag

    def covariance_cholesky(self, t: TensorType[..., 'seq_len', 1]) -> TensorType[..., 'seq_len', 'seq_len']:
        return torch.linalg.cholesky(self.covariance(t))

    def covariance_inverse(self, t: TensorType[..., 'seq_len', 1]) -> TensorType[..., 'seq_len', 'seq_len']:
        return torch.linalg.inv(self.covariance(t))


class GaussianProcess(nn.Module):
    """
    Gaussian random field for one-dimensional (temporal) data.
    """
    def __init__(self, dim: int, sigma: float = 0.1):
        super().__init__()
        self.dim = dim
        self.sigma = sigma

    def forward(
        self,
        *args,
        t: TensorType[..., 'N', 1],
        **kwargs,
    ) -> TensorType[..., 'N', 'dim']:
        # If N is very large this could become slow
        # In that case, consider using sparse GP
        L = self.covariance_cholesky(t)
        e = torch.randn(*t.shape[:-1], self.dim).to(t)
        return L @ e

    def covariance(
        self,
        t: TensorType[..., 'N', 1],
        diag_epsilon: float = 1e-4,
        **kwargs,
    ) -> TensorType[..., 'N', 'N']:
        if t.shape[-1] != 1 or len(t.shape) < 2:
            t = t.unsqueeze(-1)
        distance = t - t.transpose(-1, -2)
        diag = torch.eye(t.shape[-2]).to(t) * diag_epsilon
        return torch.exp(-torch.square(distance / self.sigma)) + diag

    def covariance_cholesky(self, t: TensorType[..., 'N', 1]) -> TensorType[..., 'N', 'N']:
        return torch.linalg.cholesky(self.covariance(t))

    def covariance_inverse(self, t: TensorType[..., 'N', 1]) -> TensorType[..., 'N', 'N']:
        return torch.linalg.inv(self.covariance(t))


# Beta Schedulers

In [6]:
from typing import Callable
import torch
import torch.nn as nn
from torch import Tensor

def get_beta_scheduler(name: str) -> Callable:
    if name == 'linear':
        return BetaLinear

def get_loss_weighting(name: str) -> Callable:
    if name == 'exponential':
        return exponential_loss_weighting

class BetaLinear(nn.Module):
    """
    Linear scheduling for beta.
    Input t is always from interval [0, 1].

    Args:
        start: Lower bound (float)
        end: Upper bound (float)
    """
    def __init__(self, start: float, end: float):
        super().__init__()
        self.start = start
        self.end = end

    def forward(self, t: Tensor) -> Tensor:
        return self.start * (1 - t) + self.end * t

    def integral(self, t: Tensor) -> Tensor:
        return 0.5 * (self.end - self.start) * t.square() + self.start * t


def exponential_loss_weighting(beta_fn, i):
    return 1 - torch.exp(-beta_fn.integral(i))

# Diffusion Models

## Continuous Diffusion

In [7]:
from typing import Callable, Tuple, Optional, Union
from torchtyping import TensorType
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.distributions as td

from torchsde import sdeint
from torchdiffeq import odeint


class ContinuousDiffusion(nn.Module):
    """
    Continuous diffusion using SDEs (https://arxiv.org/abs/2011.13456)

    Args:
        dim: Dimension of data
        beta_fn: Scheduler for noise levels
        t1: Final diffusion time
        noise_fn: Type of noise
        predict_gaussian_noise: Whether to approximate score with unit normal
        loss_weighting: Function returning loss weights given diffusion time
    """
    def __init__(
        self,
        dim: int,
        beta_fn: Callable,
        t1: float = 1.0,
        noise_fn: Callable = None,
        loss_weighting: Callable = None,
        is_time_series: bool = False,
        predict_gaussian_noise: bool = True,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.t1 = t1
        self.predict_gaussian_noise = predict_gaussian_noise
        self.is_time_series = is_time_series

        self.beta_fn = beta_fn
        self.noise = noise_fn
        self.loss_weighting = partial(loss_weighting or (lambda beta, i: 1), beta_fn)

    def forward(
        self,
        x: TensorType[..., 'dim'],
        i: TensorType[..., 1],
        _return_all: Optional[bool] = False, # For internal use only
        **kwargs,
    ) -> Tuple[TensorType[..., 'dim'], TensorType[..., 'dim']]:

        noise_gaussian = torch.randn_like(x)

        if self.is_time_series:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
            noise = L @ noise_gaussian
        else:
            noise = noise_gaussian

        beta_int = self.beta_fn.integral(i)

        mean = x * torch.exp(-beta_int / 2)
        std = (1 - torch.exp(-beta_int)).clamp(1e-5).sqrt()

        y = mean + std * noise

        if _return_all:
            return y, noise, mean, std, cov if self.is_time_series else None

        if self.predict_gaussian_noise:
            return y, noise_gaussian
        else:
            return y, noise

    def get_loss(
        self,
        model: Callable,
        x: TensorType[..., 'dim'],
        **kwargs,
    ) -> TensorType[..., 1]:

        i = torch.rand(x.shape[0], *(1,) * len(x.shape[1:])).expand_as(x[...,:1]).to(x)
        i = i * self.t1

        x_noisy, noise = self.forward(x, i, **kwargs)

        pred_noise = model(x_noisy, i=i, **kwargs)
        loss = self.loss_weighting(i) * (pred_noise - noise)**2

        return loss

    def _get_score(self, model, x, i, L=None, **kwargs):
        """
        Returns score: ∇_xs log p(xs)
        """
        if isinstance(i, float):
            i = torch.Tensor([i]).to(x)
        if i.shape[:-1] != x.shape[:-1]:
            i = i.view(*(1,) * len(x.shape)).expand_as(x[...,:1])

        beta_int = self.beta_fn.integral(i)
        std = (1 - torch.exp(-beta_int)).clamp(1e-5).sqrt()

        noise = model(x, i=i, **kwargs)

        if L is not None:
            # We have to compute the score using -Sigma.inv() @ noise / std
            # assuming noise~N(0, Sigma).
            # If `predict_gaussian_noise=False`, compute (LL^T).inv()
            # Else, we can simplify (LL^T).inv() @ L @ noise
            # to (L^T).inv() @ noise, where noise~N(0, I).
            # So we anyways have to do (L^T).inv(), and sometimes L.inv()
            if not self.predict_gaussian_noise:
                noise = torch.linalg.solve_triangular(L, noise, upper=False)
            noise = torch.linalg.solve_triangular(L.transpose(-1, -2), noise, upper=True)

        score = -noise / std
        return score

    @torch.no_grad()
    def log_prob(
        self,
        model: Callable,
        x: Union[TensorType[..., 'dim'], TensorType[..., 'seq_len', 'dim']],
        num_samples: int = 1,
        **kwargs,
    ) -> TensorType[..., 1]:
        model.train() # Allows backprop through RNN
        self._e = torch.randn(num_samples, *x.shape).to(x)

        if self.is_time_series:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
        else:
            L = None

        def drift(i, state):
            y, _ = state
            with torch.set_grad_enabled(True):
                y = y.requires_grad_(True)
                score = self._get_score(model, y, i=i, L=L, **kwargs)
                if self.is_time_series:
                    # Have to include `cov` since g(t) = "scalar" * L @ dW
                    score = cov @ score
                dy = -0.5 * self.beta_fn(i) * (y + score)
                divergence = divergence_approx(dy, y, self._e, num_samples=num_samples)
            return dy, -divergence

        interval = torch.Tensor([0, self.t1]).to(x)

        # states = odeint(drift, (x, torch.zeros_like(x).to(x)), interval, rtol=1e-6, atol=1e-5)
        states = odeint(drift, (x, torch.zeros_like(x).to(x)), interval,
            method='rk4', options={'step_size': .01})
        y, div = states[0][-1], states[1][-1]

        if self.is_time_series:
            p0 = td.Independent(torch.distributions.MultivariateNormal(
                torch.zeros_like(y).transpose(-1, -2),
                cov.unsqueeze(-3).repeat_interleave(self.dim, dim=-3),
            ), 1)
            log_prob = p0.log_prob(y.transpose(-1, -2)) - div.sum([-1, -2])
            log_prob = log_prob / x.shape[-2]
        else:
            p0 = td.Independent(td.Normal(torch.zeros_like(y), torch.ones_like(y)), 1)
            log_prob = p0.log_prob(y) - div.sum(-1)

        return log_prob.unsqueeze(-1)

    @torch.no_grad()
    def sample(
        self,
        model: Callable,
        num_samples: int,
        device: str = None,
        use_ode: bool = True,
        **kwargs,
    ) -> TensorType['num_samples', 'dim']:
        if isinstance(num_samples, int):
            num_samples = (num_samples,)

        sampler = self.ode_sample if use_ode else self.sde_sample
        return sampler(model, num_samples, device, **kwargs)

    @torch.no_grad()
    def ode_sample(
        self,
        model: Callable,
        num_samples: int,
        device: str = None,
        **kwargs,
    ) -> TensorType['num_samples', 'dim']:
        if self.is_time_series:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
        else:
            L = None

        def drift(i, y):
            score = self._get_score(model, y, i=i, L=L, **kwargs)
            if self.is_time_series:
                # Have to include `cov` since g(t) = "scalar" * L @ dW
                score = cov @ score
            return -0.5 * self.beta_fn(i) * (y + score)

        x = self.noise(*num_samples, **kwargs).to(device)
        t = torch.Tensor([self.t1, 0]).to(device)
        y = odeint(drift, x, t, method='rk4', options={'step_size': .01})[1]
        # y = odeint(drift, x, t, rtol=1e-6, atol=1e-5)[1]

        return y

    @torch.no_grad()
    def sde_sample(
        self,
        model: Callable,
        num_samples: int,
        device: str = None,
        **kwargs,
    ) -> TensorType['num_samples', 'dim']:

        if self.is_time_series:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
        else:
            L = None

        is_time_series = self.is_time_series

        x = self.noise(*num_samples, **kwargs).to(device)
        shape = x.shape
        x = x.transpose(-2, -1).flatten(0, -2)

        class SDE(nn.Module):
            noise_type = 'general' if is_time_series else 'diagonal'
            sde_type = 'ito'

            def __init__(self, beta_fn, _get_score):
                super().__init__()
                self.beta_fn = beta_fn
                self._get_score = _get_score

            def f(self, i, inp):
                i = -i
                inp = inp.view(*shape) # Reshape back to original

                score = self._get_score(model, inp, i=i, L=L, **kwargs)
                if is_time_series:
                    score = cov @ score

                dx = self.beta_fn(i) * (0.5 * inp + score)

                if is_time_series:
                    return dx.transpose(-1, -2).flatten(0, -2)
                return dx.view(-1, shape[-1])

            def g(self, i, inp):
                i = -i
                beta = -self.beta_fn(i).sqrt()

                if is_time_series:
                    return (beta * L).repeat_interleave(shape[-1], dim=0)
                return beta.view(1, 1).repeat(np.prod(shape[:-1]), shape[-1]).to(device)

        sde = SDE(self.beta_fn, self._get_score)
        interval = torch.Tensor([-self.t1, 0]).to(device) # Time from -t1 to 0

        step_size = self.t1 / 100
        if not is_time_series:
            x = x.view(-1, shape[-1])
        else:
            x = x.view(-1, shape[-2])
        y = sdeint(sde, x, interval, dt=step_size)[-1]
        y = y.view(*shape)

        return y


class ContinuousGaussianDiffusion(ContinuousDiffusion):
    """ Continuous diffusion using Gaussian noise """
    def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise=None, **kwargs):
        super().__init__(dim, beta_fn, noise_fn=Normal(dim), predict_gaussian_noise=True, **kwargs)


class ContinuousOUDiffusion(ContinuousDiffusion):
    """ Continuous diffusion using noise coming from an OU process """
    def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise: bool = False, theta: float = 0.5, **kwargs):
        super().__init__(
            dim=dim,
            beta_fn=beta_fn,
            noise_fn=OrnsteinUhlenbeck(dim, theta=theta),
            predict_gaussian_noise=predict_gaussian_noise,
            is_time_series=True,
            **kwargs,
        )


class ContinuousGPDiffusion(ContinuousDiffusion):
    """ Continuous diffusion using noise coming from a Gaussian process """
    def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise: bool = False, sigma: float = 0.1, **kwargs):
        super().__init__(
            dim=dim,
            beta_fn=beta_fn,
            noise_fn=GaussianProcess(dim, sigma=sigma),
            predict_gaussian_noise=predict_gaussian_noise,
            is_time_series=True,
            **kwargs,
        )


def divergence_approx(output, input, e, num_samples=1):
    out = 0
    for i in range(num_samples):
        out += torch.autograd.grad(output, input, e[i], create_graph=True)[0].detach() * e[i]
    return out / num_samples


## Discrete Diffusion

In [8]:
from typing import Any, Callable, Tuple, Union


import torch
import torch.nn as nn

class DiscreteDiffusion(nn.Module):
    """
    Discrete diffusion (https://arxiv.org/abs/2006.11239)

    Args:
        dim: Dimension of data
        num_steps: Number of diffusion steps
        beta_fn: Scheduler for noise levels
        noise_fn: Type of noise
        parallel_elbo: Whether to compute ELBO in parallel or not
    """
    def __init__(
        self,
        dim: int,
        num_steps: int,
        beta_fn: Callable,
        noise_fn: Callable,
        parallel_elbo: bool = False,
        is_time_series: bool = False,
        predict_gaussian_noise: bool = True,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.num_steps = num_steps
        self.parallel_elbo = parallel_elbo
        self.is_time_series = is_time_series
        self.predict_gaussian_noise = predict_gaussian_noise

        betas = beta_fn(torch.linspace(0, 1, num_steps))
        alphas = torch.cumprod(1 - betas, dim=0)

        self.register_buffer('betas', betas)  # Register betas as a buffer
        self.register_buffer('alphas', alphas)  # Register alphas as a buffer

        self.noise = noise_fn

    def forward(
        self,
        x: TensorType[..., 'dim'],  # noqa: F821
        i: TensorType[..., 1],
        **kwargs,
    ) -> Tuple[TensorType[..., 'dim'], TensorType[..., 'dim']]:  # noqa: F821

        noise_gaussian = torch.randn_like(x)

        if self.is_time_series:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
            noise = L @ noise_gaussian
        else:
            noise = noise_gaussian

        alpha = self.alphas[i.long()].to(x)
        y = torch.sqrt(alpha) * x + torch.sqrt(1 - alpha) * noise

        if self.predict_gaussian_noise:
            return y, noise_gaussian
        else:
            return y, noise

    def get_loss(
        self,
        model: Callable,
        x: TensorType[..., 'dim'],
        **kwargs,
    ) -> TensorType[..., 'dim']:

        i = torch.randint(0, self.num_steps, size=(x.shape[0],))
        i = i.view(-1, *(1,) * len(x.shape[1:])).expand_as(x[...,:1]).to(x)

        x_noisy, noise = self.forward(x, i, **kwargs)

        pred_noise = model(x_noisy, i=i, **kwargs)
        loss = (pred_noise - noise)**2

        return loss

    @torch.no_grad()
    def sample(
        self,
        model: Callable,
        num_samples: Union[int, Tuple],
        device: str = 'cpu',
        **kwargs,
    ) -> TensorType['*num_samples', 'dim']:
        if isinstance(num_samples, int):
            num_samples = (num_samples,)

        x = self.noise(*num_samples, **kwargs).to(device)

        if self.is_time_series and self.predict_gaussian_noise:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
        else:
            L = None

        for diff_step in reversed(range(0, self.num_steps)):
            alpha = self.alphas[diff_step]
            beta = self.betas[diff_step]

            # An alternative can be:
            # alpha_prev = self.alphas[diff_step - 1]
            # sigma = beta * (1 - alpha_prev) / (1 - alpha)
            sigma = beta

            if diff_step == 0:
                z = 0
            else:
                z = self.noise(*num_samples, **kwargs).to(device)

            i = torch.Tensor([diff_step]).expand_as(x[...,:1]).to(device)
            pred_noise = model(x, i=i, **kwargs)

            if L is not None:
                pred_noise = L @ pred_noise

            x = (x - beta * pred_noise / (1 - alpha).sqrt()) / (1 - beta).sqrt() + sigma.sqrt() * z

        return x

    @torch.no_grad()
    def log_prob(
        self,
        model: Callable,
        x: TensorType[..., 'dim'],
        num_samples: int = 1,
        **kwargs,
    ) -> TensorType[..., 1]:
        if self.is_time_series and self.predict_gaussian_noise:
            cov = self.noise.covariance(**kwargs)
            L = torch.linalg.cholesky(cov)
        else:
            L = None

        func = self._elbo_parallel if self.parallel_elbo else self._elbo_sequential
        return func(model, x, num_samples=num_samples, L=L, **kwargs)

    def _elbo_parallel(
        self,
        model: Callable,
        x: TensorType[..., 'dim'],
        L: TensorType[..., 'seq_len', 'seq_len'],
        num_samples: int = 1,
        **kwargs,
    ) -> TensorType[..., 1]:
        """
        Computes ELBO over all diffusion steps in parallel,
        then averages over `num_samples` runs.
        If diffusion `num_steps` large (and `num_samples` small)
        it will be heavy on the GPU memory.

        Args:
            model: Denoising diffusion model
            x: Clean input data
            num_samples: How many times to compute ELBO, final
                result is averaged over all ELBO samples
            **kwargs: Can be time, latent etc. depending on a model
        """
        elbo = 0

        i = expand_to_x(torch.arange(self.num_steps), x).expand(-1, *x[...,:1].shape).contiguous()
        alphas = expand_to_x(self.alphas, x)
        betas = expand_to_x(self.betas, x)

        xt, kwargs = expand_x_and_kwargs(x, kwargs, self.num_steps)

        for _ in range(num_samples):
            # Get diffused outputs
            xt, _ = self.forward(x, i, **kwargs) # [num_steps, ..., dim]

            # Output predicted noise
            epsilon = model(xt, i=i, **kwargs)

            if L is not None:
                epsilon = L @ epsilon

            # p(x_{t-1} | p_t)
            p_mu = get_p_mu(xt, betas, alphas, epsilon)
            px = td.Independent(td.Normal(p_mu[1:], betas[1:].sqrt()), 1)

            # p(x_0 | x_1)
            log_prob_x0_x1 = td.Independent(td.Normal(p_mu[0], betas[0].sqrt()), 1).log_prob(x)
            assert log_prob_x0_x1.shape == x.shape[:-1]

            # q(x_{t-1} | x_0, x_t), t > 1
            qx = get_qx(x.unsqueeze(0), xt[1:], alphas[1:], alphas[:-1], betas[1:])

            # KL[q(x_{t-1} | p_t) || p(x_{t-1} | p_t)]
            kl_q_p = td.kl_divergence(qx, px).sum(0)
            assert kl_q_p.shape == x.shape[:-1]

            # ELBO
            elbo_contribution = (log_prob_x0_x1 - kl_q_p) / num_samples
            elbo += elbo_contribution

        elbo = reduce_elbo(elbo, x)
        return elbo

    def _elbo_sequential(
        self,
        model: Callable,
        x: TensorType[..., 'dim'],
        L: TensorType[..., 'seq_len', 'seq_len'],
        num_samples: int = 1,
        **kwargs,
    ) -> TensorType[..., 1]:
        """
        Computes ELBO as a sum of diffusion steps - sequentially.

        Args:
            model: Denoising diffusion model
            x: Clean input data
            num_samples: How many times to compute ELBO, final
                result is averaged over all ELBO samples
            **kwargs: Can be time, latent etc. depending on a model
        """
        elbo = 0

        x, kwargs = expand_x_and_kwargs(x, kwargs, num_samples)

        for i in range(self.num_steps):
            # Prepare variables
            beta = self.betas[i].to(x)
            alpha = self.alphas[i].to(x)
            step = torch.Tensor([i]).expand_as(x[...,:1]).to(x)

            # Diffuse and predict noise
            xt, _ = self.forward(x, i=step, **kwargs)
            epsilon = model(xt, i=step, **kwargs)

            if L is not None:
                epsilon = L @ epsilon

            assert xt.shape == x.shape == epsilon.shape

            # p(x_{t-1} | p_t)
            p_mu = get_p_mu(xt, beta, alpha, epsilon)
            px = td.Independent(td.Normal(p_mu, beta.sqrt()), 1)

            if i == 0:
                elbo = elbo + px.log_prob(x).mean(0)
            else:
                prev_alpha = self.alphas[i - 1]

                # q(x_{t-1} | x_0, x_t), t > 1
                qx = get_qx(x, xt, alpha, prev_alpha, beta)

                # KL[q(x_{t-1} | p_t) || p(x_{t-1} | p_t)]
                kl = td.kl_divergence(qx, px).mean(0)
                elbo = elbo - kl

        elbo = reduce_elbo(elbo, x)
        return elbo


class GaussianDiffusion(DiscreteDiffusion):
    """ Discrete diffusion with Gaussian noise """
    def __init__(self, dim: int, num_steps: int, beta_fn: Callable, **kwargs):
        super().__init__(dim, num_steps, beta_fn, noise_fn=Normal(dim), **kwargs)


class OUDiffusion(DiscreteDiffusion):
    """ Discrete diffusion with noise coming from an OU process """
    def __init__(
        self,
        dim: int,
        num_steps: int,
        beta_fn: Callable,
        predict_gaussian_noise: bool,
        theta: float = 0.5,
        **kwargs,
    ):
        super().__init__(
            dim=dim,
            num_steps=num_steps,
            beta_fn=beta_fn,
            noise_fn=OrnsteinUhlenbeck(dim, theta=theta),
            is_time_series=True,
            predict_gaussian_noise=predict_gaussian_noise,
            **kwargs,
        )


class GPDiffusion(DiscreteDiffusion):
    """ Discrete diffusion with noise coming from a Gaussian process """
    def __init__(
        self,
        dim: int,
        num_steps: int,
        beta_fn: Callable,
        predict_gaussian_noise: bool,
        sigma: float = 0.1,
        **kwargs,
    ):
        super().__init__(
            dim=dim,
            num_steps=num_steps,
            beta_fn=beta_fn,
            noise_fn=GaussianProcess(dim, sigma=sigma),
            is_time_series=True,
            predict_gaussian_noise=predict_gaussian_noise,
            **kwargs,
        )


def expand_to_x(inputs, x):
    return inputs.view(-1, *(1,) * len(x.shape)).to(x)

def expand_x_and_kwargs(x, kwargs, N):
    # Expand dimensions
    x = x.unsqueeze(0).repeat_interleave(N, dim=0)

    # A hacky solution to repeat dimensions in all kwargs (latent, t, etc.)
    for key, value in kwargs.items():
        if torch.is_tensor(value):
            kwargs[key] = value.unsqueeze(0).repeat_interleave(N, dim=0)

    return x, kwargs

def reduce_elbo(
    elbo: TensorType['batch', Any],
    x: TensorType[Any],
) -> TensorType['batch', 1]:
    # Reduce ELBO over all but batch dimension: (B, ...) -> (B,)
    elbo = elbo.view(elbo.shape[0], -1).sum(1)

    if len(x.shape) > 2:
        elbo = elbo / x.shape[-2]

    return elbo.unsqueeze(1)

def get_p_mu(xt, beta, alpha, epsilon):
    mu = 1 / (1 - beta).sqrt() * (xt - beta / (1 - alpha).sqrt() * epsilon)
    return mu

def get_qx(x, xt, alpha, prev_alpha, beta):
    q_mu_1 = torch.sqrt(prev_alpha) * beta / (1 - alpha) * x
    q_mu_2 = torch.sqrt(1 - beta) * (1 - prev_alpha) / (1 - alpha) * xt
    q_mu = q_mu_1 + q_mu_2

    q_sigma = beta * (1 - prev_alpha) / (1 - alpha)

    qx = td.Independent(td.Normal(q_mu, q_sigma.expand_as(q_mu).sqrt()), 1)
    return qx


# All Architectures

## Epsilon Theta

In [9]:
import math

import torch
from torch import nn
import torch.nn.functional as F


class DiffusionEmbeddingET(nn.Module):
    def __init__(self, dim, proj_dim, max_steps=500):
        super().__init__()
        self.register_buffer(
            "embedding", self._build_embedding(dim, max_steps), persistent=False
        )
        self.projection1 = nn.Linear(dim * 2, proj_dim)
        self.projection2 = nn.Linear(proj_dim, proj_dim)

    def forward(self, diffusion_step):
        x = self.embedding[diffusion_step]
        x = self.projection1(x)
        x = F.silu(x)
        x = self.projection2(x)
        x = F.silu(x)
        return x

    def _build_embedding(self, dim, max_steps):
        steps = torch.arange(max_steps).unsqueeze(1)  # [T,1]
        dims = torch.arange(dim).unsqueeze(0)  # [1,dim]
        table = steps * 10.0 ** (dims * 4.0 / dim)  # [T,dim]
        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
        return table


class ResidualBlockET(nn.Module):
    def __init__(self, hidden_size, residual_channels, dilation):
        super().__init__()
        self.dilated_conv = nn.Conv1d(
            residual_channels,
            2 * residual_channels,
            3,
            padding=dilation,
            dilation=dilation,
            padding_mode="circular",
        )
        self.diffusion_projection = nn.Linear(hidden_size, residual_channels)
        self.conditioner_projection = nn.Conv1d(
            1, 2 * residual_channels, 1, padding=2, padding_mode="circular"
        )
        self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)

        nn.init.kaiming_normal_(self.conditioner_projection.weight)
        nn.init.kaiming_normal_(self.output_projection.weight)

    def forward(self, x, conditioner, diffusion_step):
        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
        conditioner = self.conditioner_projection(conditioner)

        y = x + diffusion_step
        y = self.dilated_conv(y) + conditioner

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)

        y = self.output_projection(y)
        y = F.leaky_relu(y, 0.4)
        residual, skip = torch.chunk(y, 2, dim=1)
        return (x + residual) / math.sqrt(2.0), skip


class CondUpsamplerET(nn.Module):
    def __init__(self, cond_length, target_dim):
        super().__init__()
        self.linear1 = nn.Linear(cond_length, target_dim // 2)
        self.linear2 = nn.Linear(target_dim // 2, target_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = F.leaky_relu(x, 0.4)
        x = self.linear2(x)
        x = F.leaky_relu(x, 0.4)
        return x


class EpsilonTheta(nn.Module):
    def __init__(
        self,
        target_dim,
        cond_length,
        time_emb_dim=16,
        residual_layers=8,
        residual_channels=8,
        dilation_cycle_length=2,
        residual_hidden=64,
    ):
        super().__init__()
        self.input_projection = nn.Conv1d(
            1, residual_channels, 1, padding=2, padding_mode="circular"
        )
        self.diffusion_embedding = DiffusionEmbeddingET(
            time_emb_dim, proj_dim=residual_hidden
        )
        self.cond_upsampler = CondUpsamplerET(
            target_dim=target_dim, cond_length=cond_length
        )
        self.residual_layers = nn.ModuleList(
            [
                ResidualBlockET(
                    residual_channels=residual_channels,
                    dilation=2 ** (i % dilation_cycle_length),
                    hidden_size=residual_hidden,
                )
                for i in range(residual_layers)
            ]
        )
        self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3)
        self.output_projection = nn.Conv1d(residual_channels, 1, 3)

        nn.init.kaiming_normal_(self.input_projection.weight)
        nn.init.kaiming_normal_(self.skip_projection.weight)
        nn.init.zeros_(self.output_projection.weight)

    def forward(self, inputs, time, cond):
        x = self.input_projection(inputs)
        x = F.leaky_relu(x, 0.4)

        diffusion_step = self.diffusion_embedding(time)
        cond_up = self.cond_upsampler(cond)
        skip = []
        for layer in self.residual_layers:
            x, skip_connection = layer(x, cond_up, diffusion_step)
            skip.append(skip_connection)

        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.leaky_relu(x, 0.4)
        x = self.output_projection(x)
        return x


## Dodict

In [10]:
class dotdict(dict):
    """ Dot notation access to dict attributes """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


## Architecture

In [11]:
from torchtyping import TensorType

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from pts.model import weighted_average
from pts.model.time_grad import TimeGradTrainingNetwork, TimeGradPredictionNetwork
from pts.model.time_grad.epsilon_theta import DiffusionEmbedding


class TimeGradTrainingNetwork_AutoregressiveOld(TimeGradTrainingNetwork):
    def __init__(self, **kwargs):
        kwargs.pop('time_feat_dim')
        kwargs.pop('noise')
        super().__init__(**kwargs)

class TimeGradPredictionNetwork_AutoregressiveOld(TimeGradPredictionNetwork):
    def __init__(self, **kwargs):
        kwargs.pop('time_feat_dim')
        kwargs.pop('noise')
        super().__init__(**kwargs)


class DenoiseWrapper(nn.Module):
    def __init__(self, denoise_fn, target_dim, time_input):
        super().__init__()
        self.denoise_fn = denoise_fn
        self.time_input = time_input
        if self.time_input:
            self.time_embedding = DiffusionEmbedding(dim=target_dim, proj_dim=target_dim, max_steps=100)

    def forward(self, x, t=None, i=None, latent=None, **kwargs):
        shape = x.shape

        if self.time_input:
            x = x + self.time_embedding(t.squeeze(-1).long())

        x = x.view(-1, 1, x.shape[-1])
        i = i.view(-1).long()
        latent = latent.reshape(-1, 1, latent.shape[-1])

        y = self.denoise_fn(x, i, latent)
        y = y.view(*shape)
        return y


################################################################################################
#### TimeGrad RNN encoder --> prediction all at once using time positional encoding
#### using the past prediction window sized RNN context
################################################################################################
class TimeGradTrainingNetwork_All(TimeGradTrainingNetwork):
    def __init__(self, **kwargs):
        args = dotdict(kwargs)
        self.noise = args.noise

        kwargs.pop('time_feat_dim')
        kwargs.pop('noise')
        super().__init__(**kwargs)

        self.time_input = (self.noise != 'normal')
        self.rnn_state_proj = nn.Linear(args.num_cells, args.conditioning_length)

        if self.noise == 'normal':
            diffusion = GaussianDiffusion
        elif self.noise == 'ou':
            diffusion = OUDiffusion
        elif self.noise == 'gp':
            diffusion = GPDiffusion
        else:
            raise NotImplementedError

        self.diffusion = diffusion(args.target_dim, args.diff_steps, BetaLinear(1e-4, args.beta_end), sigma=0.05, predict_gaussian_noise=True)

        denoise_fn = EpsilonTheta(
            target_dim=args.target_dim,
            cond_length=args.conditioning_length,
            residual_layers=args.residual_layers,
            residual_channels=args.residual_channels,
            dilation_cycle_length=args.dilation_cycle_length,
        )

        self.denoise_fn = DenoiseWrapper(denoise_fn, args.target_dim, self.time_input)

    def get_rnn_state(self, **kwargs):
        rnn_outputs, _, scale, _, _ = self.unroll_encoder(**kwargs)
        rnn_outputs = self.rnn_state_proj(rnn_outputs)
        return rnn_outputs, scale

    def forward(
        self,
        target_dimension_indicator: torch.Tensor,
        past_time_feat: torch.Tensor,
        past_target_cdf: torch.Tensor,
        past_observed_values: torch.Tensor,
        past_is_pad: torch.Tensor,
        future_time_feat: torch.Tensor,
        future_target_cdf: torch.Tensor,
        future_observed_values: torch.Tensor,
    ) -> TensorType[()]:

        latent, scale = self.get_rnn_state(
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            past_is_pad=past_is_pad,
            future_time_feat=future_time_feat,
            future_target_cdf=None,
            target_dimension_indicator=target_dimension_indicator,
        )

        mean = past_target_cdf[...,-self.prediction_length:,:].mean(1, keepdim=True)
        std = past_target_cdf[...,-self.prediction_length:,:].std(1, keepdim=True).clamp(1e-4)

        # target = future_target_cdf[...,-self.prediction_length:,:] / scale
        target = (future_target_cdf[...,-self.prediction_length:,:] - mean) / std
        # target = (future_target_cdf[...,-self.prediction_length:,:] - past_target_cdf[...,-1:,:] - mean) / std
        # target = (future_target_cdf[...,-self.prediction_length:,:] - past_target_cdf[...,-1:,:]) / scale

        t = torch.arange(self.prediction_length).view(1, -1, 1).repeat(target.shape[0], 1, 1).to(target)
        loss = self.diffusion.get_loss(self.denoise_fn, target, t=t, latent=latent, future_time_feat=future_time_feat)

        loss_weights, _ = future_observed_values.min(dim=-1, keepdim=True)
        loss = weighted_average(loss, weights=loss_weights, dim=1)

        return loss.mean()


class TimeGradPredictionNetwork_All(TimeGradTrainingNetwork_All):
    def __init__(self, num_parallel_samples: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_samples = num_parallel_samples

    def forward(
        self,
        target_dimension_indicator: torch.Tensor,
        past_time_feat: torch.Tensor,
        past_target_cdf: torch.Tensor,
        past_observed_values: torch.Tensor,
        past_is_pad: torch.Tensor,
        future_time_feat: torch.Tensor,
    ) -> torch.Tensor:
        mean = past_target_cdf[...,-self.prediction_length:,:].mean(1, keepdim=True)
        std = past_target_cdf[...,-self.prediction_length:,:].std(1, keepdim=True).clamp(1e-4)

        latent, scale = self.get_rnn_state(
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            past_is_pad=past_is_pad,
            future_time_feat=future_time_feat,
            future_target_cdf=None,
            target_dimension_indicator=target_dimension_indicator,
        )

        num_samples = (self.num_samples * latent.shape[0], *latent.shape[1:-1])
        latent = latent.repeat_interleave(self.num_samples, dim=0)
        future_time_feat = future_time_feat.repeat_interleave(self.num_samples, dim=0)

        t = torch.arange(self.prediction_length).view(*(1,) * len(latent.shape[:-3]), -1, 1)
        t = t.expand_as(latent[...,:1]).to(latent)

        samples = self.diffusion.sample(
            self.denoise_fn,
            num_samples=num_samples,
            latent=latent,
            t=t,
            future_time_feat=future_time_feat,
            device=latent.device,
        )

        samples = samples.unflatten(0, (-1, self.num_samples))
        samples = samples * std.unsqueeze(1) + mean.unsqueeze(1)
        # samples = samples * scale.unsqueeze(1)
        # samples = samples + past_target_cdf[...,-1:,:].unsqueeze(1)
        return samples


################################################################################################
#### TimeGrad Autoregressive -> predicts one by one
################################################################################################
class TimeGradTrainingNetwork_Autoregressive(TimeGradTrainingNetwork_All):
    def forward(
        self,
        target_dimension_indicator: torch.Tensor,
        past_time_feat: torch.Tensor,
        past_target_cdf: torch.Tensor,
        past_observed_values: torch.Tensor,
        past_is_pad: torch.Tensor,
        future_time_feat: torch.Tensor,
        future_target_cdf: torch.Tensor,
        future_observed_values: torch.Tensor,
    ) -> TensorType[()]:

        latent, scale = self.get_rnn_state(
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            past_is_pad=past_is_pad,
            future_time_feat=future_time_feat,
            future_target_cdf=future_target_cdf,
            target_dimension_indicator=target_dimension_indicator,
        )

        target = torch.cat([past_target_cdf[...,-self.context_length:,:], future_target_cdf], 1)
        target = target / scale

        loss = self.diffusion.get_loss(self.denoise_fn, target, latent=latent, future_time_feat=future_time_feat)

        past_observed_values = torch.min(past_observed_values, 1 - past_is_pad.unsqueeze(-1))
        observed_values = torch.cat((past_observed_values[:, -self.context_length:, ...], future_observed_values), dim=1)
        loss_weights, _ = observed_values.min(dim=-1, keepdim=True)

        loss = weighted_average(loss, weights=loss_weights, dim=1)

        return loss.mean()

class TimeGradPredictionNetwork_Autoregressive(TimeGradTrainingNetwork_Autoregressive):
    def __init__(self, num_parallel_samples: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_samples = num_parallel_samples
        self.shifted_lags = [l - 1 for l in self.lags_seq]

    @torch.no_grad()
    def forward(
        self,
        target_dimension_indicator: torch.Tensor,
        past_time_feat: torch.Tensor,
        past_target_cdf: torch.Tensor,
        past_observed_values: torch.Tensor,
        past_is_pad: torch.Tensor,
        future_time_feat: torch.Tensor,
    ) -> torch.Tensor:

        past_observed_values = torch.min(past_observed_values, 1 - past_is_pad.unsqueeze(-1))

        past_time_feat = past_time_feat.repeat_interleave(self.num_samples, dim=0)
        past_target_cdf = past_target_cdf.repeat_interleave(self.num_samples, dim=0)
        past_observed_values = past_observed_values.repeat_interleave(self.num_samples, dim=0)
        past_is_pad = past_is_pad.repeat_interleave(self.num_samples, dim=0)
        future_time_feat = future_time_feat.repeat_interleave(self.num_samples, dim=0)
        target_dimension_indicator = target_dimension_indicator.repeat_interleave(self.num_samples, dim=0)

        _, begin_states, scale, _, _ = self.unroll_encoder(
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            past_is_pad=past_is_pad,
            future_time_feat=future_time_feat,
            future_target_cdf=None,
            target_dimension_indicator=target_dimension_indicator,
        )

        samples = []
        for i in range(self.prediction_length):
            lags = self.get_lagged_subsequences(
                sequence=past_target_cdf,
                sequence_length=self.history_length + i,
                indices=self.shifted_lags,
                subsequences_length=1,
            )

            latent, begin_states, _, _ = self.unroll(
                begin_state=begin_states,
                lags=lags,
                scale=scale,
                time_feat=future_time_feat[:, i : i + 1],
                target_dimension_indicator=target_dimension_indicator,
                unroll_length=1,
            )
            latent = self.rnn_state_proj(latent)

            sample = self.diffusion.sample(
                self.denoise_fn,
                num_samples=latent.shape[:-1],
                latent=latent,
                device=latent.device,
            )
            sample = sample * scale

            samples.append(sample)
            past_target_cdf = torch.cat([past_target_cdf, sample], dim=1)

        samples = torch.cat(samples, dim=1)
        samples = samples.unflatten(0, (-1, self.num_samples))

        return samples


################################################################################################
#### TimeGrad RNN encoder --> predicting all at once with RNN+TimeGrad decoder
#### RNN initial state is the last state from the encoder
################################################################################################
class TimeGradTrainingNetwork_RNN(TimeGradTrainingNetwork_All):
    def __init__(self, **kwargs):
        args = dotdict(kwargs)
        super().__init__(**kwargs)

        self.num_rnn_layers = 2
        self.proj_inputs = nn.Sequential(
            nn.Linear(args.time_feat_dim , args.conditioning_length),
            nn.ReLU(),
            nn.Linear(args.conditioning_length , args.conditioning_length),
        )
        self.prediction_rnn = nn.GRU(args.conditioning_length, args.conditioning_length,
            num_layers=self.num_rnn_layers, bidirectional=False, batch_first=True)

    def get_rnn_state(self, **kwargs):
        states, _, scale, _, _ = self.unroll_encoder(**kwargs)
        states = self.rnn_state_proj(states)

        states = states[...,-1,:].unsqueeze(0).repeat_interleave(self.num_rnn_layers, dim=0)

        inputs = self.proj_inputs(kwargs['future_time_feat'])
        out, _ = self.prediction_rnn(inputs, states)

        return out, scale

class TimeGradPredictionNetwork_RNN(TimeGradTrainingNetwork_RNN):
    forward = TimeGradPredictionNetwork_All.forward

    def __init__(self, num_parallel_samples: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_samples = num_parallel_samples


################################################################################################
#### TimeGrad RNN encoder --> predicting all at once with Transformer decoder+TimeGrad net
################################################################################################
class TimeGradTrainingNetwork_Transformer(TimeGradTrainingNetwork_All):
    def __init__(self, **kwargs):
        args = dotdict(kwargs)
        super().__init__(**kwargs)

        self.pos_enc = DiffusionEmbedding(dim=args.conditioning_length, proj_dim=args.conditioning_length, max_steps=100)
        self.proj_time_feat = nn.Linear(args.time_feat_dim + args.conditioning_length, args.conditioning_length)

        decoder_layer = nn.TransformerDecoderLayer(args.conditioning_length, nhead=1, dim_feedforward=args.conditioning_length, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

    def get_rnn_state(self, **kwargs):
        states, _, scale, _, _ = self.unroll_encoder(**kwargs)
        states = self.rnn_state_proj(states)

        t = torch.arange(self.prediction_length).view(1, -1).repeat(states.shape[0], 1).to(states)
        t = self.pos_enc(t.long())

        x = torch.cat([t, kwargs['future_time_feat']], -1)
        x = self.proj_time_feat(x)
        out = self.transformer_decoder(tgt=x, memory=states)
        return out, scale

class TimeGradPredictionNetwork_Transformer(TimeGradTrainingNetwork_Transformer):
    forward = TimeGradPredictionNetwork_All.forward

    def __init__(self, num_parallel_samples: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_samples = num_parallel_samples


################################################################################################
#### TimeGrad RNN encoder --> predicting all at once with 2D conv similar to TimeGrad version
################################################################################################
class ResidualBlockTG(nn.Module):
    def __init__(self, dim, hidden_size, residual_channels, dilation, padding_mode):
        super().__init__()
        self.step_projection = nn.Linear(hidden_size, residual_channels)
        self.time_projection = nn.Linear(hidden_size, residual_channels)

        self.x_step_proj = nn.Sequential(
            nn.Conv2d(residual_channels, residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode),
            nn.LeakyReLU(0.4),
        )
        self.x_time_proj = nn.Sequential(
            nn.Conv2d(residual_channels, residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode),
            nn.LeakyReLU(0.4),
        )

        self.latent_projection = nn.Conv2d(
            1, 2 * residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode,
        )
        self.dilated_conv = nn.Conv2d(
            1 * residual_channels,
            2 * residual_channels,
            kernel_size=3,
            dilation=dilation,
            padding='same',
            padding_mode=padding_mode,
        )
        self.output_projection = nn.Conv2d(
            residual_channels, 2 * residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode,
        )

    def forward(self, x, t=None, i=None, latent=None):
        i = self.step_projection(i).transpose(-1, -2).unsqueeze(-1)
        latent = self.latent_projection(latent.unsqueeze(1))

        y = x + i
        y = y + self.x_step_proj(y)

        t = self.time_projection(t).transpose(-1, -2).unsqueeze(-1)
        y = y + self.x_time_proj(y + t)

        y = self.dilated_conv(y) + latent

        gate, filter = y.chunk(2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)

        y = self.output_projection(y)
        y = F.leaky_relu(y, 0.4)

        residual, skip = y.chunk(2, dim=1)
        return (x + residual) / math.sqrt(2), skip

class DenoisingModelTG(nn.Module):
    def __init__(self, dim, residual_channels, latent_dim, residual_hidden, residual_layers, time_input, padding_mode='circular'):
        super().__init__()
        self.time_input = time_input

        self.input_projection = nn.Conv2d(1, residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode)
        self.step_embedding = DiffusionEmbedding(residual_hidden, proj_dim=residual_hidden)
        self.time_embedding = DiffusionEmbedding(residual_hidden, proj_dim=residual_hidden, max_steps=24)
        self.latent_projection = nn.Sequential(
            nn.Linear(latent_dim, dim // 2),
            nn.LeakyReLU(0.4),
            nn.Linear(dim // 2, dim),
            nn.LeakyReLU(0.4),
        )

        self.residual_layers = nn.ModuleList([
            ResidualBlockTG(dim, residual_hidden, residual_channels, dilation=2**(i % 2), padding_mode=padding_mode)
            for i in range(residual_layers)
        ])

        self.skip_projection = nn.Conv2d(
            residual_channels, residual_channels, kernel_size=3, padding='same', padding_mode=padding_mode,
        )
        self.output_projection = nn.Conv2d(
            residual_channels, 1, kernel_size=3, padding='same', padding_mode=padding_mode,
        )

        self.time_proj = nn.Sequential(
            nn.Linear(5, residual_hidden),
            nn.LeakyReLU(0.4),
            nn.Linear(residual_hidden, residual_hidden),
            nn.LeakyReLU(0.4),
        )

    def forward(self, x, t=None, i=None, latent=None, future_time_feat=None):
        shape = x.shape

        x = x.unsqueeze(1)
        x = self.input_projection(x)
        x = F.leaky_relu(x, 0.4)

        i = self.step_embedding(i.squeeze(-1).long())
        # if t is not None:
        #     t = self.time_embedding(t.squeeze(-1).long())

        t = self.time_proj(torch.cat([future_time_feat, t / t.max()], -1))

        latent = self.latent_projection(latent)

        skip_agg = 0
        for layer in self.residual_layers:
            x, skip = layer(x, t=t, i=i, latent=latent)
            skip_agg = skip_agg + skip

        x = skip_agg / math.sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.leaky_relu(x, 0.4)
        x = self.output_projection(x).squeeze(1)

        x = x.view(*shape)
        return x

class TimeGradTrainingNetwork_CNN(TimeGradTrainingNetwork_All):
    def __init__(self, **kwargs):
        args = dotdict(kwargs)
        super().__init__(**kwargs)
        self.denoise_fn = DenoisingModelTG(
            dim=args.target_dim,
            residual_channels=args.residual_channels,
            latent_dim=args.conditioning_length,
            residual_hidden=args.conditioning_length,
            time_input=self.time_input,
            residual_layers=args.residual_layers,
        )

class TimeGradPredictionNetwork_CNN(TimeGradTrainingNetwork_CNN):
    forward = TimeGradPredictionNetwork_All.forward

    def __init__(self, num_parallel_samples: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_samples = num_parallel_samples
        self.shifted_lags = [l - 1 for l in self.lags_seq]


In [12]:
class NotSupportedModelNoiseCombination(Exception):
    pass

def get_network(network,noise):
    # Load model
    if network == 'timegrad':
        if noise != 'normal':
            raise NotSupportedModelNoiseCombination
        training_net, prediction_net = TimeGradTrainingNetwork_Autoregressive, TimeGradPredictionNetwork_Autoregressive
    elif network == 'timegrad_old':
        if noise != 'normal':
            raise NotSupportedModelNoiseCombination
        training_net, prediction_net = TimeGradTrainingNetwork_AutoregressiveOld, TimeGradPredictionNetwork_AutoregressiveOld
    elif network == 'timegrad_all':
        training_net, prediction_net = TimeGradTrainingNetwork_All, TimeGradPredictionNetwork_All
    elif network == 'timegrad_rnn':
        training_net, prediction_net = TimeGradTrainingNetwork_RNN, TimeGradPredictionNetwork_RNN
    elif network == 'timegrad_transformer':
        training_net, prediction_net = TimeGradTrainingNetwork_Transformer, TimeGradPredictionNetwork_Transformer
    elif network == 'timegrad_cnn':
        training_net, prediction_net = TimeGradTrainingNetwork_CNN, TimeGradPredictionNetwork_CNN
    return training_net, prediction_net

# Score Training

In [13]:
from torchtyping import TensorType

import torch
import torch.nn as nn
from pts.modules import MeanScaler

class ScoreTrainingNetwork(nn.Module):
    """
    Score training network.

    Args:
        context_length: Size of history
        prediction_length: Size of prediction
        target_dim: Dimension of data
        time_feat_dim: Dimension of covariates
        conditioning_length: Hidden dimension
        beta_end: Final diffusion scale
        diff_steps: Number of diffusion steps
        residual_layers: Number of residual layers
        residual_channels: Number of residual channels
        dilation_cycle_length: Dilation cycle length
    """
    def __init__(
        self,
        context_length: int,
        prediction_length: int,
        target_dim: int,
        time_feat_dim: int,
        conditioning_length: int,
        beta_end: float,
        diff_steps: int,
        residual_layers: int,
        residual_channels: int,
        dilation_cycle_length: int,
        **kwargs,
    ):
        super().__init__()
        self.context_length = context_length
        self.prediction_length = prediction_length

        # hidden_dim = conditioning_length
        # self.context_rnn = nn.GRU(target_dim + time_feat_dim, hidden_dim, num_layers=2, bidirectional=True, batch_first=True)

        self.diffusion = OUDiffusion(target_dim, BetaLinear(1e-4, beta_end), diff_steps)
        self.denoise_fn = DenoisingModelTG(
            dim=target_dim + time_feat_dim,
            residual_channels=residual_channels,
            latent_dim=conditioning_length,
            residual_hidden=conditioning_length,
        )

        self.scaler = MeanScaler(keepdim=True)

    def forward(
        self,
        target_dimension_indicator: TensorType['batch', 'dim'],
        past_time_feat:             TensorType['batch', 'history_length', 'feat_dim'],
        past_target_cdf:            TensorType['batch', 'history_length', 'dim'],
        past_observed_values:       TensorType['batch', 'history_length', 'dim'],
        past_is_pad:                TensorType['batch', 'history_length'],
        future_time_feat:           TensorType['batch', 'prediction_length', 'feat_dim'],
        future_target_cdf:          TensorType['batch', 'prediction_length', 'dim'],
        future_observed_values:     TensorType['batch', 'prediction_length', 'dim'],
    ) -> TensorType[()]:

        past_time_feat = past_time_feat[...,-self.context_length:,:]
        past_target_cdf = past_target_cdf[...,-self.context_length:,:]
        past_observed_values = past_observed_values[...,-self.context_length:,:]
        past_is_pad = past_is_pad[...,-self.context_length:]

        past_observed_values = torch.min(past_observed_values, 1 - past_is_pad.unsqueeze(-1))
        _, scale = self.scaler(past_target_cdf, past_observed_values)

        history = past_target_cdf / scale
        target = future_target_cdf / scale

        t = torch.arange(self.prediction_length).view(1, -1, 1).repeat(target.shape[0], 1, 1).to(target)

        loss = self.diffusion.get_loss(self.denoise_fn, target, t=t, history=history, covariates=future_time_feat)

        loss_weights, _ = future_observed_values.min(dim=-1, keepdim=True)
        loss = weighted_average(loss, weights=loss_weights, dim=1)

        return loss.mean()


class ScorePredictionNetwork(ScoreTrainingNetwork):
    def __init__(self, num_parallel_samples: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_parallel_samples = num_parallel_samples

    def forward(
        self,
        target_dimension_indicator: TensorType['batch', 'dim'],
        past_time_feat:             TensorType['batch', 'history_length', 'feat_dim'],
        past_target_cdf:            TensorType['batch', 'history_length', 'dim'],
        past_observed_values:       TensorType['batch', 'history_length', 'dim'],
        past_is_pad:                TensorType['batch', 'history_length'],
        future_time_feat:           TensorType['batch', 'prediction_length', 'feat_dim'],
    ) -> TensorType['batch', 'num_samples', 'prediction_length', 'dim']:

        past_observed_values = torch.min(past_observed_values, 1 - past_is_pad.unsqueeze(-1))

        rnn_states, scale = self.get_rnn_state(
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            future_time_feat=future_time_feat,
        )

        t = torch.arange(self.prediction_length).view(1, -1, 1)
        t = t.repeat(rnn_states.shape[0] * self.num_parallel_samples, 1, 1).to(rnn_states)

        rnn_states = rnn_states.repeat_interleave(self.num_parallel_samples, dim=0)

        samples = self.diffusion.sample(self.denoise_fn, t=t, latent=rnn_states)
        samples = samples.unflatten(0, (-1, self.num_parallel_samples)) * scale.unsqueeze(1)

        return samples


# Main Call

In [14]:
import argparse
import numpy as np
import torch
from copy import deepcopy

from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator


import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def energy_score(forecast, target):
    obs_dist = np.mean(np.linalg.norm((forecast - target), axis=-1))
    pair_dist = np.mean(
        np.linalg.norm(forecast[:, np.newaxis, ...] - forecast, axis=-1)
    )
    return obs_dist - pair_dist * 0.5

## Parameters

In [15]:
# Store arguments in a dictionary
args = {
    'seed': 1,
    'dataset': "electricity_nips",
    'network': "timegrad_rnn",  # Choose from ['timegrad', 'timegrad_old', 'timegrad_all', 'timegrad_rnn', 'timegrad_transformer', 'timegrad_cnn']
    'noise': "gp",  # Choose from ['normal', 'ou', 'gp']
    'diffusion_steps': 100,
    'epochs': 100,
    'learning_rate': 1e-3,
    'batch_size': 64,
    'num_cells': 100,
    'hidden_dim': 100,
    'residual_layers': 8
}

# Direct assignments for Jupyter notebook
seed = 1
dataset = "electricity_nips"
network = "timegrad_rnn"  # Choose from ['timegrad', 'timegrad_old', 'timegrad_all', 'timegrad_rnn', 'timegrad_transformer', 'timegrad_cnn']
noise = "gp"  # Choose from ['normal', 'ou', 'gp']
diffusion_steps = 100
epochs = 100
learning_rate = 1e-3
batch_size = 64
num_cells = 100
hidden_dim = 100
residual_layers = 8

## Set Data

In [16]:
np.random.seed(seed)
torch.manual_seed(seed)

covariance_dim = 4 if dataset != 'exchange_rate_nips' else -4

# Load data
dataset = get_dataset(dataset, regenerate=False)

target_dim = int(dataset.metadata.feat_static_cat[0].cardinality)

train_grouper = MultivariateGrouper(max_target_dim=min(2000, target_dim))
test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test) / len(dataset.train)), max_target_dim=min(2000, target_dim))
dataset_train = train_grouper(dataset.train)
dataset_test = test_grouper(dataset.test)

val_window = 20 * dataset.metadata.prediction_length
dataset_train = list(dataset_train)
dataset_val = []
for i in range(len(dataset_train)):
    x = deepcopy(dataset_train[i])
    x['target'] = x['target'][:,-val_window:]
    dataset_val.append(x)
    dataset_train[i]['target'] = dataset_train[i]['target'][:,:-val_window]

In [17]:
dataset_train[0]["target"].shape

(370, 5353)

In [18]:
dataset_val[0]["target"].shape

(370, 480)

## Training

In [19]:
training_net, prediction_net = get_network(network,noise)

In [20]:
estimator = ScoreEstimator(
    training_net=training_net,
    prediction_net=prediction_net,
    noise=noise,
    target_dim=target_dim,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length,
    cell_type='GRU',
    num_cells=num_cells,
    hidden_dim=hidden_dim,
    residual_layers=residual_layers,
    input_size=target_dim * 4 + covariance_dim,
    freq=dataset.metadata.freq,
    loss_type='l2',
    scaling=True,
    diff_steps=diffusion_steps,
    beta_end=20 / diffusion_steps,
    beta_schedule='linear',
    num_parallel_samples=100,
    pick_incomplete=True,
    trainer=TrainerForecasting(
        device=device,
        epochs=epochs,
        learning_rate=learning_rate,
        num_batches_per_epoch=100,
        batch_size=batch_size,
        patience=10,
    ),
)

In [21]:
# Training
# predictor = estimator.train(dataset_train, dataset_val, num_workers=8)

# Dataloader Outside GluonEstimator

## All Separated

In [30]:
from typing import NamedTuple, Optional
from functools import partial

import numpy as np

import torch
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader

from gluonts.env import env
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import SelectFields, Transformation
from gluonts.itertools import maybe_len

from pts import Trainer
from pts.model import get_module_forward_input_names
from pts.dataset.loader import TransformedIterableDataset

"""
from typing import NamedTuple, Optional
from functools import partial

import numpy as np

import torch
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader

from gluonts.env import env
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import SelectFields, Transformation
from gluonts.itertools import maybe_len

from pts import Trainer
from pts.model import get_module_forward_input_names
from pts.dataset.loader import TransformedIterableDataset

def train_model(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: int = 0,
        prefetch_factor: int = 2,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
        **kwargs,
    ):
        transformation = self.create_transformation()

        trained_net = self.create_training_network(self.trainer.device)

        input_names = get_module_forward_input_names(trained_net)

        with env._let(max_idle_transforms=maybe_len(training_data) or 0):
            training_instance_splitter = self.create_instance_splitter("training")
        training_iter_dataset = TransformedIterableDataset(
            dataset=training_data,
            transform=transformation
            + training_instance_splitter
            + SelectFields(input_names),
            is_train=True,
            shuffle_buffer_length=shuffle_buffer_length,
            cache_data=cache_data,
        )

        if validation_data is not None:
            with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
                validation_instance_splitter = self.create_instance_splitter("validation")
            validation_iter_dataset = TransformedIterableDataset(
                dataset=validation_data,
                transform=transformation
                + validation_instance_splitter
                + SelectFields(input_names),
                is_train=True,
                cache_data=cache_data,
            )
        return training_iter_dataset, validation_iter_dataset
        
from gluonts.env import env
from gluonts.dataset.common import Dataset
from typing import NamedTuple, Optional
from gluonts.core.component import validated
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import SelectFields, Transformation
from gluonts.itertools import maybe_len

lags_seq = None
time_features = None

training_net=training_net
prediction_net=prediction_net
noise=noise
target_dim=target_dim
prediction_length=dataset.metadata.prediction_length
context_length=dataset.metadata.prediction_length
cell_type='GRU'
num_cells=num_cells
hidden_dim=hidden_dim
residual_layers=residual_layers
input_size=target_dim * 4 + covariance_dim
freq=dataset.metadata.freq
loss_type='l2'
scaling=True
diff_steps=diffusion_steps
beta_end=20 / diffusion_steps
beta_schedule='linear'
num_parallel_samples=100
pick_incomplete=True

class TrainOutput(NamedTuple):
    transformation: Transformation
    trained_net: nn.Module
    predictor: PyTorchPredictor

def create_transformation() -> Transformation:
    return Chain(
        [
            AsNumpyArray(
                field=FieldName.TARGET,
                expected_ndim=2,
            ),
            # maps the target to (1, T)
            # if the target data is uni dimensional
            ExpandDimArray(
                field=FieldName.TARGET,
                axis=None,
            ),
            AddObservedValuesIndicator(
                target_field=FieldName.TARGET,
                output_field=FieldName.OBSERVED_VALUES,
            ),
            AddTimeFeatures(
                start_field=FieldName.START,
                target_field=FieldName.TARGET,
                output_field=FieldName.FEAT_TIME,
                time_features=time_features,
                pred_length=prediction_length,
            ),
            VstackFeatures(
                output_field=FieldName.FEAT_TIME,
                input_fields=[FieldName.FEAT_TIME],
            ),
            SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
            TargetDimIndicator(
                field_name="target_dimension_indicator",
                target_field=FieldName.TARGET,
            ),
            AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
        ]
    )

def create_instance_splitter(train_sampler,validation_sampler,mode: str):
    assert mode in ["training", "validation", "test"]

    instance_sampler = {
        "training": train_sampler,
        "validation": validation_sampler,
        "test": TestSplitSampler(),
    }[mode]

    return InstanceSplitter(
        target_field=FieldName.TARGET,
        is_pad_field=FieldName.IS_PAD,
        start_field=FieldName.START,
        forecast_start_field=FieldName.FORECAST_START,
        instance_sampler=instance_sampler,
        past_length=history_length,
        future_length=prediction_length,
        time_series_fields=[
            FieldName.FEAT_TIME,
            FieldName.OBSERVED_VALUES,
        ],
    ) + (
        RenameFields(
            {
                f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
                f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
            }
        )
    )


def create_training_network(self, training_net, device: torch.device):
    return training_net(
        noise=self.noise,
        input_size=self.input_size,
        target_dim=self.target_dim,
        num_layers=self.num_layers,
        num_cells=self.num_cells,
        cell_type=self.cell_type,
        history_length=self.history_length,
        context_length=self.context_length,
        prediction_length=self.prediction_length,
        dropout_rate=self.dropout_rate,
        cardinality=self.cardinality,
        embedding_dimension=self.embedding_dimension,
        diff_steps=self.diff_steps,
        loss_type=self.loss_type,
        beta_end=self.beta_end,
        beta_schedule=self.beta_schedule,
        residual_layers=self.residual_layers,
        residual_channels=self.residual_channels,
        dilation_cycle_length=self.dilation_cycle_length,
        lags_seq=self.lags_seq,
        scaling=self.scaling,
        conditioning_length=self.conditioning_length,
        time_feat_dim=self.time_feat_dim,
    ).to(device)

lags_seq = (
    lags_seq
    if lags_seq is not None
    else lags_for_fourier_time_features_from_frequency(freq_str=freq)
)

time_features = (
    time_features
    if time_features is not None
    else fourier_time_features_from_frequency(freq)
)

history_length = context_length + max(lags_seq)
pick_incomplete = pick_incomplete
scaling = scaling

train_sampler = ExpectedNumInstanceSampler(
    num_instances=1.0,
    min_past=0 if pick_incomplete else history_length,
    min_future=prediction_length,
)

validation_sampler = ValidationSplitSampler(
    min_past=0 if pick_incomplete else history_length,
    min_future=prediction_length,
) 
"""

'\nfrom typing import NamedTuple, Optional\nfrom functools import partial\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils import data\nfrom torch.utils.data import DataLoader\n\nfrom gluonts.env import env\nfrom gluonts.core.component import validated\nfrom gluonts.dataset.common import Dataset\nfrom gluonts.model.estimator import Estimator\nfrom gluonts.torch.model.predictor import PyTorchPredictor\nfrom gluonts.transform import SelectFields, Transformation\nfrom gluonts.itertools import maybe_len\n\nfrom pts import Trainer\nfrom pts.model import get_module_forward_input_names\nfrom pts.dataset.loader import TransformedIterableDataset\n\ndef train_model(\n        self,\n        training_data: Dataset,\n        validation_data: Optional[Dataset] = None,\n        num_workers: int = 0,\n        prefetch_factor: int = 2,\n        shuffle_buffer_length: Optional[int] = None,\n        cache_data: bool = False,\n        **kwargs,\n    ):\n        transformation

## In Class

In [None]:
from typing import NamedTuple, Optional
from functools import partial

import numpy as np

import torch
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader

from gluonts.env import env
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import SelectFields, Transformation
from gluonts.itertools import maybe_len

from pts import Trainer
from pts.model import get_module_forward_input_names
from pts.dataset.loader import TransformedIterableDataset

class ScoreEstimatorLOAD(PyTorchEstimator):
    def __init__(
        self,
        training_net: Callable,
        prediction_net: Callable,
        noise: str,
        input_size: int,
        freq: str,
        prediction_length: int,
        target_dim: int,
        trainer: TrainerForecasting = TrainerForecasting(),
        context_length: Optional[int] = None,
        num_layers: int = 2,
        num_cells: int = 40,
        cell_type: str = "GRU",
        num_parallel_samples: int = 100,
        dropout_rate: float = 0.1,
        cardinality: List[int] = [1],
        embedding_dimension: int = 5,
        hidden_dim: int = 100,
        diff_steps: int = 100,
        loss_type: str = "l2",
        beta_end=0.1,
        beta_schedule="linear",
        residual_layers=8,
        residual_channels=8,
        dilation_cycle_length=2,
        scaling: bool = True,
        pick_incomplete: bool = True,
        lags_seq: Optional[List[int]] = None,
        time_features: Optional[List[TimeFeature]] = None,
        old: bool = False,
        time_feat_dim: int = 4,
        **kwargs,
    ) -> None:
        super().__init__(trainer=trainer, **kwargs)

        self.training_net = training_net
        self.prediction_net = prediction_net
        self.noise = noise

        self.old = old

        self.freq = freq
        self.context_length = context_length if context_length is not None else prediction_length

        self.input_size = input_size
        self.prediction_length = prediction_length
        self.target_dim = target_dim
        self.time_feat_dim = time_feat_dim
        self.num_layers = num_layers
        self.num_cells = num_cells
        self.cell_type = cell_type
        self.num_parallel_samples = num_parallel_samples
        self.dropout_rate = dropout_rate
        self.cardinality = cardinality
        self.embedding_dimension = embedding_dimension

        self.conditioning_length = hidden_dim
        self.diff_steps = diff_steps
        self.loss_type = loss_type
        self.beta_end = beta_end
        self.beta_schedule = beta_schedule
        self.residual_layers = residual_layers
        self.residual_channels = residual_channels
        self.dilation_cycle_length = dilation_cycle_length

        self.lags_seq = (
            lags_seq
            if lags_seq is not None
            else lags_for_fourier_time_features_from_frequency(freq_str=freq)
        )

        self.time_features = (
            time_features
            if time_features is not None
            else fourier_time_features_from_frequency(self.freq)
        )

        self.history_length = self.context_length + max(self.lags_seq)
        self.pick_incomplete = pick_incomplete
        self.scaling = scaling

        self.train_sampler = ExpectedNumInstanceSampler(
            num_instances=1.0,
            min_past=0 if pick_incomplete else self.history_length,
            min_future=prediction_length,
        )

        self.validation_sampler = ValidationSplitSampler(
            min_past=0 if pick_incomplete else self.history_length,
            min_future=prediction_length,
        )

    def create_transformation(self) -> Transformation:
        return Chain(
            [
                AsNumpyArray(
                    field=FieldName.TARGET,
                    expected_ndim=2,
                ),
                # maps the target to (1, T)
                # if the target data is uni dimensional
                ExpandDimArray(
                    field=FieldName.TARGET,
                    axis=None,
                ),
                AddObservedValuesIndicator(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.OBSERVED_VALUES,
                ),
                AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    time_features=self.time_features,
                    pred_length=self.prediction_length,
                ),
                VstackFeatures(
                    output_field=FieldName.FEAT_TIME,
                    input_fields=[FieldName.FEAT_TIME],
                ),
                SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
                TargetDimIndicator(
                    field_name="target_dimension_indicator",
                    target_field=FieldName.TARGET,
                ),
                AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
            ]
        )

    def create_instance_splitter(self, mode: str):
        assert mode in ["training", "validation", "test"]

        instance_sampler = {
            "training": self.train_sampler,
            "validation": self.validation_sampler,
            "test": TestSplitSampler(),
        }[mode]

        return InstanceSplitter(
            target_field=FieldName.TARGET,
            is_pad_field=FieldName.IS_PAD,
            start_field=FieldName.START,
            forecast_start_field=FieldName.FORECAST_START,
            instance_sampler=instance_sampler,
            past_length=self.history_length,
            future_length=self.prediction_length,
            time_series_fields=[
                FieldName.FEAT_TIME,
                FieldName.OBSERVED_VALUES,
            ],
        ) + (
            RenameFields(
                {
                    f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
                    f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
                }
            )
        )

    def create_training_network(self, device: torch.device):
        return self.training_net(
            noise=self.noise,
            input_size=self.input_size,
            target_dim=self.target_dim,
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            cell_type=self.cell_type,
            history_length=self.history_length,
            context_length=self.context_length,
            prediction_length=self.prediction_length,
            dropout_rate=self.dropout_rate,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            diff_steps=self.diff_steps,
            loss_type=self.loss_type,
            beta_end=self.beta_end,
            beta_schedule=self.beta_schedule,
            residual_layers=self.residual_layers,
            residual_channels=self.residual_channels,
            dilation_cycle_length=self.dilation_cycle_length,
            lags_seq=self.lags_seq,
            scaling=self.scaling,
            conditioning_length=self.conditioning_length,
            time_feat_dim=self.time_feat_dim,
        ).to(device)

    def create_predictor(
        self,
        transformation: Transformation,
        trained_network: Any,
        device: torch.device,
    ) -> Predictor:
        prediction_network = self.prediction_net(
            noise=self.noise,
            input_size=self.input_size,
            target_dim=self.target_dim,
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            cell_type=self.cell_type,
            history_length=self.history_length,
            context_length=self.context_length,
            prediction_length=self.prediction_length,
            dropout_rate=self.dropout_rate,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            diff_steps=self.diff_steps,
            loss_type=self.loss_type,
            beta_end=self.beta_end,
            beta_schedule=self.beta_schedule,
            residual_layers=self.residual_layers,
            residual_channels=self.residual_channels,
            dilation_cycle_length=self.dilation_cycle_length,
            lags_seq=self.lags_seq,
            scaling=self.scaling,
            conditioning_length=self.conditioning_length,
            num_parallel_samples=self.num_parallel_samples,
            time_feat_dim=self.time_feat_dim,
        ).to(device)

        copy_parameters(trained_network, prediction_network)
        input_names = get_module_forward_input_names(prediction_network)
        prediction_splitter = self.create_instance_splitter("test")

        return PyTorchPredictor(
            input_transform=transformation + prediction_splitter,
            input_names=input_names,
            prediction_net=prediction_network,
            batch_size=self.trainer.batch_size,
            freq=self.freq,
            prediction_length=self.prediction_length,
            device=device,
        )

    def train_model(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: int = 0,
        prefetch_factor: int = 2,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
        **kwargs,
    ):
        transformation = self.create_transformation()

        trained_net = self.create_training_network(self.trainer.device)

        input_names = get_module_forward_input_names(trained_net)

        with env._let(max_idle_transforms=maybe_len(training_data) or 0):
            validation_instance_splitter = self.create_instance_splitter("validation")
        training_iter_dataset = TransformedIterableDataset(
            dataset=training_data,
            transform=transformation
            + validation_instance_splitter
            + SelectFields(input_names),
            is_train=True,
            shuffle_buffer_length=shuffle_buffer_length,
            cache_data=cache_data,
        )

        if validation_data is not None:
            with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
                validation_instance_splitter = self.create_instance_splitter("validation")
            validation_iter_dataset = TransformedIterableDataset(
                dataset=validation_data,
                transform=transformation
                + validation_instance_splitter
                + SelectFields(input_names),
                is_train=True,
                cache_data=cache_data,
            )

        return training_iter_dataset, validation_iter_dataset

    def train(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: int = 0,
        prefetch_factor: int = 2,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
        **kwargs,
    ):
        return self.train_model(
            training_data,
            validation_data,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            shuffle_buffer_length=shuffle_buffer_length,
            cache_data=cache_data,
            **kwargs,
        )

In [25]:
target_dim

370

In [26]:
estimator_load = ScoreEstimatorLOAD(
    training_net=training_net,
    prediction_net=prediction_net,
    noise=noise,
    target_dim=target_dim,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length,
    cell_type='GRU',
    num_cells=num_cells,
    hidden_dim=hidden_dim,
    residual_layers=residual_layers,
    input_size=target_dim * 4 + covariance_dim,
    freq=dataset.metadata.freq,
    loss_type='l2',
    scaling=True,
    diff_steps=diffusion_steps,
    beta_end=20 / diffusion_steps,
    beta_schedule='linear',
    num_parallel_samples=100,
    pick_incomplete=True,
    trainer=TrainerForecasting(
        device=device,
        epochs=epochs,
        learning_rate=learning_rate,
        num_batches_per_epoch=100,
        batch_size=batch_size,
        patience=10,
    ),
)

In [27]:
train_dataset, val_dataset  = estimator_load.train(dataset_train, dataset_val, num_workers=8)

In [28]:
for i,data_entry in enumerate(val_dataset):
    if i % 5000 == 0:
        print(data_entry.keys())
        break
    

dict_keys(['target_dimension_indicator', 'past_time_feat', 'past_target_cdf', 'past_observed_values', 'past_is_pad', 'future_time_feat', 'future_target_cdf', 'future_observed_values'])


In [29]:
for i,data_entry in enumerate(train_dataset):
    if i % 5000 == 0:
        print(data_entry.keys())
        break
    
    

dict_keys(['target_dimension_indicator', 'past_time_feat', 'past_target_cdf', 'past_observed_values', 'past_is_pad', 'future_time_feat', 'future_target_cdf', 'future_observed_values'])
