<a href="https://colab.research.google.com/github/bchenley/TorchSequence/blob/main/TorchSequence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install torch --quiet
!pip install pytorch_lightning --quiet


In [11]:
import torch
from torchsummary import summary

import pytorch_lightning as pl

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
%load_ext tensorboard

from pytorch_lightning.loggers import TensorBoardLogger

torch.autograd.set_detect_anomaly(True)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import sklearn.preprocessing as skp

import scipy.io
import scipy as sc
from scipy import signal as sp
from scipy import interpolate as interp
from scipy.special import factorial

import itertools
import math
from datetime import datetime, timedelta

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import sys

import random

from tqdm.auto import tqdm

import copy

import pickle

import time

import pdb


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class FeatureTransform():
  '''
  A class for performing feature scaling and transformation operations on data.
  '''

  def __init__(self,
                scale_type = 'minmax', minmax = [0., 1.], dim = 0,
                device = 'cpu', dtype = torch.float32):
    '''
    Initializes the FeatureTransform instance.

    Args:
        scale_type (str): The type of scaling to be applied. Options are 'identity', 'minmax', or 'standard'.
        minmax (list): The minimum and maximum values to scale the data when using 'minmax' scaling.
        dim (int): The dimension along which the scaling is applied.
        device (str): The device to be used for computations.
        dtype (torch.dtype): The data type to be used for computations.
    '''

    if scale_type not in ['identity', 'minmax', 'standard']:
        raise ValueError(f"scale_type ({scale_type}) is not set to 'identity', 'minmax', or 'standard'.")

    self.scale_type = scale_type
    self.minmax = minmax
    self.dim = dim
    self.device, self.dtype = device, dtype

  def identity(self, X):
    '''
    Returns the input data as it is without any scaling.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The input data unchanged.
    '''
    self.min_, self.max_ = X.min(self.dim).values, X.max(self.dim).values
    return X

  def standardize(self, X):
    '''
    Performs standardization on the input data.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The standardized input data.
    '''
    self.mean_, self.std_ = X.mean(self.dim), X.std(self.dim)
    return (X - self.mean_) / self.std_

  def inverse_standardize(self, X):
    '''
    Applies inverse standardization on the input data.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The inversely standardized input data.
    '''
    return X * self.std_ + self.mean_

  def normalize(self, X):
    '''
    Performs normalization on the input data.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The normalized input data.
    '''
    self.min_, self.max_ = X.min(self.dim).values, X.max(self.dim).values
    return (X - self.min_) / (self.max_ - self.min_) * (self.minmax[1] - self.minmax[0]) + self.minmax[0]

  def inverse_normalize(self, X):
    '''
    Applies inverse normalization on the input data.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The inversely normalized input data.
    '''
    return (X - self.minmax[0]) * (self.max_ - self.min_) / (self.minmax[1] - self.minmax[0]) + self.min_

  def fit_transform(self, X):
    '''
    Fits the scaling parameters based on the input data and transforms the data accordingly.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The transformed input data.
    '''
    if self.scale_type == 'identity':
        X_transformed = self.identity(X)
    elif self.scale_type == 'minmax':
        X_transformed = self.normalize(X)
    elif self.scale_type == 'standard':
        X_transformed = self.standardize(X)

    return X_transformed

  def transform(self, X):
    '''
    Transforms the input data based on the previously fitted scaling parameters.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The transformed input data.
    '''
    if self.scale_type == 'identity':
        X_transformed = X
    elif self.scale_type == 'minmax':
        X_transformed = (X - self.min_) / (self.max_ - self.min_) * (self.minmax[1] - self.minmax[0]) + self.minmax[0]
    elif self.scale_type == 'standard':
        X_transformed = (X - self.mean_) / self.std_

    return X_transformed

  def inverse_transform(self, X):
    '''
    Applies the inverse transformation on the input data.

    Args:
        X (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The inversely transformed input data.
    '''
    if self.scale_type == 'identity':
        X_inverse_transformed = X
    elif self.scale_type == 'minmax':
        X_inverse_transformed = self.inverse_normalize(X)
    elif self.scale_type == 'standard':
        X_inverse_transformed = self.inverse_standardize(X)

    return X_inverse_transformed


In [None]:
class Loss():
  '''
  A class for computing loss functions.
  '''

  def __init__(self, name='mse', dims=0):
    '''
    Initializes the Loss instance.

    Args:
        name (str): The name of the loss function. Options are 'mae', 'mse', 'mase', 'rmse', 'nmse', 'mape', 'fb'.
        dims (int): The dimension along which the loss is computed.
    '''
    self.name = name
    self.dims = dims

  def __call__(self, y_pred, y_true):
    '''
    Computes the loss based on the predicted and true values.

    Args:
        y_pred (torch.Tensor): The predicted values.
        y_true (torch.Tensor): The true values.

    Returns:
        torch.Tensor: The computed loss value.
    '''
    if self.name == 'mae':
        # Mean Absolute Error (L1 loss)
        loss = (y_true - y_pred).abs().nanmean(dim=self.dims)
    elif self.name == 'mse':
        # Mean Squared Error
        loss = (y_true - y_pred).pow(2).nanmean(dim=self.dims)
    elif self.name == 'mase':
        # Mean Absolute Scaled Error
        loss = (y_true - y_pred).abs().nanmean(dim=self.dims) / (y_true.diff(n=1, dim=self.dims).abs().nanmean(dim=self.dims))
    elif self.name == 'rmse':
        # Root Mean Squared Error
        loss = (y_true - y_pred).pow(2).nanmean(dim=self.dims).sqrt()
    elif self.name == 'nmse':
        # Normalized Mean Squared Error
        loss = (y_true - y_pred).pow(2).nanmean(dim=self.dims) / y_true.pow(2).nanmean(dim=self.dims)
    elif self.name == 'mape':
        # Mean Absolute Percentage Error
        loss = (((y_true - y_pred) / y_true).abs() * 100).nanmean(dim=self.dims)
    elif self.name == 'fb':
        # Fractional Bias
        loss = (y_pred.nansum(dim=self.dims) - y_true.nansum(dim=self.dims)) / y_true.nansum(dim=self.dims)

    return loss


In [None]:
def fft(x, fs=1, dim=0, nfft=None, norm='backward',
        device=None, dtype=torch.complex64):
  '''
  Computes the Fast Fourier Transform (FFT) of the input signal.

  Args:
      x: The input signal. If not a torch.Tensor, it will be converted to one.
      fs: The sampling frequency of the input signal.
      dim: The dimension(s) along which to compute the FFT.
      nfft: The number of FFT points. If None, it is set to the size of the input signal along the specified dimension.
      norm: The normalization mode. Options are 'backward' (default) and 'forward'.
      device: The device to perform the computation on. If None, the default device is used.
      dtype: The data type of the output.

  Returns:
      freq: The frequency values corresponding to the FFT.
      x_fft_mag: The magnitude of the FFT coefficients.
      x_fft_phase: The phase of the FFT coefficients.
  '''
  if not isinstance(x, torch.Tensor):
      if isinstance(x, pd.core.frame.DataFrame):
          x = x.values
      x = torch.tensor(x).to(device=device, dtype=dtype)

  if nfft is None:
      nfft = x.shape[dim]
      print(f'nfft set to {nfft}')

  s, dim = [nfft, dim if isinstance(dim, int) else (-2, -1)]

  s += np.mod(s, 2)
  x_fft = torch.fft.fftn(x, s=s, dim=dim, norm=norm).to(device=device, dtype=dtype)

  N = int(s // 2)

  if isinstance(dim, int):
      freq = torch.fft.fftfreq(s, d=1 / fs).to(device=device)[:N]

      x_fft = x_fft.split(N, dim=dim)[0]

      x_fft_mag = 2.0 / s * torch.abs(x_fft)

      x_fft_phase = torch.angle(x_fft)

  elif dim == (-2, -1):
      freq = torch.meshgrid(freq, freq, indexing='ij')

      x_fft_mag = 2 / s * torch.abs(x_fft[..., :N, :N])

      x_fft_phase = torch.angle(x_fft)[..., :N, :N]

  else:
      raise ValueError(f'dim ({dim}) must be 1 or (-2, -1)... for now.')

  return freq, x_fft_mag, x_fft_phase


In [None]:
def periodogram(X, sf=1, window='hann', nfft=512,
                detrend=None, return_onesided=True,
                scaling='density', axis=0):
  '''
  Computes the periodogram of a signal using the SciPy library.

  Args:
      X: The input signal.
      sf: The sampling frequency of the input signal.
      window: The window function to apply to the signal.
      nfft: The number of points to compute the FFT.
      detrend: The detrend function to remove a trend from the signal.
      return_onesided: If True, returns only the one-sided spectrum for real inputs.
      scaling: The scaling mode for the power spectrum.
      axis: The axis along which to compute the periodogram.

  Returns:
      f: The frequencies at which the periodogram is computed.
      psd: The power spectral density (periodogram) of the signal.
  '''
  if nfft is None:  # or nfft < x.shape[dim]
      nfft = X.shape[axis]
      print(f'nfft set to {nfft}')

  f, psd = sp.periodogram(X, sf=sf, window=window, nfft=nfft,
                          detrend=detrend, return_onesided=return_onesided,
                          scaling=scaling, axis=axis)

  return f, psd


In [None]:

def moving_average(X, window):
    '''
    Applies a moving average filter to the input signal.

    Args:
        X: The input signal.
        window: The window of the moving average filter.

    Returns:
        y: The output signal after applying the moving average filter.
    '''
    if isinstance(X, torch.Tensor):
        X = X.cpu().numpy()
    if isinstance(window, torch.Tensor):
        window = window.cpu().numpy()

    len_window = window.shape[0]

    y = np.empty_like(X)

    ww = []

    for i in range(X.shape[0]):
        is_odd = int(np.mod(len_window, 2) == 1)

        m = np.arange((i - (len_window - is_odd) / 2), (i + (len_window - is_odd) / 2 - (is_odd == 0) + 1),
                      dtype=np.compat.long)

        k = m[(m >= 0) & (m < X.shape[0])]

        window_ = window[(m >= 0) & (m < X.shape[0])]
        window_ /= window_.sum(0)

        y[i] = np.dot(window_.T, X[k])

    return y


In [None]:

def butter(x, critical_frequency, butter_type = 'low', filter_order = 3, sampling_rate = 1):
    '''
    Applies a Butterworth filter to the input signal.

    Args:
        x: The input signal.
        critical_frequency: The critical frequency of the filter.
        butter_type: The type of Butterworth filter to apply.
        filter_order: The order of the filter.
        sampling_rate: The sampling rate of the input signal.

    Returns:
        y: The output signal after applying the Butterworth filter.
    '''
    b, a = sp.butter(N=filter_order,
                     Wn=critical_frequency / (sampling_rate / 2),
                     btype=butter_type,
                     output='ba')

    y = sp.filtfilt(b, a, x, axis=0)
    return y


In [None]:
def fill(X, steps, interp_kind = 'linear'):
  '''
  Fills missing values in a dataset using interpolation.

  Args:
      X: The input dataset.
      steps: The time steps associated with the data points.
      interp_kind: The kind of interpolation method to use.

  Returns:
      X: The dataset with missing values filled using interpolation.
  '''
  for i in range(X.shape[-1]):
      X_i = X[:, i].copy()

      interpolator = Interpolator(kind=interp_kind)

      if np.any(np.isnan(X_i)):
          X_i_notnan = X_i[~np.isnan(X_i)]
          steps_i_notnan = steps[~np.isnan(X_i)]

          interpolator.fit(steps_i_notnan, X_i_notnan)

          X[:, i] = interpolator.interp_fn(steps)

  return X


In [None]:
class Interpolator():
  '''
  Interpolator for 1-dimensional data.

  Args:
      kind: The kind of interpolation method to use.
      axis: The axis along which to interpolate.

  Attributes:
      interp_fn: The interpolation function.

  Methods:
      fit: Fits the interpolation function to the provided data.

  '''
  def __init__(self, kind='linear', axis=0):
      super().__init__()

      self.kind = kind
      self.axis = axis
      self.interp_fn = None

  def fit(self, x, y):
      '''
      Fits the interpolation function to the provided data.

      Args:
          x: The x-coordinates of the data points.
          y: The y-coordinates of the data points.

      '''
      if isinstance(x, torch.Tensor):
          x = x.detach().numpy()
      if isinstance(y, torch.Tensor):
          y = y.detach().numpy()

      self.interp_fn = sc.interpolate.interp1d(x, y, kind=self.kind, axis=self.axis)


In [None]:
def remove_outliers(X, steps, abs_max_change=[np.inf], z_change_critical=[7], interp_type='linear'):
  '''
  Remove outliers from the input data.

  Args:
      X: The input data array.
      steps: The corresponding steps array.
      abs_max_change: The absolute maximum change threshold for outlier removal.
      z_change_critical: The z-change critical value for outlier removal.
      interp_type: The type of interpolation to use for filling the gaps left by removed outliers.

  Returns:
      Y: The input data with outliers removed and gaps filled.
      steps: The corresponding steps array after outlier removal.

  '''
  if len(abs_max_change) == 1:
      abs_max_change = abs_max_change * X.shape[-1]
  if len(z_change_critical) == 1:
      z_change_critical = z_change_critical * X.shape[-1]

  x, i_x, interpolator = [], [], []

  for i in range(X.shape[-1]):
      X_i = X[:, i]

      i_all = np.arange(X_i.shape[0], dtype=np.compat.long)
      i_x.append(i_all)

      diff = np.diff(X_i)
      z_diff = (diff - diff.mean()) / diff.std()

      i_discard = np.where((np.abs(diff) > abs_max_change[i]) | (np.abs(z_diff) > z_change_critical[i]))[0]

      if len(i_discard) > 0:
          interpolator.append(Interpolator(kind=interp_type))
      else:
          interpolator.append(None)

      while len(i_discard) > 0:
          j_discard = i_discard[0] + [0, 1]
          j_discard = j_discard[j_discard < len(X_i)]

          j_discard = j_discard[np.abs(X_i.mean() - X_i[j_discard]).argmax()]
          X_i = np.delete(X_i, j_discard)
          i_x[-1] = np.delete(i_x[-1], j_discard)

          diff = np.diff(X_i)
          z_diff = (diff - diff.mean()) / diff.std()

          i_discard = np.where((np.abs(diff) > abs_max_change[i]) | (np.abs(z_diff) > z_change_critical[i]))[0]

      x.append(X_i.reshape(-1, 1))

      print(f"{i+1}/{X.shape[-1]}")

  i_min = np.max([np.min(i) for i in i_x])
  i_max = np.min([np.max(i) for i in i_x])

  i_all = np.arange(i_min, i_max + 1, dtype=np.compat.long)
  steps = steps[i_all]

  for i, interp in enumerate(interpolator):
      if interp is not None:
          interp.fit(i_x[i], x[i])
          x[i] = interp.interp_fn(i_all)

  Y = np.concatenate(x, -1)

  return Y, steps


In [None]:
class BaselineModel():
  '''
  Baseline models for time series prediction.

  Args:
      model_type (str): Type of baseline model.
      naive_steps (int): Number of steps for the naive baseline model.
      ma_window_size (int): Moving average window size for the moving average baseline model.
      decay (float): Decay factor for exponential smoothing models.
      trend (list): Trend parameters for exponential smoothing models.
      period (int): Period parameter for seasonal exponential smoothing model.
      seasonal (list): Seasonal parameters for seasonal exponential smoothing model.

  '''

  def __init__(self, model_type='naive', naive_steps=1, ma_window_size=20, decay=0.5, trend=[0.5, 1.0], period=1, seasonal=[0.5, 1.0]):
      self.model_type = model_type
      self.naive_steps = naive_steps
      self.decay, self.trend, self.seasonal, self.period = decay, trend, seasonal, period
      self.ma_window_size = ma_window_size

  def ma_prediction(self, input):
      '''
      Moving average prediction.

      Args:
          input: The input data tensor.

      Returns:
          prediction: The predicted values based on the moving average model.

      '''
      prediction = []
      for n in range(input.shape[0]):
          prediction_n = input[np.max([0, n - self.ma_window_size]):n].mean(0, keepdims=True)
          prediction.append(prediction_n)

      prediction = torch.cat(prediction, 0)

      return prediction

  def naive_prediction(self, input):
    '''
    Naive prediction.

    Args:
        input: The input data tensor.

    Returns:
        prediction: The predicted values based on the naive model.

    '''
    prediction = torch.full((self.naive_steps, input.shape[1]), float('nan')).to(input)
    for n in range(self.naive_steps, input.shape[0]):
        prediction_n = input[n - self.naive_steps]
        prediction = torch.cat((prediction, prediction_n), 0)

    return prediction

  def ses_prediction(self, input):
    '''
    Single exponential smoothing prediction.

    Args:
        input: The input data tensor.

    Returns:
        prediction: The predicted values based on the single exponential smoothing model.

    '''
    prediction = torch.full((1, input.shape[1]), float('nan')).to(input)
    for n in range(1, input.shape[0]):
        prediction_n = self.decay * input[n - 1] + (1 - self.decay) * prediction_n[n - 1]
        prediction = torch.cat((prediction, prediction_n), 0)

  def des_prediction(self, input):
    '''
    Double exponential smoothing prediction.

    Args:
        input: The input data tensor.

    Returns:
        prediction: The predicted values based on the double exponential smoothing model.

    '''
    prediction = torch.full((1, input.shape[1]), float('nan')).to(input)
    level_prev, trend_prev = 0, 0
    for n in range(1, input.shape[0]):
        level_n = self.decay * input[n - 1] + (1 - self.decay) * (level_prev + trend_prev)
        trend_n = self.trend[0] * (level_n - level_prev) + (1 - self.trend[0]) * trend_prev

        prediction_n = level_n + self.trend[1] * trend_n

        prediction = torch.cat((prediction, prediction_n), 0)

        level_prev, trend_prev = level_n, trend_prev

    return prediction

  def tes_prediction(self, input):
    '''
    Seasonal exponential smoothing prediction.

    Args:
        input: The input data tensor.

    Returns:
        prediction: The predicted values based on the seasonal exponential smoothing model.

    '''
    prediction = torch.full((self.period, input.shape[1]), float('nan')).to(input)
    level_prev, trend_prev = 0, 0
    season = torch.zeros_like(input).to(input)
    for n in range(self.period, input.shape[0]):
        level_n = self.decay * input[n - 1] + (1 - self.decay) * (level_prev + trend_prev) + season[n - self.period]
        trend_n = self.trend[0] * (level_n - level_prev) + (1 - self.trend[0]) * trend_prev
        season[n] = self.seasonal[0] * (input[n - 1] - level_prev - trend_prev) + (1 - self.seasonal) * season[
            n - self.period]

        prediction_n = level_n + self.trend[1] * trend_n + self.seasonal[1] * season[n]

        prediction = torch.cat((prediction, prediction_n), 0)

        level_prev, trend_prev = level_n, trend_prev

  def __call__(self, input):
    '''
    Make predictions based on the selected model type.

    Args:
        input: The input data tensor.

    Returns:
        prediction: The predicted values based on the selected model type.

    '''
    if self.model_type == 'naive':
        return self.naive_prediction(input)

    if self.model_type == 'moving_average':
        return self.ma_prediction(input)

    if self.model_type == 'bayesian':
        return self.bayesian_prediction(input)

    if self.model_type == 'ses':
        return self.ses_prediction(input)

    if self.model_type == 'des':
        return self.des_prediction(input)


In [None]:
class Polynomial(torch.nn.Module):
  '''
  Polynomial regression model.

  Args:
  - in_features (int): Number of input features.
  - degree (int): Degree of the polynomial.
  - coef_init (torch.Tensor): Initial coefficients for the polynomial. If None, coefficients are initialized randomly.
  - coef_train (bool): Whether to train the coefficients.
  - coef_reg (list): Regularization parameters for the coefficients. [Regularization weight, regularization exponent]
  - zero_order (bool): Whether to include the zeroth-order term (constant) in the polynomial.
  - device (str): Device to use for computation ('cpu' or 'cuda').
  - dtype (torch.dtype): Data type of the coefficients.
  '''

  def __init__(self,
               in_features, degree=1, coef_init=None, coef_train=True,
               coef_reg=[0.001, 1], zero_order=True,
               device='cpu', dtype=torch.float32):
      super(Polynomial, self).__init__()

      self.to(device=device, dtype=dtype)

      if coef_init is None:
          coef_init = torch.nn.init.normal_(torch.empty(in_features, degree + int(zero_order)))

      coef = torch.nn.Parameter(data=coef_init.to(device=device, dtype=dtype), requires_grad=coef_train)

      self.coef, self.coef_reg = coef, coef_reg
      self.in_features, self.degree = in_features, degree
      self.zero_order = zero_order
      self.device, self.dtype = device, dtype

  def forward(self, X):
    '''
    Perform forward pass to compute polynomial regression.

    Args:
    - X (torch.Tensor): Input data tensor of shape (batch_size, in_features).

    Returns:
    - y (torch.Tensor): Output predictions of shape (batch_size).
    '''

    X = X.to(device=self.device, dtype=self.dtype)

    pows = torch.arange(1 - int(self.zero_order), (self.degree + 1), device=self.device, dtype=self.dtype)

    y = (X.unsqueeze(-1).pow(pows) * self.coef).sum(-1)

    return y

  def penalize(self):
    '''
    Compute the penalty term for coefficient regularization.

    Returns:
    - penalty (torch.Tensor): Penalty term based on coefficient regularization.
    '''

    return self.coef_reg[0] * torch.norm(self.coef, p=self.coef_reg[1]) * int(self.coef.requires_grad)


In [None]:
class LRU(torch.nn.RNN):
  '''
  Laguerre Recurrent Unit (LRU) model based on RNN architecture.

  Args:
      input_size (int): Number of expected features in the input.
      hidden_size (int): Number of features in the hidden state.
      weight_reg (list): Regularization parameters for the weights. [Regularization weight, regularization exponent]
      weight_norm (int): Norm to be used for weight regularization.
      bias (bool): If True, adds a learnable bias to the output.
      relax_init (list): Initial relaxation values for the LRU model.
      relax_train (bool): Whether to train the relaxation values.
      relax_minmax (list): Minimum and maximum relaxation values for each filter bank.
      device (str): Device to use for computation ('cpu' or 'cuda').
      dtype (torch.dtype): Data type of the model parameters.
  '''

  def __init__(self,
               input_size, hidden_size, weight_reg=[0.001, 1], weight_norm=2, bias=False,
               relax_init=[0.5], relax_train=True, relax_minmax=[[0.1, 0.9]], device='cpu', dtype=torch.float32):

    super(LRU, self).__init__(input_size=input_size, hidden_size=hidden_size, batch_first=True)

    self.to(device=device, dtype=dtype)

    num_filterbanks = len(relax_init)

    if len(relax_minmax) == 1:
        relax_minmax = relax_minmax * num_filterbanks

    relax_init = torch.tensor(relax_init).reshape(num_filterbanks,)

    relax = torch.nn.Parameter(relax_init.to(device=device, dtype=dtype), requires_grad=relax_train)

    if input_size > 1:
        input_block = torch.nn.Linear(in_features=input_size, out_features=num_filterbanks, bias=bias)
    else:
        input_block = torch.nn.Identity()

    self.bias_hh_l0.requires_grad = False
    self.bias_hh_l0.requires_grad = False

    self.input_size, self.hidden_size = input_size, hidden_size
    self.num_filterbanks = num_filterbanks
    self.input_block = input_block
    self.relax_minmax = relax_minmax
    self.relax = relax
    self.weight_reg, self.weight_norm = weight_reg, weight_norm
    self.device, self.dtype = device, dtype

    # Remove built-in weights and biases
    self.weight_ih_l0 = None
    self.weight_hh_l0 = None
    self.bias_ih_l0 = None
    self.bias_hh_l0 = None

  def init_hiddens(self, num_samples):
    '''
    Initialize the hidden state of the LRU model.

    Args:
        num_samples (int): Number of samples in the batch.

    Returns:
        torch.Tensor: Initialized hidden state tensor.
    '''
    return torch.zeros((self.num_filterbanks, num_samples, self.hidden_size)).to(device=self.device, dtype=self.dtype)

  def cell(self, input, hiddens=None):
    '''
    LRU cell computation for a single time step.

    Args:
        input (torch.Tensor): Input tensor for the current time step.
        hiddens (torch.Tensor): Hidden state tensor.

    Returns:
        torch.Tensor: Output tensor for the current time step.
        torch.Tensor: Updated hidden state tensor.
    '''
    num_samples, input_size = input.shape

    hiddens = hiddens if hiddens is not None else self.init_hiddens(num_samples)

    sq_relax = torch.sqrt(self.relax)

    hiddens_new = torch.zeros_like(hiddens).to(hiddens)

    hiddens_new[..., 0] = sq_relax[:, None] * hiddens[..., 0] + (1 - sq_relax ** 2).sqrt()[:, None] * self.input_block(input).t()

    for i in range(1, self.hidden_size):
        hiddens_new[..., i] = sq_relax[:, None] * (hiddens[..., i] + hiddens_new[..., i - 1]) - hiddens[..., i - 1]

    output = hiddens_new.permute(1, 0, 2)  # [batch_size, num_filters, hidden_size]

    return output, hiddens_new

  def forward(self, input, hiddens=None):
    '''
    Forward pass of the LRU model.

    Args:
        input (torch.Tensor): Input tensor.
        hiddens (torch.Tensor): Hidden state tensor.

    Returns:
        torch.Tensor: Output tensor.
        torch.Tensor: Updated hidden state tensor.
    '''
    num_samples, input_len, input_size = input.shape

    hiddens = self.init_hiddens(num_samples) if hiddens is None else hiddens

    output = []
    for n, input_n in enumerate(input.split(1, 1)):
        output_n, hiddens = self.cell(input_n.squeeze(1), hiddens)
        output.append(output_n.unsqueeze(1))

    output = torch.cat(output, 1)

    return output, hiddens

  def generate_laguerre_functions(self, max_len):
    '''
    Generate Laguerre functions up to a specified maximum length.

    Args:
        max_len (int): Maximum length of the Laguerre functions.

    Returns:
        torch.Tensor: Generated Laguerre functions.
    '''
    with torch.no_grad():
        hiddens = self.init_hiddens(1)

        impulse = torch.zeros((1, max_len, self.input_size)).to(device=self.device, dtype=self.dtype)

        impulse[:, 0, :] = 1

        output, hiddens = self.forward(impulse, hiddens)

        return output

  def clamp_relax(self):
    '''
    Clamp relaxation values to the specified minimum and maximum range.
    '''
    for i in range(self.num_filterbanks):
        self.relax[i].data.clamp_(self.relax_minmax[i][0], self.relax_minmax[i][1])


In [None]:
class HiddenLayer(torch.nn.Module):
  '''
  Hidden layer module with various activation functions and regularization options.

  Args:
      in_features (int): Number of input features.
      out_features (int or None): Number of output features. If None or 0, output features will be the same as input features.
      bias (bool): If True, adds a learnable bias to the output.
      activation (str): Activation function to use. Options: 'identity', 'polynomial', 'tanh', 'sigmoid', 'softmax', 'relu'.
      weight_reg (list): Regularization parameters for the weights. [Regularization weight, regularization exponent]
      weight_norm (int): Norm to be used for weight regularization.
      degree (int): Degree of the polynomial activation function.
      coef_init (torch.Tensor): Initial coefficients for the polynomial activation function. If None, coefficients are initialized randomly.
      coef_train (bool): Whether to train the coefficients.
      coef_reg (list): Regularization parameters for the coefficients. [Regularization weight, regularization exponent]
      zero_order (bool): Whether to include the zeroth-order term (constant) in the polynomial activation function.
      softmax_dim (int): Dimension along which to apply softmax activation.
      dropout_p (float): Dropout probability.
      device (str): Device to use for computation ('cpu' or 'cuda').
      dtype (torch.dtype): Data type of the model parameters.

  '''

  def __init__(self, in_features, out_features=None, bias=True, activation='identity',
                weight_reg=[0.001, 1], weight_norm=2, degree=1, coef_init=None, coef_train=True,
                coef_reg=[0.001, 1], zero_order=False, softmax_dim=-1, dropout_p=0.0,
                device='cpu', dtype=torch.float32):
    super(HiddenLayer, self).__init__()

    self.to(device=device, dtype=dtype)

    if out_features is None or out_features == 0:
        out_features = in_features
        f1 = torch.nn.Identity()
    else:
        if isinstance(in_features, list):  # bilinear (must be len = 2)
            class Bilinear(torch.nn.Module):
                def __init__(self, in1_features=in_features[0], in2_features=in_features[1],
                              out_features=out_features, bias=bias, device=device, dtype=dtype):
                    super(Bilinear, self).__init__()

                    self.F = torch.nn.Bilinear(in1_features, in2_features, out_features, bias)

                def forward(self, input):
                    input1, input2 = input
                    return self.F(input1, input2)

            f1 = Bilinear()
        else:
            f1 = torch.nn.Linear(in_features=in_features, out_features=out_features,
                                  bias=bias, device=device, dtype=dtype)

    if activation == 'identity':
        f2 = torch.nn.Identity()
    elif activation == 'polynomial':
        f2 = Polynomial(in_features=out_features, degree=degree, coef_init=coef_init,
                        coef_train=coef_train, coef_reg=coef_reg, zero_order=zero_order,
                        device=device, dtype=dtype)
    elif activation == 'tanh':
        f2 = torch.nn.Tanh()
    elif activation == 'sigmoid':
        f2 = torch.nn.Sigmoid()
    elif activation == 'softmax':
        f2 = torch.nn.Softmax(dim=softmax_dim)
    elif activation == 'relu':
        f2 = torch.nn.ReLU()
    else:
        raise ValueError(f"activation ({activation}) must be 'identity', 'polynomial', 'tanh', 'sigmoid', or 'relu'.")

    F = torch.nn.Sequential(f1, f2)

    self.dropout = torch.nn.Dropout(dropout_p)

    self.F = F
    self.device, self.dtype = device, dtype
    self.weight_reg, self.weight_norm = weight_reg, weight_norm

  def forward(self, input):
    '''
    Perform a forward pass through the hidden layer.

    Args:
        input (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Output tensor.

    '''
    y = self.dropout(self.F(input))
    return y

  def constrain(self):
    '''
    Constrain the weights of the hidden layer.

    '''
    for name, param in self.named_parameters():
        if 'weight' in name:
            param = torch.nn.functional.normalize(param, p=self.weight_norm, dim=1).contiguous()

  def penalize(self):
    '''
    Compute the regularization loss for the hidden layer.

    Returns:
        torch.Tensor: Regularization loss.

    '''
    loss = 0
    for name, param in self.named_parameters():
        if 'weight' in name:
            loss += self.weight_reg[0] * torch.norm(param, p=self.weight_reg[1]) * int(param.requires_grad)
        elif 'coef' in name:
            loss += self.coef_reg[0] * torch.norm(param, p=self.coef_reg[1]) * int(param.requires_grad)

    return loss


In [None]:
class ModulationLayer(torch.nn.Module):
  '''
  Modulation layer that applies different modulation functions to the input.

  Args:
      window_len (int): Length of the input window.
      in_features (int): Number of input features.
      associated (bool): Whether the modulators are associated with each other.
      legendre_degree (int or None): Degree of the Legendre modulation function. If None, Legendre modulation is not applied.
      chebychev_degree (int or None): Degree of the Chebychev modulation function. If None, Chebychev modulation is not applied.
      dt (float): Time step for Fourier modulation.
      num_freqs (int or None): Number of frequencies for Fourier modulation. If None, Fourier modulation is not applied.
      freq_init (torch.Tensor or None): Initial frequencies for Fourier modulation. If None, frequencies are initialized uniformly.
      freq_train (bool): Whether to train the frequencies for Fourier modulation.
      phase_init (torch.Tensor or None): Initial phases for Fourier modulation. If None, phases are initialized as zeros.
      phase_train (bool): Whether to train the phases for Fourier modulation.
      num_sigmoids (int or None): Number of sigmoid functions for Sigmoid modulation. If None, Sigmoid modulation is not applied.
      slope_init (torch.Tensor or None): Initial slopes for Sigmoid modulation. If None, slopes are initialized from a normal distribution.
      slope_train (bool): Whether to train the slopes for Sigmoid modulation.
      shift_init (torch.Tensor or None): Initial shifts for Sigmoid modulation. If None, shifts are initialized from a uniform distribution.
      shift_train (bool): Whether to train the shifts for Sigmoid modulation.
      weight_reg (list): Regularization parameters for the linear function weights. [Regularization weight, regularization exponent]
      weight_norm (int): Norm to be used for weight regularization.
      zero_order (bool): Whether to include the zeroth-order term (constant) in the modulation functions.
      bias (bool): If True, adds a learnable bias to the linear function.
      pure (bool): If True, concatenates a constant term to the input.
      device (str): Device to use for computation ('cpu' or 'cuda').
      dtype (torch.dtype): Data type of the model parameters.

  '''

  def __init__(self, window_len, in_features, associated=False, legendre_degree=None, chebychev_degree=None,
               dt=1, num_freqs=None, freq_init=None, freq_train=True, phase_init=None, phase_train=True,
               num_sigmoids=None, slope_init=None, slope_train=True, shift_init=None, shift_train=True,
               weight_reg=[0.001, 1.], weight_norm=2, zero_order=True, bias=True, pure=False,
               device='cpu', dtype=torch.float32):

    super(ModulationLayer, self).__init__()

    legendre_idx, chebychev_idx, hermite_idx, fourier_idx, sigmoid_idx = None, None, None, None, None
    idx = 1

    num_modulators, m = 0, 0

    F = []

    modulators = torch.nn.ModuleList([])
    F_legendre, legendre_idx = None, None
    if legendre_degree is not None:
        m += 1
        F_legendre = LegendreModulator(window_len=window_len, scale=True, degree=legendre_degree,
                                       zero_order=zero_order, device=device, dtype=dtype)
        modulators.append(F_legendre)
        F.append(F_legendre.functions)
        legendre_idx = [m, torch.arange(idx, idx + F_legendre.num_modulators)]
        idx += F_legendre.num_modulators

    F_chebychev, chebychev_idx = None, None
    if chebychev_degree is not None:
        m += 1
        F_chebychev = ChebychevModulator(window_len=window_len, scale=True, kind=1, degree=chebychev_degree,
                                          zero_order=zero_order * (len(F) == 0), device=device, dtype=dtype)
        modulators.append(F_chebychev)
        F.append(F_chebychev.functions)
        chebychev_idx = [m, torch.arange(idx, idx + F_chebychev.num_modulators)]
        idx += F_chebychev.num_modulators

    F_fourier, fourier_idx = None, None
    if num_freqs is not None:
        m += 1
        F_fourier = FourierModulator(window_len=window_len, num_freqs=num_freqs, dt=dt,
                                      freq_init=freq_init, freq_train=freq_train,
                                      phase_init=phase_init, phase_train=phase_train,
                                      device=device, dtype=dtype)
        modulators.append(F_fourier)
        F.append(F_fourier.functions)
        fourier_idx = [m, torch.arange(idx, idx + F_fourier.num_modulators)]
        idx += F_fourier.num_modulators

    F_sigmoid, s, bs, sigmoid_idx = None, None, None, None
    if num_sigmoids is not None:
        m += 1
        F_sigmoid = SigmoidModulator(window_len=window_len, num_sigmoids=num_sigmoids,
                                      slope_init=slope_init, slope_train=slope_train,
                                      shift_init=shift_init, shift_train=shift_train,
                                      device=device, dtype=dtype)
        modulators.append(F_sigmoid)
        F.append(F_sigmoid.functions)
        sigmoid_idx = [m, torch.arange(idx, idx + F_sigmoid.num_modulators)]
        idx += F_sigmoid.num_modulators

    F = torch.cat(F, -1)

    num_modulators = F.shape[-1]

    linear_fn = HiddenLayer(in_features=in_features + int(pure),
                            out_features=num_modulators,
                            bias=bias,
                            activation='identity',
                            weight_reg=weight_reg,
                            weight_norm=weight_norm,
                            device=device, dtype=dtype)

    self.window_len = window_len
    self.in_features = in_features
    self.associated = associated
    self.weight_reg, self.weight_norm = weight_reg, weight_norm
    self.bias = bias
    self.modulators, self.num_modulators = modulators, num_modulators
    self.pure = pure

    self.linear_fn, self.F = linear_fn, F
    self.legendre_idx, self.chebychev_idx, self.hermite_idx, self.fourier_idx, self.sigmoid_idx = legendre_idx, chebychev_idx, hermite_idx, fourier_idx, sigmoid_idx
    self.dt = dt

    self.device, self.dtype = device, dtype

  def forward(self, input, steps):
    '''
    Perform a forward pass through the modulation layer.

    Args:
        input (torch.Tensor): Input tensor.
        steps (int): Index of the modulation step.

    Returns:
        torch.Tensor: Output tensor.

    '''
    num_samples, seq_len, input_size = input.shape

    if self.pure:
      input_ = torch.cat((torch.ones((num_samples, seq_len, 1)).to(device=self.device, dtype=self.dtype), input), -1).to(input)
    else:
      input_ = input

    output = self.F[steps] * self.linear_fn(input_)

    return output

  def constrain(self):
    '''
    Apply constraints to the modulation parameters.

    '''
    if self.weight is not None:
      self.weight.data = self.weight.data / self.weight.data.norm(self.weight_norm, dim=1, keepdim=True)
      self.weight.data = self.weight.data.sum(dim=1, keepdim=True).sign() * self.weight.data

    if self.fourier_idx is not None:
      self.modulators[self.fourier_idx[0]].f = self.modulators[self.fourier_idx[0]].f.data.clamp_(0, 1 / (2 * self.dt))
      self.modulators[self.fourier_idx[0]].p = self.modulators[self.fourier_idx[0]].p.data.clamp_(-torch.pi, torch.pi)

  def penalize(self):
    '''
    Compute the regularization penalty.

    Returns:
      float: Regularization penalty.

    '''
    penalty = 0.
    if self.weight is not None:
        penalty += self.weight_reg[0] * torch.norm(self.weight, p=self.weight_reg[1]) * int(self.weight.requires_grad)

    return penalty

class LegendreModulator(torch.nn.Module):
  '''
  Legendre modulation function.

  Args:
      window_len (int): Length of the input window.
      scale (bool): If True, scale the input to the range [0, 1].
      degree (int): Degree of the Legendre polynomial.
      zero_order (bool): If True, include the zeroth-order term (constant) in the modulation function.
      device (str): Device to use for computation ('cpu' or 'cuda').
      dtype (torch.dtype): Data type of the model parameters.

  '''

  def __init__(self, window_len, scale=True, degree=1, zero_order=True, device='cpu', dtype=torch.float32):
    super(LegendreModulator, self).__init__()

    self.degree = degree
    self.zero_order = zero_order
    self.num_modulators = degree + int(zero_order)

    self.device, self.dtype = device, dtype

    self.window_len = window_len
    self.scale = scale
    self.functions = self.generate_basis_functions()

  def generate_basis_functions(self):
    '''
    Generate the Legendre basis functions.

    Returns:
        torch.Tensor: Legendre basis functions.

    '''
    t = torch.arange(0, self.window_len).view(-1, 1).to(device=self.device, dtype=self.dtype)
    t = t / (t.max() - t.min()) if self.scale else t

    N = len(t)

    y = torch.zeros((N, (self.degree + 1))).to(device=self.device, dtype=self.dtype)

    for q in range(0, (self.degree + 1)):
        if q == 0:
            y[:, 0] = torch.ones((N,)).to(device=self.device, dtype=self.dtype)
        elif q == 1:
            y[:, 1:2] = t * y[:, 0:1]
        else:
            y[:, q:(q + 1)] = ((2 * q - 1) * t * y[:, (q - 1):q] - (q - 1) * y[:, (q - 2):(q - 1)]) / q

    if not self.zero_order:
        y = y[:, 1:]

    self.functions = y

    return y

  def forward(self, X, steps):
    '''
    Apply the Legendre modulation to the input.

    Args:
        X (torch.Tensor): Input tensor.
        steps (int): Index of the modulation step.

    Returns:
        torch.Tensor: Modulated tensor.

    '''
    X = X.to(device=self.device, dtype=self.dtype)

    y = X[:, :, None, :] * self.functions[steps]

    return y


class ChebychevModulator(torch.nn.Module):
  '''
  Chebychev modulation function.

  Args:
      window_len (int): Length of the input window.
      scale (bool): If True, scale the input to the range [0, 1].
      kind (int): Kind of Chebychev polynomial to use. Must be 1 or 2.
      degree (int): Degree of the Chebychev polynomial.
      zero_order (bool): If True, include the zeroth-order term (constant) in the modulation function.
      device (str): Device to use for computation ('cpu' or 'cuda').
      dtype (torch.dtype): Data type of the model parameters.

  '''

  def __init__(self, window_len, scale=True, kind=1, degree=1, zero_order=True, device='cpu', dtype=torch.float32):
    super(ChebychevModulator, self).__init__()

    self.degree = degree
    self.zero_order = zero_order
    self.num_modulators = degree + int(zero_order)

    self.device, self.dtype = device, dtype

    self.window_len = window_len
    self.scale = scale
    self.kind = kind
    self.functions = self.generate_basis_functions()

  def generate_basis_functions(self):
    '''
    Generate the Chebychev basis functions.

    Returns:
        torch.Tensor: Chebychev basis functions.

    '''
    t = torch.arange(0, self.window_len).view(-1, 1).to(device=self.device, dtype=self.dtype)
    t = t / (t.max() - t.min()) if self.scale else t

    N = len(t)

    y = torch.zeros((N, (self.degree + 1))).to(device=self.device, dtype=self.dtype)

    for q in range(0, (self.degree + 1)):
        if q == 0:
            y[:, 0] = torch.ones((N,)).to(device=self.device, dtype=self.dtype)
        elif q == 1:
            y[:, 1:2] = self.kind * t * y[:, 0:1]
        else:
            y[:, q:(q + 1)] = 2 * t * y[:, (q - 1):q] - y[:, (q - 2):(q - 1)]

    if not self.zero_order:
        y = y[:, 1:]

    self.functions = y

    return y

  def forward(self, X, steps):
    '''
    Apply the Chebychev modulation to the input.

    Args:
        X (torch.Tensor): Input tensor.
        steps (int): Index of the modulation step.

    Returns:
        torch.Tensor: Modulated tensor.

    '''
    X = X.to(device=self.device, dtype=self.dtype)

    y = X[:, :, None, :] * self.functions[steps]

    return y

class FourierModulator(torch.nn.Module):
  '''
  Fourier modulation function.

  Args:
    window_len (int): Length of the input window.
    num_freqs (int): Number of frequencies.
    dt (float): Time step.
    freq_init (torch.Tensor or None): Initial frequencies. If None, frequencies are initialized uniformly.
    freq_train (bool): Whether to train the frequencies.
    phase_init (torch.Tensor or None): Initial phases. If None, phases are initialized as zeros.
    phase_train (bool): Whether to train the phases.
    device (str): Device to use for computation ('cpu' or 'cuda').
    dtype (torch.dtype): Data type of the model parameters.

  '''

  def __init__(self,
               window_len, num_freqs, dt=1, freq_init=None, freq_train=True, phase_init=None,
               phase_train=True, device='cpu', dtype=torch.float32):
    super(FourierModulator, self).__init__()

    if freq_init is None:
        freq_init = ((1 / dt) / 4) * torch.ones(size=(1, num_freqs))
    else:
        freq_init = freq_init

    if phase_init is None:
        phase_init = torch.zeros(size=(1, num_freqs))
    else:
        phase_init = phase_init

    freq = torch.nn.Parameter(data=freq_init.to(device=device, dtype=dtype), requires_grad=freq_train)
    phase = torch.nn.Parameter(data=phase_init.to(device=device, dtype=dtype), requires_grad=phase_train)

    self.dt = dt
    self.window_len = window_len
    self.freq, self.phase = freq, phase
    self.num_modulators = num_freqs
    self.device, self.dtype = device, dtype

    self.generate_basis_functions()

  def generate_basis_functions(self):
    '''
    Generate the Fourier basis functions.

    Returns:
        torch.Tensor: Fourier basis functions.

    '''
    t = self.dt * torch.arange(0, self.window_len).view(-1, 1).to(device=self.device, dtype=self.dtype)

    y = torch.sin(2 * torch.pi * t * self.freq + self.phase)

    self.functions = y

    return y

  def forward(self, X, steps):
    '''
    Apply the Fourier modulation to the input.

    Args:
        X (torch.Tensor): Input tensor.
        steps (int): Index of the modulation step.

    Returns:
        torch.Tensor: Modulated tensor.

    '''
    X = X.to(device=self.device, dtype=self.dtype)

    self.functions = self.generate_basis_functions()

    y = X[:, :, None, :] * self.functions[steps]

    self.functions = y

    return y

class SigmoidModulator(torch.nn.Module):
  '''
  Sigmoid modulation function.

  Args:
      window_len (int): Length of the input window.
      num_sigmoids (int): Number of sigmoid functions.
      scale (bool): If True, scale the input to the range [0, 1].
      slope_init (torch.Tensor or None): Initial slopes. If None, slopes are initialized from a normal distribution.
      slope_train (bool): Whether to train the slopes.
      shift_init (torch.Tensor or None): Initial shifts. If None, shifts are initialized from a uniform distribution.
      shift_train (bool): Whether to train the shifts.
      device (str): Device to use for computation ('cpu' or 'cuda').
      dtype (torch.dtype): Data type of the model parameters.

  '''

  def __init__(self, window_len, num_sigmoids, scale=True, slope_init=None, slope_train=True,
                shift_init=None, shift_train=True, device='cpu', dtype=torch.float32):
    super(SigmoidModulator, self).__init__()

    if slope_init is None:
        slope_init = torch.nn.init.normal_(torch.empty((1, num_sigmoids)), mean=0, std=1 / window_len)
    slope = torch.nn.Parameter(data=slope_init.to(device=device, dtype=dtype), requires_grad=slope_train)

    if shift_init is None:
        shift_init = torch.nn.init.uniform_(torch.empty((1, num_sigmoids)), a=-1, b=1)
    shift = torch.nn.Parameter(data=shift_init.to(device=device, dtype=dtype), requires_grad=shift_train)

    self.window_len = window_len
    self.num_modulators = num_sigmoids
    self.scale = scale
    self.slope, self.shift = slope, shift
    self.device, self.dtype = device, dtype

    self.functions = self.generate_basis_functions()

  def generate_basis_functions(self):
    '''
    Generate the sigmoid basis functions.

    Returns:
        torch.Tensor: Sigmoid basis functions.

    '''
    t = torch.arange(0, self.window_len).view(-1, 1).to(device=self.device, dtype=self.dtype)

    scaler = (t.max() - t.min()) if self.scale else 1

    y = 1 / (1 + torch.exp(-self.slope * (t - self.shift * scaler)))

    self.functions = y

    return y

  def forward(self, X, steps):
    '''
    Apply the sigmoid modulation to the input.

    Args:
        X (torch.Tensor): Input tensor.
        steps (int): Index of the modulation step.

    Returns:
        torch.Tensor: Modulated tensor.

    '''
    X = X.to(device=self.device, dtype=self.dtype)

    self.functions = self.generate_basis_functions()

    print(X.shape)
    print(self.functions[steps].shape)

    y = X[:, :, None, :] * self.functions[steps]

    return y


In [None]:
class Attention(torch.nn.MultiheadAttention):
  '''
  Custom attention layer based on the torch.nn.MultiheadAttention module.
  This layer supports different types of attention mechanisms: dot, general, and concat.

  Args:
      embed_dim (int): The input embedding dimension.
      num_heads (int): The number of attention heads.
      query_dim (int, optional): The query embedding dimension. Defaults to None (same as embed_dim).
      key_dim (int, optional): The key embedding dimension. Defaults to None (same as embed_dim).
      value_dim (int, optional): The value embedding dimension. Defaults to None (same as embed_dim).
      attn_type (str, optional): The attention type. Options: 'dot', 'general', 'concat'. Defaults to 'dot'.
      query_weight_reg (List[float], optional): The regularization weights for the query projection layer. Defaults to [0.001, 1].
      query_weight_norm (float, optional): The normalization type for the query projection layer. Defaults to 2.
      query_bias (bool, optional): Whether to include bias in the query projection layer. Defaults to False.
      key_weight_reg (List[float], optional): The regularization weights for the key projection layer. Defaults to [0.001, 1].
      key_weight_norm (float, optional): The normalization type for the key projection layer. Defaults to 2.
      key_bias (bool, optional): Whether to include bias in the key projection layer. Defaults to False.
      value_weight_reg (List[float], optional): The regularization weights for the value projection layer. Defaults to [0.001, 1].
      value_weight_norm (float, optional): The normalization type for the value projection layer. Defaults to 2.
      value_bias (bool, optional): Whether to include bias in the value projection layer. Defaults to False.
      gen_weight_reg (List[float], optional): The regularization weights for the generation weights (concat type). Defaults to [0.001, 1].
      gen_weight_norm (float, optional): The normalization type for the generation weights (concat type). Defaults to 2.
      gen_bias (bool, optional): Whether to include bias in the generation weights (concat type). Defaults to False.
      concat_weight_reg (List[float], optional): The regularization weights for the concatenation layer (concat type). Defaults to [0.001, 1].
      concat_weight_norm (float, optional): The normalization type for the concatenation layer (concat type). Defaults to 2.
      concat_bias (bool, optional): Whether to include bias in the concatenation layer (concat type). Defaults to False.
      average_attn_weights (bool, optional): Whether to average the attention weights across heads. Defaults to False.
      is_causal (bool, optional): Whether the attention is causal (supports autoregressive property). Defaults to False.
      dropout_p (float, optional): The dropout probability. Defaults to 0.0.
      device (str, optional): The device for the computation. Defaults to 'cpu'.
      dtype (torch.dtype, optional): The data type. Defaults to torch.float32.
  '''

  def __init__(self,
               embed_dim, num_heads=1,
               query_dim=None, key_dim=None, value_dim=None,
               attn_type="dot",
               query_weight_reg=[0.001, 1], query_weight_norm=2, query_bias=False,
               key_weight_reg=[0.001, 1], key_weight_norm=2, key_bias=False,
               value_weight_reg=[0.001, 1], value_weight_norm=2, value_bias=False,
               gen_weight_reg=[0.001, 1], gen_weight_norm=2, gen_bias=False,
               concat_weight_reg=[0.001, 1], concat_weight_norm=2, concat_bias=False,
               average_attn_weights=False,
               is_causal=False,
               dropout_p=0.0,
               device="cpu",
               dtype=torch.float32):

      super(Attention, self).__init__(embed_dim=embed_dim, num_heads=num_heads)

      # Choose the appropriate score function based on the attention type
      if attn_type == "dot":
          self.score_fn = self.dot_fn
      elif attn_type == "general":
          self.score_fn = self.general_fn
      elif attn_type == "concat":
          self.score_fn = self.concat_fn

      query_dim = query_dim or embed_dim
      key_dim = key_dim or embed_dim
      value_dim = value_dim or embed_dim

      query_blocks = torch.nn.ModuleList([])
      key_blocks = torch.nn.ModuleList([])
      value_blocks = torch.nn.ModuleList([])
      gen_blocks = torch.nn.ModuleList([])
      concat_blocks = torch.nn.ModuleList([])

      head_dims = np.round(embed_dim / num_heads).astype(int).repeat(num_heads - 1).tolist()
      head_dims += [int(embed_dim - np.sum(head_dims))]

      for dim in head_dims:
        query_blocks.append(HiddenLayer(in_features=embed_dim,
                                        out_features=dim,
                                        bias=query_bias,
                                        activation="identity",
                                        weight_reg=query_weight_reg,
                                        weight_norm=query_weight_norm,
                                        device=device,
                                        dtype=dtype))
        key_blocks.append(HiddenLayer(in_features=embed_dim,
                                      out_features=dim,
                                      bias=key_bias,
                                      activation="identity",
                                      weight_reg=key_weight_reg,
                                      weight_norm=key_weight_norm,
                                      device=device,
                                      dtype=dtype))
        value_blocks.append(HiddenLayer(in_features=embed_dim,
                                        out_features=dim,
                                        bias=value_bias,
                                        activation="identity",
                                        weight_reg=value_weight_reg,
                                        weight_norm=value_weight_norm,
                                        device=device,
                                        dtype=dtype))

        if attn_type == "general":
          gen_blocks.append(
              HiddenLayer(in_features=[dim, dim],
                          out_features=1,
                          bias=gen_bias,
                          activation="identity",
                          weight_reg=gen_weight_reg,
                          weight_norm=gen_weight_norm,
                          device=device,
                          dtype=dtype))

        if attn_type == "concat":
          concat_blocks.append(torch.nn.Sequential(*[HiddenLayer(in_features=2 * dim,
                                                                  out_features=dim,
                                                                  bias=concat_bias,
                                                                  activation="tanh",
                                                                  weight_reg=concat_weight_reg,
                                                                  weight_norm=concat_weight_norm,
                                                                  device=device,
                                                                  dtype=dtype),
                                                      HiddenLayer(in_features=dim,
                                                                  out_features=1,
                                                                  bias=concat_bias,
                                                                  activation="identity",
                                                                  weight_reg=concat_weight_reg,
                                                                  weight_norm=concat_weight_norm,
                                                                  device=device,
                                                                  dtype=dtype)]))

      self.embed_dim = embed_dim
      self.num_heads = num_heads
      self.attn_type = attn_type
      self.is_causal = is_causal

      self.query_blocks = query_blocks
      self.key_blocks = key_blocks
      self.value_blocks = value_blocks

      self.dropout = torch.nn.Dropout(dropout_p)

      self.query_weight_reg = query_weight_reg
      self.weight_norm = query_weight_norm
      self.key_weight_reg = key_weight_reg
      self.key_norm = key_weight_norm
      self.value_weight_reg = value_weight_reg
      self.value_norm = value_weight_norm

      self.gen_blocks = gen_blocks
      self.concat_blocks = concat_blocks

      self.gen_weight_reg = gen_weight_reg
      self.gen_norm = gen_weight_norm
      self.concat_weight_reg = concat_weight_reg
      self.concat_norm = concat_weight_norm

      self.average_attn_weights = average_attn_weights

      self.device = device
      self.dtype = dtype

  def dot_fn(self, query, key, block_idx):
    '''
    Compute the dot-product attention score between query and key.

    Args:
        query (torch.Tensor): The query tensor of shape (num_samples, query_len, query_dim).
        key (torch.Tensor): The key tensor of shape (num_samples, key_len, key_dim).
        block_idx (int): The index of the attention block.

    Returns:
        torch.Tensor: The attention score tensor of shape (num_samples, query_len, key_len).
    '''
    score = (torch.bmm(query, key.transpose(-2, -1)) / torch.math.sqrt(query.shape[-1])).transpose(-1, -2)
    return score

  def general_fn(self, query, key, block_idx):
    '''
    Compute the general attention score between query and key.

    Args:
        query (torch.Tensor): The query tensor of shape (num_samples, query_len, query_dim).
        key (torch.Tensor): The key tensor of shape (num_samples, key_len, key_dim).
        block_idx (int): The index of the attention block.

    Returns:
        torch.Tensor: The attention score tensor of shape (num_samples, query_len, key_len).
    '''
    if query.shape[1] == 1:
        query = query.repeat(1, key.shape[1], 1)

    score = self.gen_blocks[block_idx]((query, key))

    return score

  def concat_fn(self, query, key, block_idx):
    '''
    Compute the concat attention score between query and key.

    Args:
        query (torch.Tensor): The query tensor of shape (num_samples, query_len, query_dim).
        key (torch.Tensor): The key tensor of shape (num_samples, key_len, key_dim).
        block_idx (int): The index of the attention block.

    Returns:
        torch.Tensor: The attention score tensor of shape (num_samples, query_len, key_len).
    '''
    if query.shape[1] == 1:
        query = query.repeat(1, key.shape[1], 1)

    score = self.concat_blocks[block_idx](torch.cat((query, key), -1))
    return score

  def forward(self, query, key, value, attn_mask=None):
    '''
    Perform the forward pass of the attention layer.

    Args:
        query (torch.Tensor): The query tensor of shape (num_samples, query_len, query_dim).
        key (torch.Tensor): The key tensor of shape (num_samples, key_len, key_dim).
        value (torch.Tensor): The value tensor of shape (num_samples, value_len, value_dim).
        attn_mask (torch.Tensor, optional): The attention mask tensor of shape (query_len, key_len)
            or (num_samples, num_heads, query_len, key_len). Defaults to None.

    Returns:
        torch.Tensor: The output tensor of shape (num_samples, query_len, value_dim).
    '''
    num_samples, query_len, query_dim = query.shape
    _, key_len, key_dim = key.shape
    _, value_len, value_dim = value.shape

    ones = torch.ones((query_len, key_len), device=self.device, dtype=torch.bool)

    attn_mask = ones.tril(diagonal=0).transpose(-2, -1) if self.is_causal else ones
    attn_mask = attn_mask.to(query).masked_fill(~attn_mask, -float('inf')) if attn_mask.dtype == torch.bool else attn_mask

    output, weight = [], []
    for block_idx, (query_block, key_block, value_block) in enumerate(zip(self.query_blocks, self.key_blocks, self.value_blocks)):

      query_h, key_h, value_h = query_block(query), key_block(key), value_block(value)

      score_h = self.score_fn(query_h, key_h, block_idx)

      weight_h = torch.softmax(score_h + attn_mask, dim=1)

      output_h = torch.bmm(weight_h.transpose(-2, -1), value_h)

      weight.append(weight_h)
      output.append(output_h)

    output = self.dropout(torch.cat(output, -1))
    weight = torch.cat(weight, 1)

    if self.average_attn_weights:
      weight = weight.mean(1)

    self.weight = weight

    return output

  def penalize(self):
    '''
    Compute the regularization loss for the attention layer.

    Returns:
        torch.Tensor: The regularization loss.
    '''
    loss = 0
    for name, param in self.named_parameters():
      if 'weight' in name:
        if 'query' in name:
          loss += self.query_weight_reg[0] * torch.norm(param, p=self.query_weight_reg[1]) * int(param.requires_grad)
        elif 'key' in name:
          loss += self.key_weight_reg[0] * torch.norm(param, p=self.key_weight_reg[1]) * int(param.requires_grad)
        elif 'value' in name:
          loss += self.value_weight_reg[0] * torch.norm(param, p=self.value_weight_reg[1]) * int(param.requires_grad)
        elif 'gen' in name:
          loss += self.gen_weight_reg[0] * torch.norm(param, p=self.gen_weight_reg[1]) * int(param.requires_grad)
        elif 'concat' in name:
          loss += self.concat_weight_reg[0] * torch.norm(param, p=self.concat_weight_reg[1]) * int(param.requires_grad)

    return loss


In [None]:
class TransformerEncoderLayer(torch.nn.TransformerEncoderLayer):
    '''
    Customized Transformer Encoder Layer with optional modifications.

    Args:
      d_model (int): The input and output feature dimension.
      nhead (int, optional): The number of attention heads. Defaults to 1.
      dim_feedforward (int, optional): The hidden dimension of the feedforward network. Defaults to 2048.
      self_attn_type (str, optional): The type of self-attention. Choices: 'dot', 'general', 'concat'. Defaults to 'dot'.
      is_causal (bool, optional): Whether to use causal self-attention. Defaults to False.
      query_weight_reg (list[float], optional): Regularization weights for the query weights. Defaults to [0.001, 1].
      query_weight_norm (int, optional): Norm type for the query weights. Defaults to 2.
      query_bias (bool, optional): Whether to use bias in the query weights. Defaults to False.
      key_weight_reg (list[float], optional): Regularization weights for the key weights. Defaults to [0.001, 1].
      key_weight_norm (int, optional): Norm type for the key weights. Defaults to 2.
      key_bias (bool, optional): Whether to use bias in the key weights. Defaults to False.
      value_weight_reg (list[float], optional): Regularization weights for the value weights. Defaults to [0.001, 1].
      value_weight_norm (int, optional): Norm type for the value weights. Defaults to 2.
      value_bias (bool, optional): Whether to use bias in the value weights. Defaults to False.
      gen_weight_reg (list[float], optional): Regularization weights for the generator weights. Defaults to [0.001, 1].
      gen_weight_norm (int, optional): Norm type for the generator weights. Defaults to 2.
      gen_bias (bool, optional): Whether to use bias in the generator weights. Defaults to False.
      concat_weight_reg (list[float], optional): Regularization weights for the concatenator weights. Defaults to [0.001, 1].
      concat_weight_norm (int, optional): Norm type for the concatenator weights. Defaults to 2.
      concat_bias (bool, optional): Whether to use bias in the concatenator weights. Defaults to False.
      average_attn_weights (bool, optional): Whether to average the attention weights. Defaults to False.
      dropout_p (float, optional): Dropout probability for the attention and feedforward layers. Defaults to 0.0.
      dropout1_p (float, optional): Dropout probability for the first dropout layer in the feedforward network. Defaults to 0.0.
      dropout2_p (float, optional): Dropout probability for the second dropout layer in the feedforward network. Defaults to 0.0.
      linear1_bias (bool, optional): Whether to use bias in the first linear layer of the feedforward network. Defaults to False.
      linear2_bias (bool, optional): Whether to use bias in the second linear layer of the feedforward network. Defaults to False.
      linear1_weight_reg (list[float], optional): Regularization weights for the first linear layer weights. Defaults to [0.001, 1].
      linear1_weight_norm (int, optional): Norm type for the first linear layer weights. Defaults to 2.
      linear2_weight_reg (list[float], optional): Regularization weights for the second linear layer weights. Defaults to [0.001, 1].
      linear2_weight_norm (int, optional): Norm type for the second linear layer weights. Defaults to 2.
      feedforward_activation (str, optional): The activation function in the feedforward network. Choices: 'identity', 'relu', 'gelu', 'polynomial'. Defaults to 'relu'.
      degree (int, optional): The degree of the polynomial activation function. Only applicable when feedforward_activation='polynomial'. Defaults to 2.
      coef_init (torch.Tensor, optional): The initial coefficients for the polynomial activation function. Only applicable when feedforward_activation='polynomial'. Defaults to None.
      coef_train (bool, optional): Whether to train the coefficients for the polynomial activation function. Only applicable when feedforward_activation='polynomial'. Defaults to True.
      coef_reg (list[float], optional): Regularization weights for the polynomial coefficients. Only applicable when feedforward_activation='polynomial'. Defaults to [0.001, 1.].
      zero_order (bool, optional): Whether to include the zero-order term in the polynomial activation function. Only applicable when feedforward_activation='polynomial'. Defaults to False.
      scale_self_attn_residual_connection (bool, optional): Whether to scale the self-attention residual connection. Defaults to False.
      scale_feedforward_residual_connection (bool, optional): Whether to scale the feedforward residual connection. Defaults to False.
      device (str, optional): The device to run the layer on. Defaults to 'cpu'.
      dtype (torch.dtype, optional): The data type. Defaults to torch.float32.
    '''

    def __init__(self,
                d_model, nhead=1, dim_feedforward=2048,
                self_attn_type='dot',
                is_causal=False,
                query_weight_reg=[0.001, 1], query_weight_norm=2, query_bias=False,
                key_weight_reg=[0.001, 1], key_weight_norm=2, key_bias=False,
                value_weight_reg=[0.001, 1], value_weight_norm=2, value_bias=False,
                gen_weight_reg=[0.001, 1], gen_weight_norm=2, gen_bias=False,
                concat_weight_reg=[0.001, 1], concat_weight_norm=2, concat_bias=False,
                average_attn_weights=False,
                dropout_p=0.0, dropout1_p=0.0, dropout2_p=0.0,
                linear1_bias=False, linear2_bias=False,
                linear1_weight_reg=[0.001, 1], linear1_weight_norm=2,
                linear2_weight_reg=[0.001, 1], linear2_weight_norm=2,
                feedforward_activation='relu',
                degree=2,
                coef_init=None, coef_train=True, coef_reg=[0.001, 1.],
                zero_order=False,
                scale_self_attn_residual_connection=False,
                scale_feedforward_residual_connection=False,
                device='cpu', dtype=torch.float32):

        super(TransformerEncoderLayer, self).__init__(d_model=d_model,
                                                      nhead=nhead,
                                                      dim_feedforward=dim_feedforward,
                                                      device=device,
                                                      dtype=dtype)

        self.dropout.p = dropout_p

        self.self_attn = Attention(embed_dim=d_model,
                                  num_heads=nhead,
                                  attn_type=self_attn_type,
                                  query_weight_reg=query_weight_reg,
                                  query_weight_norm=query_weight_norm,
                                  query_bias=query_bias,
                                  key_weight_reg=key_weight_reg,
                                  key_weight_norm=key_weight_norm,
                                  key_bias=key_bias,
                                  value_weight_reg=value_weight_reg,
                                  value_weight_norm=value_weight_norm,
                                  value_bias=value_bias,
                                  gen_weight_reg=gen_weight_reg,
                                  gen_weight_norm=gen_weight_norm,
                                  gen_bias=gen_bias,
                                  concat_weight_reg=concat_weight_reg,
                                  concat_weight_norm=concat_weight_norm,
                                  concat_bias=concat_bias,
                                  average_attn_weights=average_attn_weights,
                                  is_causal=is_causal,
                                  dropout_p=dropout_p,
                                  device=device,
                                  dtype=dtype)

        self.dropout1.p = dropout1_p
        self.dropout2.p = dropout2_p

        if feedforward_activation == 'identity':
          self.activation = torch.nn.Identity()
          self.linear2 = torch.nn.Identity()
          self.norm2 = torch.nn.Identity()
          self.dropout2 = torch.nn.Identity()
        elif feedforward_activation == 'relu':
          self.activation = torch.nn.ReLU()
        elif feedforward_activation == 'gelu':
          self.activation = torch.nn.GELU()
        elif feedforward_activation == 'polynomial':
          self.activation = Polynomial(in_features=dim_feedforward,
                                        degree=degree,
                                        coef_init=coef_init,
                                        coef_train=coef_train,
                                        coef_reg=coef_reg,
                                        zero_order=zero_order,
                                        device=device,
                                        dtype=dtype)

        if not linear1_bias:
          self.linear1.bias = None

        self.linear1_weight_reg = linear1_weight_reg
        self.linear1_weight_norm = linear1_weight_norm

        if not isinstance(self.linear2, torch.nn.Identity):
          if not linear2_bias:
            self.linear2.bias = None
          self.linear2_weight_reg = linear2_weight_reg
          self.linear2_weight_norm = linear2_weight_norm

        if scale_self_attn_residual_connection:
            self.self_attn_residual_scaler = torch.nn.Linear(in_features=d_model, out_features=1).weight.squeeze().to(device=device, dtype=dtype)
        else:
            self.self_attn_residual_scaler = torch.ones((d_model,)).to(device=device, dtype=dtype)

        if scale_feedforward_residual_connection:
            self.feedforward_residual_scaler = torch.nn.Linear(in_features=d_model, out_features=1).weight.squeeze().to(device=device, dtype=dtype)
        else:
            self.feedforward_residual_scaler = torch.ones((d_model,)).to(device=device, dtype=dtype)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
      '''
      Forward pass of the transformer encoder layer.

      Args:
          src (torch.Tensor): The input sequence of shape (seq_len, batch_size, d_model).
          src_mask (torch.Tensor, optional): The mask to apply to the source sequence. Defaults to None.
          src_key_padding_mask (torch.Tensor, optional): The padding mask for the source sequence. Defaults to None.
          is_causal (bool, optional): Whether to use causal self-attention. Defaults to False.

      Returns:
          torch.Tensor: The output sequence of shape (seq_len, batch_size, d_model).
      '''

      # Generate self-attn output (dropout applied inside) and add residual connection (scale if desired)
      src = self.self_attn(src, src, src, src_mask) + self.self_attn_residual_scaler * src

      # Normalize self-attn sub-layer
      src = self.norm1(src)

      # Generate ff output and add residual connection (scale if desired)
      src = self.dropout2(self.linear2(self.dropout1(self.activation(self.linear1(src))))) + self.feedforward_residual_scaler * src

      src = self.norm2(src)

      return src

    def penalize(self):
      '''
      Calculate the regularization loss for the transformer encoder layer.

      Returns:
          torch.Tensor: The regularization loss.
      '''
      loss = 0
      if self.self_attn is not None:
        loss += self.self_attn.penalize()
      loss += (self.linear1_weight_reg[0]
               * torch.norm(self.linear1.weight, p=self.linear1_weight_reg[1])
               * int(self.linear1.weight.requires_grad))
      loss += (self.linear2_weight_reg[0]
               * torch.norm(self.linear2.weight, p=self.linear2_weight_reg[1])
               * int(self.linear2.weight.requires_grad))

      return loss


In [None]:
class TransformerDecoderLayer(torch.nn.TransformerDecoderLayer):
  '''
  Transformer Decoder Layer module that extends `torch.nn.TransformerDecoderLayer`.

  Args:
    d_model (int): The number of expected features in the input.
    nhead (int, optional): The number of heads in the multihead attention models. Defaults to 1.
    dim_feedforward (int, optional): The dimension of the feedforward network model. Defaults to 2048.
    self_attn_type (str, optional): The self-attention type. Defaults to 'dot'.
    multihead_attn_type (str, optional): The multihead attention type. Defaults to 'dot'.
    memory_is_causal (bool, optional): Whether the memory sequence is causal. Defaults to False.
    tgt_is_causal (bool, optional): Whether the target sequence is causal. Defaults to True.
    query_weight_reg (list, optional): Regularization parameters for query weight. Defaults to [0.001, 1].
    query_weight_norm (int, optional): Norm type for query weight regularization. Defaults to 2.
    query_bias (bool, optional): Whether to include bias in query weight. Defaults to False.
    key_weight_reg (list, optional): Regularization parameters for key weight. Defaults to [0.001, 1].
    key_weight_norm (int, optional): Norm type for key weight regularization. Defaults to 2.
    key_bias (bool, optional): Whether to include bias in key weight. Defaults to False.
    value_weight_reg (list, optional): Regularization parameters for value weight. Defaults to [0.001, 1].
    value_weight_norm (int, optional): Norm type for value weight regularization. Defaults to 2.
    value_bias (bool, optional): Whether to include bias in value weight. Defaults to False.
    gen_weight_reg (list, optional): Regularization parameters for generation weight. Defaults to [0.001, 1].
    gen_weight_norm (int, optional): Norm type for generation weight regularization. Defaults to 2.
    concat_weight_reg (list, optional): Regularization parameters for concatenation weight. Defaults to [0.001, 1].
    concat_weight_norm (int, optional): Norm type for concatenation weight regularization. Defaults to 2.
    concat_bias (bool, optional): Whether to include bias in concatenation weight. Defaults to False.
    average_attn_weights (bool, optional): Whether to average the attention weights. Defaults to False.
    dropout_p (float, optional): Probability of an element to be zeroed. Defaults to 0.
    dropout1_p (float, optional): Probability of an element of the first dropout layer to be zeroed. Defaults to 0.
    dropout2_p (float, optional): Probability of an element of the second dropout layer to be zeroed. Defaults to 0.
    dropout3_p (float, optional): Probability of an element of the third dropout layer to be zeroed. Defaults to 0.
    linear1_bias (bool, optional): Whether to include bias in the first linear layer. Defaults to False.
    linear2_bias (bool, optional): Whether to include bias in the second linear layer. Defaults to False.
    linear1_weight_reg (list, optional): Regularization parameters for the first linear layer weight. Defaults to [0.001, 1].
    linear1_weight_norm (int, optional): Norm type for the first linear layer weight regularization. Defaults to 2.
    linear2_weight_reg (list, optional): Regularization parameters for the second linear layer weight. Defaults to [0.001, 1].
    linear2_weight_norm (int, optional): Norm type for the second linear layer weight regularization. Defaults to 2.
    feedforward_activation (str, optional): Type of activation function for the feedforward network. Defaults to 'relu'.
    degree (int, optional): Degree of the polynomial activation function. Defaults to 2.
    coef_init (torch.Tensor, optional): Initial coefficients for the polynomial activation function. Defaults to None.
    coef_train (bool, optional): Whether to train the coefficients of the polynomial activation function. Defaults to True.
    coef_reg (list, optional): Regularization parameters for the polynomial activation function coefficients. Defaults to [0.001, 1.].
    zero_order (bool, optional): Whether to include the zero-order term in the polynomial activation function. Defaults to False.
    scale_self_attn_residual_connection (bool, optional): Whether to scale the self-attention residual connection. Defaults to False.
    scale_cross_attn_residual_connection (bool, optional): Whether to scale the cross-attention residual connection. Defaults to False.
    scale_feedforward_residual_connection (bool, optional): Whether to scale the feedforward residual connection. Defaults to False.
    device (str, optional): Device on which to allocate tensors. Defaults to 'cpu'.
    dtype (torch.dtype, optional): Desired data type of the tensor. Defaults to torch.float32.
  '''

  def __init__(self,
               d_model, nhead=1, dim_feedforward=2048,
               self_attn_type="dot", multihead_attn_type="dot",
               memory_is_causal=False, tgt_is_causal=True,
               query_weight_reg=[0.001, 1], query_weight_norm=2, query_bias=False,
               key_weight_reg=[0.001, 1], key_weight_norm=2, key_bias=False,
               value_weight_reg=[0.001, 1], value_weight_norm=2, value_bias=False,
               gen_weight_reg=[0.001, 1], gen_weight_norm=2, gen_bias = False,
               concat_weight_reg=[0.001, 1], concat_weight_norm=2, concat_bias=False,
               average_attn_weights=False,
               dropout_p=0.0, dropout1_p=0.0, dropout2_p=0.0, dropout3_p=0.0,
               linear1_bias=False, linear2_bias=False,
               linear1_weight_reg=[0.001, 1], linear1_weight_norm=2,
               linear2_weight_reg=[0.001, 1], linear2_weight_norm=2,
               feedforward_activation="relu",
               degree=2,
               coef_init=None, coef_train=True, coef_reg=[0.001, 1.],
               zero_order=False,
               scale_self_attn_residual_connection=False,
               scale_cross_attn_residual_connection=False,
               scale_feedforward_residual_connection=False,
               device="cpu", dtype=torch.float32):

      super(TransformerDecoderLayer, self).__init__(d_model=d_model,
                                                    nhead=nhead,
                                                    dim_feedforward=dim_feedforward,
                                                    device=device,
                                                    dtype=dtype)

      self.dropout.p = dropout_p

      self.self_attn = Attention(embed_dim=d_model,
                                 num_heads=nhead,
                                 attn_type=self_attn_type,
                                 query_weight_reg=query_weight_reg,
                                 query_weight_norm=query_weight_norm,
                                 query_bias=query_bias,
                                 key_weight_reg=key_weight_reg,
                                 key_weight_norm=key_weight_norm,
                                 key_bias=key_bias,
                                 value_weight_reg=value_weight_reg,
                                 value_weight_norm=value_weight_norm,
                                 value_bias=value_bias,
                                 gen_weight_reg=gen_weight_reg,
                                 gen_weight_norm=gen_weight_norm,
                                 gen_bias=gen_bias,
                                 concat_weight_reg=concat_weight_reg,
                                 concat_weight_norm=concat_weight_norm,
                                 concat_bias=concat_bias,
                                 average_attn_weights=average_attn_weights,
                                 is_causal=memory_is_causal,
                                 dropout_p=dropout_p,
                                 device=device,
                                 dtype=dtype)

      self.multihead_attn = Attention(embed_dim=d_model,
                                      num_heads=nhead,
                                      attn_type=multihead_attn_type,
                                      query_weight_reg=query_weight_reg,
                                      query_weight_norm=query_weight_norm,
                                      query_bias=query_bias,
                                      key_weight_reg=key_weight_reg,
                                      key_weight_norm=key_weight_norm,
                                      key_bias=key_bias,
                                      value_weight_reg=value_weight_reg,
                                      value_weight_norm=value_weight_norm,
                                      value_bias=value_bias,
                                      gen_weight_reg=gen_weight_reg,
                                      gen_weight_norm=gen_weight_norm,
                                      gen_bias=gen_bias,
                                      concat_weight_reg=concat_weight_reg,
                                      concat_weight_norm=concat_weight_norm,
                                      concat_bias=concat_bias,
                                      average_attn_weights=average_attn_weights,
                                      is_causal=tgt_is_causal,
                                      dropout_p=dropout1_p,
                                      device=device,
                                      dtype=dtype)

      self.dropout2.p = dropout2_p
      self.dropout3.p = dropout3_p

      if feedforward_activation == "identity":
        self.activation = torch.nn.Identity()
        self.linear2 = torch.nn.Identity()
        self.norm3 = torch.nn.Identity()
        self.dropout3 = torch.nn.Identity()
      elif feedforward_activation == "relu":
        self.activation = torch.nn.ReLU()
      elif feedforward_activation == "gelu":
        self.activation = torch.nn.GELU()
      elif feedforward_activation == "polynomial":
        self.activation = Polynomial(in_features=dim_feedforward,
                                    degree=degree,
                                    coef_init=coef_init,
                                    coef_train=coef_train,
                                    coef_reg=coef_reg,
                                    zero_order=zero_order,
                                    device=device,
                                    dtype=dtype)

      self.linear1.bias = None if not linear1_bias else self.linear1.bias
      self.linear1_weight_reg, self.linear1_weight_norm = (linear1_weight_reg, linear1_weight_norm)

      if not isinstance(self.linear2, torch.nn.Identity):
          self.linear2.bias = None if not linear2_bias else self.linear2.bias
          self.linear2_weight_reg, self.linear2_weight_norm = (linear2_weight_reg, linear2_weight_norm)

      self.self_attn_residual_scaler = (torch.nn.Linear(in_features=d_model, out_features=1).weight.squeeze().to(device=device, dtype=dtype)
                                        if scale_self_attn_residual_connection
                                        else torch.ones((d_model,)).to(device=device, dtype=dtype))

      self.cross_attn_residual_scaler = (torch.nn.Linear(in_features=d_model, out_features=1).weight.squeeze().to(device=device, dtype=dtype)
                                          if scale_cross_attn_residual_connection
                                          else torch.ones((d_model,)).to(device=device, dtype=dtype))

      self.feedforward_residual_scaler = (torch.nn.Linear(in_features=d_model, out_features=1).weight.squeeze().to(device=device, dtype=dtype)
                                          if scale_feedforward_residual_connection
                                          else torch.ones((d_model,)).to(device=device, dtype=dtype))

  def forward(self,
              tgt, memory,
              tgt_mask=None, memory_mask=None,
              tgt_key_padding_mask=None, memory_key_padding_mask=None):
      '''
      Forward pass of the Transformer Decoder Layer.

      Args:
          tgt (torch.Tensor): The input to the decoder layer of shape `(target_sequence_length, batch_size, d_model)`.
          memory (torch.Tensor): The output of the encoder layer of shape `(input_sequence_length, batch_size, d_model)`.
          tgt_mask (torch.Tensor, optional): Mask applied to the target sequence. Defaults to None.
          memory_mask (torch.Tensor, optional): Mask applied to the memory sequence. Defaults to None.
          tgt_key_padding_mask (torch.Tensor, optional): Mask applied to the target keys. Defaults to None.
          memory_key_padding_mask (torch.Tensor, optional): Mask applied to the memory keys. Defaults to None.

      Returns:
          torch.Tensor: The output of the decoder layer of shape `(target_sequence_length, batch_size, d_model)`.
      '''

      tgt = self.self_attn(tgt, tgt, tgt, tgt_mask) + self.self_attn_residual_scaler * tgt
      tgt = self.norm1(tgt)
      tgt = self.multihead_attn(tgt, memory, memory, memory_mask) + self.cross_attn_residual_scaler * tgt
      tgt = self.norm2(tgt)
      tgt = self.dropout3(self.linear2(self.dropout2(self.activation(self.linear1(tgt))))) + self.feedforward_residual_scaler * tgt
      tgt = self.norm3(tgt)

      return tgt

  def penalize(self):
      '''
      Compute the regularization loss for the decoder layer.

      Returns:
          torch.Tensor: The regularization loss.
      '''
      loss = 0
      loss += self.self_attn.penalize()
      loss += self.multihead_attn.penalize()
      loss += self.linear1_weight_reg[0] * torch.norm(self.linear1.weight, p=self.linear1_weight_reg[1]) * int(self.linear1.weight.requires_grad)
      loss += self.linear2_weight_reg[0] * torch.norm(self.linear2.weight, p=self.linear2_weight_reg[1]) * int(self.linear2.weight.requires_grad)

      return loss


In [None]:
class SequenceModelBase(torch.nn.Module):
  '''
  Base class for sequence models.

  Args:
    input_size (int): The number of expected features in the input.
    hidden_size (int): The number of features in the hidden state/output.
    seq_len (int, optional): The length of the input sequence. Default is None.
    base_type (str): The type of the base model. Options: 'gru', 'lstm', 'lru', 'cnn', 'transformer'.
    num_layers (int, optional): Number of recurrent layers. Default is 1.
    encoder_bias (bool, optional): Whether to include a bias term in the encoder block. Default is False.
    decoder_bias (bool, optional): Whether to include a bias term in the decoder block. Default is False.
    rnn_bias (bool, optional): If False, then the layer does not use bias weights. Default is True.
    rnn_dropout_p (float, optional): Dropout probability for the base model. Default is 0.
    rnn_bidirectional (bool, optional): If True, becomes a bidirectional RNN. Default is False.
    rnn_attn (bool, optional): Whether to apply attention mechanism on RNN outputs. Default is False.
    rnn_weight_reg (list, optional): Regularization settings for RNN weights. Default is [0.001, 1].
    rnn_weight_norm (float, optional): Norm type for RNN weights. Default is None.
    relax_init (list, optional): Initial relaxation values for LRU. Default is [0.5].
    relax_train (bool, optional): Whether to train relaxation values for LRU. Default is True.
    relax_minmax (list, optional): Minimum and maximum relaxation values for LRU. Default is [0.1, 0.9].
    num_filterbanks (int, optional): Number of filterbanks for LRU. Default is 1.
    cnn_kernel_size (tuple, optional): Size of the convolving kernel for CNN. Default is (1,).
    cnn_stride (tuple, optional): Stride of the convolution for CNN. Default is (1,).
    cnn_padding (tuple, optional): Zero-padding added to both sides of the input for CNN. Default is (0,).
    cnn_dilation (tuple, optional): Spacing between kernel elements for CNN. Default is (1,).
    cnn_groups (int, optional): Number of blocked connections from input channels to output channels for CNN. Default is 1.
    cnn_bias (bool, optional): If False, then the layer does not use bias weights. Default is False.
    encoder_output_size (int, optional): The size of the output from the encoder block. Default is None.
    seq_type (str, optional): Type of the sequence. Options: 'encoder', 'decoder'. Default is 'encoder'.
    transformer_embedding_type (str, optional): Type of embedding for Transformer. Default is 'time'.
    transformer_embedding_bias (bool, optional): Whether to include a bias term in the embedding for Transformer. Default is False.
    transformer_embedding_activation (str, optional): Activation function for Transformer embedding. Options: 'identity', 'relu', 'gelu'. Default is 'identity'.
    transformer_embedding_weight_reg (list, optional): Regularization settings for Transformer embedding weights. Default is [0.001, 1].
    transformer_embedding_weight_norm (float, optional): Norm type for Transformer embedding weights. Default is 2.
    transformer_embedding_dropout_p (float, optional): Dropout probability for Transformer embedding. Default is 0.0.
    transformer_positional_encoding_type (str, optional): Type of positional encoding for Transformer. Options: 'absolute'.
    transformer_dropout1_p (float, optional): Dropout probability for the first dropout layer in Transformer. Default is 0.
    transformer_dropout2_p (float, optional): Dropout probability for the second dropout layer in Transformer. Default is 0.
    transformer_dropout3_p (float, optional): Dropout probability for the third dropout layer in Transformer. Default is 0.
    transformer_linear1_bias (bool, optional): Whether to include a bias term in the first linear layer in Transformer. Default is False.
    transformer_linear2_bias (bool, optional): Whether to include a bias term in the second linear layer in Transformer. Default is False.
    transformer_linear1_weight_reg (list, optional): Regularization settings for the weights of the first linear layer in Transformer. Default is [0.001, 1].
    transformer_linear1_weight_norm (float, optional): Norm type for the weights of the first linear layer in Transformer. Default is 2.
    transformer_linear2_weight_reg (list, optional): Regularization settings for the weights of the second linear layer in Transformer. Default is [0.001, 1].
    transformer_linear2_weight_norm (float, optional): Norm type for the weights of the second linear layer in Transformer. Default is 2.
    transformer_feedforward_activation (str, optional): Activation function for the feedforward layer in Transformer. Options: 'relu'. Default is 'relu'.
    transformer_feedforward_degree (int, optional): Degree of the polynomial activation function for the feedforward layer in Transformer. Default is 2.
    transformer_coef_init (None or float, optional): Initial value for the coefficients of the polynomial activation function in Transformer. Default is None.
    transformer_coef_train (bool, optional): Whether to train the coefficients of the polynomial activation function in Transformer. Default is True.
    transformer_coef_reg (list, optional): Regularization settings for the coefficients of the polynomial activation function in Transformer. Default is [0.001, 1.].
    transformer_zero_order (bool, optional): Whether to include the zero-order term in the polynomial activation function in Transformer. Default is False.
    transformer_scale_self_attn_residual_connection (bool, optional): Whether to scale the residual connection in the self-attention sub-layer of Transformer. Default is False.
    transformer_scale_cross_attn_residual_connection (bool, optional): Whether to scale the residual connection in the cross-attention sub-layer of Transformer. Default is False.
    transformer_scale_feedforward_residual_connection (bool, optional): Whether to scale the residual connection in the feedforward sub-layer of Transformer. Default is False.
    transformer_layer_norm (bool, optional): Whether to include layer normalization in Transformer layers. Default is True.
    num_heads (int, optional): Number of attention heads in Transformer. Default is 1.
    transformer_dim_feedforward (int, optional): Dimension of the feedforward layer in Transformer. Default is 2048.
    self_attn_type (str, optional): Type of self-attention in Transformer. Options: 'dot'. Default is 'dot'.
    multihead_attn_type (str, optional): Type of multihead attention in Transformer. Options: 'dot'. Default is 'dot'.
    memory_is_causal (bool, optional): Whether the memory sequence is causal in Transformer. Default is True.
    tgt_is_causal (bool, optional): Whether the target sequence is causal in Transformer. Default is False.
    query_dim (None or int, optional): The dimension of query in attention mechanism. Default is None.
    key_dim (None or int, optional): The dimension of key in attention mechanism. Default is None.
    value_dim (None or int, optional): The dimension of value in attention mechanism. Default is None.
    query_weight_reg (list, optional): Regularization settings for the query weight in attention mechanism. Default is [0.001, 1].
    query_weight_norm (float, optional): Norm type for the query weight in attention mechanism. Default is 2.
    query_bias (bool, optional): Whether to include a bias term in the query weight in attention mechanism. Default is False.
    key_weight_reg (list, optional): Regularization settings for the key weight in attention mechanism. Default is [0.001, 1].
    key_weight_norm (float, optional): Norm type for the key weight in attention mechanism. Default is 2.
    key_bias (bool, optional): Whether to include a bias term in the key weight in attention mechanism. Default is False.
    value_weight_reg (list, optional): Regularization settings for the value weight in attention mechanism. Default is [0.001, 1].
    value_weight_norm (float, optional): Norm type for the value weight in attention mechanism. Default is 2.
    value_bias (bool, optional): Whether to include a bias term in the value weight in attention mechanism. Default is False.
    gen_weight_reg (list, optional): Regularization settings for the generator weight in attention mechanism. Default is [0.001, 1].
    gen_weight_norm (float, optional): Norm type for the generator weight in attention mechanism. Default is 2.
    gen_bias (bool, optional): Whether to include a bias term in the generator weight in attention mechanism. Default is False.
    concat_weight_reg (list, optional): Regularization settings for the concatenation weight in attention mechanism. Default is [0.001, 1].
    concat_weight_norm (float, optional): Norm type for the concatenation weight in attention mechanism. Default is 2.
    concat_bias (bool, optional): Whether to include a bias term in the concatenation weight in attention mechanism. Default is False.
    attn_dropout_p (float, optional): Dropout probability for attention mechanism. Default is 0.
    average_attn_weights (bool, optional): Whether to average attention weights. Default is False.
    batch_first (bool, optional): If True, then the input and output tensors are provided as (batch, seq, feature). Default is True.
    device (str, optional): The device to run the model on. Default is 'cpu'.
    dtype (torch.dtype, optional): The desired data type of the model's parameters. Default is torch.float32.
  '''

  def __init__(self,
              input_size, hidden_size, seq_len=None,
              base_type='gru', num_layers=1,
              encoder_bias=False, decoder_bias=False,
              rnn_bias=True,
              rnn_dropout_p=0,
              rnn_bidirectional=False,
              rnn_attn=False,
              rnn_weight_reg=[0.001, 1], rnn_weight_norm=None,
              relax_init=[0.5], relax_train=True, relax_minmax=[0.1, 0.9], num_filterbanks=1,
              cnn_kernel_size=(1,), cnn_stride=(1,), cnn_padding=(0,), cnn_dilation=(1,), cnn_groups=1,
              cnn_bias=False,
              encoder_output_size=None, seq_type='encoder',
              transformer_embedding_type='time', transformer_embedding_bias=False,
              transformer_embedding_activation='identity',
              transformer_embedding_weight_reg=[0.001, 1], transformer_embedding_weight_norm=2,
              transformer_embedding_dropout_p=0.0,
              transformer_positional_encoding_type='absolute',
              transformer_dropout1_p=0., transformer_dropout2_p=0., transformer_dropout3_p=0.,
              transformer_linear1_bias=False, transformer_linear2_bias=False,
              transformer_linear1_weight_reg=[0.001, 1], transformer_linear1_weight_norm=2,
              transformer_linear2_weight_reg=[0.001, 1], transformer_linear2_weight_norm=2,
              transformer_feedforward_activation='relu',
              transformer_feedforward_degree=2, transformer_coef_init=None, transformer_coef_train=True,
              transformer_coef_reg=[0.001, 1.], transformer_zero_order=False,
              transformer_scale_self_attn_residual_connection=False,
              transformer_scale_cross_attn_residual_connection=False,
              transformer_scale_feedforward_residual_connection=False,
              transformer_layer_norm=True,
              num_heads=1, transformer_dim_feedforward=2048,
              self_attn_type='dot', multihead_attn_type='dot',
              memory_is_causal=True, tgt_is_causal=False,
              query_dim=None, key_dim=None, value_dim=None,
              query_weight_reg=[0.001, 1], query_weight_norm=2, query_bias=False,
              key_weight_reg=[0.001, 1], key_weight_norm=2, key_bias=False,
              value_weight_reg=[0.001, 1], value_weight_norm=2, value_bias=False,
              gen_weight_reg=[0.001, 1], gen_weight_norm=2, gen_bias=False,
              concat_weight_reg=[0.001, 1], concat_weight_norm=2, concat_bias=False,
              attn_dropout_p=0.,
              average_attn_weights=False,
              batch_first=True,
              device='cpu', dtype=torch.float32):
    super(SequenceModelBase, self).__init__()

    self.to(device=device, dtype=dtype)

    self.name = f"{base_type}{num_layers}"

    positional_encoding = None
    encoder_block = None
    if base_type == 'identity':
        base = torch.nn.Identity()
    elif base_type == 'gru':
        base = torch.nn.GRU(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            bias=rnn_bias,
                            dropout=rnn_dropout_p,
                            bidirectional=rnn_bidirectional,
                            device=device, dtype=dtype,
                            batch_first=True)
    elif base_type == 'lstm':
        base = torch.nn.LSTM(input_size=input_size,
                              hidden_size=hidden_size,
                              num_layers=num_layers,
                              bias=rnn_bias,
                              dropout=rnn_dropout_p,
                              bidirectional=rnn_bidirectional,
                              device=device, dtype=dtype,
                              batch_first=True)
    elif base_type == 'lru':
        base = LRU(input_size=input_size, hidden_size=hidden_size,
                    bias=rnn_bias,
                    relax_init=relax_init, relax_train=relax_train, relax_minmax=relax_minmax,
                    device=device, dtype=dtype)
    elif base_type == 'cnn':
        base = torch.nn.Conv1d(in_channels=input_size,
                                out_channels=hidden_size,
                                kernel_size=cnn_kernel_size,
                                stride=cnn_stride,
                                padding=cnn_padding,
                                dilation=cnn_dilation,
                                groups=cnn_groups,
                                bias=cnn_bias,
                                device=device, dtype=dtype)
    elif base_type == 'transformer':
        embedding = Embedding(num_embeddings=input_size,
                              embedding_dim=hidden_size,
                              embedding_type=transformer_embedding_type,
                              bias=transformer_embedding_bias,
                              activation=transformer_embedding_activation,
                              weight_reg=transformer_embedding_weight_reg,
                              weight_norm=transformer_embedding_weight_norm,
                              dropout_p=transformer_embedding_dropout_p,
                              device=device, dtype=dtype)

        positional_encoding = PositionalEncoding(dim=hidden_size, seq_len=seq_len,
                                                  encoding_type=transformer_positional_encoding_type,
                                                  device=device, dtype=dtype)

        base = torch.nn.ModuleList([torch.nn.Sequential(*[embedding, positional_encoding])])

        if seq_type == 'encoder':
            base.append(torch.nn.TransformerEncoder(TransformerEncoderLayer(d_model=hidden_size,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=transformer_dim_feedforward,
                                                                            self_attn_type=self_attn_type,
                                                                            is_causal=memory_is_causal,
                                                                            query_weight_reg=query_weight_reg,
                                                                            query_weight_norm=query_weight_norm,
                                                                            query_bias=query_bias,
                                                                            key_weight_reg=key_weight_reg,
                                                                            key_weight_norm=key_weight_norm,
                                                                            key_bias=key_bias,
                                                                            value_weight_reg=value_weight_reg,
                                                                            value_weight_norm=value_weight_norm,
                                                                            value_bias=value_bias,
                                                                            gen_weight_reg=gen_weight_reg,
                                                                            gen_weight_norm=gen_weight_norm,
                                                                            gen_bias=gen_bias,
                                                                            concat_weight_reg=concat_weight_reg,
                                                                            concat_weight_norm=concat_weight_norm,
                                                                            concat_bias=concat_bias,
                                                                            average_attn_weights=average_attn_weights,
                                                                            dropout_p=attn_dropout_p,
                                                                            dropout1_p=transformer_dropout1_p,
                                                                            dropout2_p=transformer_dropout2_p,
                                                                            linear1_weight_reg=transformer_linear1_weight_reg,
                                                                            linear1_weight_norm=transformer_linear1_weight_norm,
                                                                            linear2_weight_reg=transformer_linear2_weight_reg,
                                                                            linear2_weight_norm=transformer_linear2_weight_norm,
                                                                            linear1_bias=transformer_linear1_bias,
                                                                            linear2_bias=transformer_linear2_bias,
                                                                            feedforward_activation=transformer_feedforward_activation,
                                                                            degree=transformer_feedforward_degree,
                                                                            coef_init=transformer_coef_init,
                                                                            coef_train=transformer_coef_train,
                                                                            coef_reg=transformer_coef_reg,
                                                                            zero_order=transformer_zero_order,
                                                                            scale_self_attn_residual_connection=transformer_scale_self_attn_residual_connection,
                                                                            scale_feedforward_residual_connection=transformer_scale_feedforward_residual_connection,
                                                                            device=device, dtype=dtype),
                                                    num_layers=num_layers))

        elif seq_type == 'decoder':
            base.append(torch.nn.TransformerDecoder(TransformerDecoderLayer(d_model=hidden_size,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=transformer_dim_feedforward,
                                                                            self_attn_type=self_attn_type,
                                                                            memory_is_causal=memory_is_causal,
                                                                            tgt_is_causal=tgt_is_causal,
                                                                            query_weight_reg=query_weight_reg,
                                                                            query_weight_norm=query_weight_norm,
                                                                            query_bias=query_bias,
                                                                            key_weight_reg=key_weight_reg,
                                                                            key_weight_norm=key_weight_norm,
                                                                            key_bias=key_bias,
                                                                            value_weight_reg=value_weight_reg,
                                                                            value_weight_norm=value_weight_norm,
                                                                            value_bias=value_bias,
                                                                            gen_weight_reg=gen_weight_reg,
                                                                            gen_weight_norm=gen_weight_norm,
                                                                            gen_bias=gen_bias,
                                                                            concat_weight_reg=concat_weight_reg,
                                                                            concat_weight_norm=concat_weight_norm,
                                                                            concat_bias=concat_bias,
                                                                            average_attn_weights=average_attn_weights,
                                                                            dropout_p=attn_dropout_p,
                                                                            dropout1_p=transformer_dropout1_p,
                                                                            dropout2_p=transformer_dropout2_p,
                                                                            dropout3_p=transformer_dropout3_p,
                                                                            linear1_weight_reg=transformer_linear1_weight_reg,
                                                                            linear1_weight_norm=transformer_linear1_weight_norm,
                                                                            linear2_weight_reg=transformer_linear2_weight_reg,
                                                                            linear2_weight_norm=transformer_linear2_weight_norm,
                                                                            linear1_bias=transformer_linear1_bias,
                                                                            linear2_bias=transformer_linear2_bias,
                                                                            feedforward_activation=transformer_feedforward_activation,
                                                                            degree=transformer_feedforward_degree,
                                                                            coef_init=transformer_coef_init,
                                                                            coef_train=transformer_coef_train,
                                                                            coef_reg=transformer_coef_reg,
                                                                            zero_order=transformer_zero_order,
                                                                            scale_self_attn_residual_connection=transformer_scale_self_attn_residual_connection,
                                                                            scale_cross_attn_residual_connection=transformer_scale_cross_attn_residual_connection,
                                                                            scale_feedforward_residual_connection=transformer_scale_feedforward_residual_connection,
                                                                            device=device, dtype=dtype),
                                                    num_layers=num_layers))

            if (encoder_output_size != hidden_size):
                encoder_block = HiddenLayer(in_features=encoder_output_size,
                                            out_features=hidden_size,
                                            activation='identity',
                                            bias=encoder_bias,
                                            device=device, dtype=dtype)

        base[1].norm = None if not transformer_layer_norm else base[1].norm

    else:
        raise ValueError(f"'{base_type}' is not a valid value. `base_type` must be 'gru', 'lstm', 'lru', 'cnn', or 'transformer'.")

    attn_mechanism, decoder_block = None, None
    if rnn_attn:
        attn_mechanism = Attention(embed_dim=hidden_size,
                                    num_heads=num_heads,
                                    query_dim=query_dim, key_dim=key_dim, value_dim=value_dim,
                                    attn_type=multihead_attn_type,
                                    query_weight_reg=query_weight_reg, query_weight_norm=query_weight_norm,
                                    query_bias=query_bias,
                                    key_weight_reg=key_weight_reg, key_weight_norm=key_weight_norm, key_bias=key_bias,
                                    value_weight_reg=value_weight_reg, value_weight_norm=value_weight_norm, value_bias=value_bias,
                                    is_causal=tgt_is_causal, dropout_p=attn_dropout_p,
                                    device=device, dtype=dtype)

        if (encoder_output_size != hidden_size * (1 + rnn_bidirectional)):
            encoder_block = HiddenLayer(in_features=encoder_output_size,
                                        out_features=(hidden_size * (1 + rnn_bidirectional) if base_type in ('lstm', 'gru') else len(relax_init)) * hidden_size * (1 + rnn_bidirectional),
                                        activation='identity',
                                        bias=encoder_bias,
                                        device=device, dtype=dtype)

        decoder_block = HiddenLayer(in_features=2 * hidden_size,
                                    out_features=hidden_size,
                                    activation='identity',
                                    bias=decoder_bias,
                                    device=device, dtype=dtype)

    self.device, self.dtype = device, dtype
    self.input_size = input_size
    self.base, self.base_type, self.num_layers = base, base_type, num_layers
    self.seq_type, self.seq_len = seq_type, seq_len
    self.positional_encoding = positional_encoding
    self.rnn_attn, self.attn_mechanism = rnn_attn, attn_mechanism
    self.encoder_block, self.decoder_block = encoder_block, decoder_block
    self.relax_minmax = relax_minmax
    self.rnn_weight_reg, self.rnn_weight_norm = rnn_weight_reg, rnn_weight_norm

  def init_hiddens(self, num_samples):
    '''
    Initialize hidden states for the base model.

    Args:
        num_samples (int): The number of samples in the input.

    Returns:
        hiddens (list or torch.Tensor): Initialized hidden states.
    '''
    if self.base_type == 'lru':
        hiddens = torch.zeros((self.base.num_filterbanks, num_samples, self.base.hidden_size)).to(device=self.device,
                                                                                                    dtype=self.dtype)
    else:
        if self.base_type == 'lstm':
            hiddens = [torch.zeros((self.base.num_layers, num_samples, self.base.hidden_size)).to(device=self.device,
                                                                                                  dtype=self.dtype)] * 2
        else:
            hiddens = torch.zeros((self.base.num_layers, num_samples, self.base.hidden_size)).to(device=self.device,
                                                                                                  dtype=self.dtype)

    return hiddens

  def forward(self, input, hiddens=None, encoder_output=None, mask=None):
    '''
    Forward pass of the sequence model.

    Args:
        input (torch.Tensor): Input tensor of shape (num_samples, input_len, input_size).
        hiddens (list or torch.Tensor, optional): Hidden states of the base model. Default is None.
        encoder_output (torch.Tensor, optional): Output from the encoder block. Default is None.
        mask (torch.Tensor, optional): Mask tensor for attention mechanism. Default is None.

    Returns:
        output (torch.Tensor): Output tensor of shape (num_samples, input_len, output_size).
        hiddens (list or torch.Tensor): Updated hidden states of the base model.
    '''
    num_samples, input_len, input_size = input.shape

    if (hiddens is None) & (self.base_type in ['lru', 'lstm', 'gru']):
        hiddens = self.init_hiddens(num_samples)

    if self.encoder_block is not None:
        encoder_output = self.encoder_block(encoder_output)

    if self.base_type == 'identity':
        output, hiddens = input, hiddens
    elif self.base_type in ['lru', 'lstm', 'gru']:
        output, hiddens = self.base(input, hiddens)

        output = output.reshape(num_samples, input_len, -1)

        if self.rnn_attn:
            # Pass encoder output and base output (context) to generate attn output and weights
            attn_output = self.attn_mechanism(query=output[:, -1:],
                                              key=encoder_output,
                                              value=encoder_output)

            # Ensure attn_output has the same length as the base output
            if attn_output.shape[1] == 1:
                attn_output = attn_output.repeat(1, output.shape[1], 1)

            # Combine attn output and base output, then pass result to the decoder block to generate the new base output
            output = self.decoder_block(torch.cat((attn_output, output), -1))

    elif self.base_type == 'cnn':
        input_t_pad = torch.nn.functional.pad(input.transpose(1, 2), (self.base.kernel_size[0] - 1, 0))
        output = self.base(input_t_pad).transpose(1, 2)
    elif self.base_type == 'transformer':
        input_embedding_pe = self.base[0](input)

        output = self.base[1](tgt=input_embedding_pe, memory=encoder_output) if self.seq_type == 'decoder' \
            else self.base[1](src=input_embedding_pe, mask=mask)

    return output, hiddens

  def constrain(self):
    '''
    Apply constraints to the model.

    This method applies constraints specific to each base model type.
    '''
    if self.base_type == 'lru':
        self.base.clamp_relax()
    elif self.weight_norm is not None:
        for name, param in self.named_parameters():
            if 'weight' in name:
                param = torch.nn.functional.normalize(param, p=self.rnn_weight_norm, dim=1).contiguous()

  def penalize(self):
    '''
    Compute the penalty for regularization.

    Returns:
        loss (torch.Tensor): Regularization loss.
    '''
    loss = 0
    if self.base_type == 'transformer':
        loss += self.base[0].penalize()  # embedding penalty
        loss += sum(layer.penalize() for layer in self.base[1])  # transformer layer penalties
    else:
        for name, param in self.named_parameters():
            if 'weight' in param:
                loss += self.rnn_weight_reg[0] * torch.norm(param, p=self.rnn_weight_reg[1]) * int(
                    param.requires_grad)

    return loss


In [None]:
class SequenceModel(torch.nn.Module):
  def __init__(self,
               num_inputs, num_outputs,
               #
               input_size = [1], output_size = [1], seq_len = [None],
               stateful = False,
               dt = 1,
               ## Sequence base parameters
               # type
               base_hidden_size = [1],
               base_type = ['gru'], base_num_layers = [1],
               base_enc2dec_bias = [False],
               encoder_output_size = None,
               # GRU/LSTM parameters
               base_rnn_bias = [True],
               base_rnn_dropout_p = [0],
               base_rnn_bidirectional = [False],
               base_rnn_attn = [False],
               base_encoder_bias = [False], base_decoder_bias = [False],
               base_rnn_weight_reg = [[0.001, 1]], base_rnn_weight_norm = [None],
               # LRU parameters
               base_relax_init = [[0.5]], base_relax_train = [True], base_relax_minmax = [[0.1, 0.9]], base_num_filterbanks = [1],
               # CNN parameters
               base_cnn_kernel_size = [(1,)], base_cnn_stride = [(1,)], base_cnn_padding = [(0,)], base_cnn_dilation = [(1,)], base_cnn_groups = [1], base_cnn_bias = [False],
               # Transformer parameters
               base_seq_type = ['encoder'],
               base_transformer_embedding_type = ['time'], base_transformer_embedding_bias = [False], base_transformer_embedding_activation = ['identity'],
               base_transformer_embedding_weight_reg = [[0.001, 1]], base_transformer_embedding_weight_norm = [2], base_transformer_embedding_dropout_p = [0.0],
               base_transformer_positional_encoding_type = ['absolute'],
               base_transformer_dropout1_p = [0.], base_transformer_dropout2_p = [0.], base_transformer_dropout3_p = [0.],
               base_transformer_linear1_bias = [False], base_transformer_linear2_bias = [False],
               base_transformer_linear1_weight_reg = [[0.001, 1]], base_transformer_linear1_weight_norm = [2],
               base_transformer_linear2_weight_reg = [[0.001, 1]], base_transformer_linear2_weight_norm = [2],
               base_transformer_feedforward_activation = ['relu'],
               base_transformer_feedforward_degree = [2], base_transformer_coef_init = [None], base_transformer_coef_train = [True], base_transformer_coef_reg = [[0.001, 1.]], base_transformer_zero_order = [False],
               base_transformer_scale_self_attn_residual_connection = [False],
               base_transformer_scale_cross_attn_residual_connection = [False],
               base_transformer_scale_feedforward_residual_connection = [False],
               base_transformer_layer_norm = [True],
               # attention parameters
               base_num_heads = [1], base_transformer_dim_feedforward = [2048],
               base_self_attn_type = ['dot'], base_multihead_attn_type = ['dot'],
               base_memory_is_causal = [False], base_tgt_is_causal = [True],
               base_query_dim = [None], base_key_dim = [None], base_value_dim = [None],
               base_query_weight_reg = [[0.001, 1]], base_query_weight_norm = [2], base_query_bias = [False],
               base_key_weight_reg = [[0.001, 1]], base_key_weight_norm = [2], base_key_bias = [False],
               base_value_weight_reg = [[0.001, 1]], base_value_weight_norm = [2], base_value_bias = [False],
               base_gen_weight_reg = [[0.001, 1]], base_gen_weight_norm = [2], base_gen_bias = [False],
               base_concat_weight_reg = [[0.001, 1]], base_concat_weight_norm = [2], base_concat_bias = [False],
               base_attn_dropout_p = [0.], base_average_attn_weights = [False],
               base_constrain = False, base_penalize = False,
               ##
               # hidden layer parameters
               hidden_out_features = [0], hidden_bias = [False], hidden_activation = ['identity'], hidden_degree = [1],
               hidden_coef_init = [None], hidden_coef_train = [True], hidden_coef_reg = [[0.001, 1]], hidden_zero_order = [False],
               hidden_softmax_dim = [-1],
               hidden_constrain = [False], hidden_penalize = [False],
               hidden_dropout_p = [0.],
               # interaction layer
               interaction_out_features = 0, interaction_bias = False, interaction_activation = 'identity',
               interaction_degree = 1, interaction_coef_init = True, interaction_coef_train = True,
               interaction_coef_reg = [0.001, 1], interaction_zero_order = False, interaction_softmax_dim = -1,
               interaction_constrain = False, interaction_penalize = False,
               interaction_dropout_p = 0.,
               # modulation layer
               modulation_window_len = None, modulation_associated = False,
               modulation_legendre_degree = None, modulation_chebychev_degree = None,
               modulation_num_freqs = None, modulation_freq_init = None, modulation_freq_train = True,
               modulation_phase_init = None, modulation_phase_train = True,
               modulation_num_sigmoids = None,
               modulation_slope_init = None, modulation_slope_train = True, modulation_shift_init = None, modulation_shift_train = True,
               modulation_weight_reg = [0.001, 1.0], modulation_weight_norm = 2,
               modulation_zero_order = True,
               modulation_bias = True, modulation_pure = False,
               # output layer
               output_associated = [True],
               output_bias = [True], output_activation = ['identity'], output_degree = [1],
               output_coef_init = [None], output_coef_train = [True], output_coef_reg = [[0.001, 1]], output_zero_order = [False], output_softmax_dim = [-1],
               output_constrain = [False], output_penalize = [False],
               output_dropout_p = [0.],
               #
               device = 'cpu', dtype = torch.float32):

    super(SequenceModel, self).__init__()

    self.to(device = device, dtype = dtype)

    locals_copy = locals().copy() # copy the local variables
    for arg in locals_copy:
      value = locals_copy[arg]
      if isinstance(value, list) and any(x in arg for x in ['seq_type', 'input_size', 'base_', 'decoder_', 'hidden_', 'attn_']):
        if len(value) == 1:
          setattr(self, arg, value * num_inputs)
      elif isinstance(value, list) and any(x in arg for x in ['output_size', 'output_']):
        if len(value) == 1:
          setattr(self, arg, value * num_outputs)

    seq_base, hidden_layer = torch.nn.ModuleList([]), torch.nn.ModuleList([])
    for i in range(num_inputs):
      # input-associated sequence layer
      seq_base_i = SequenceModelBase(input_size = input_size[i],
                                      hidden_size = base_hidden_size[i],
                                      seq_len = seq_len[i],
                                      # type
                                      base_type = base_type[i], num_layers = base_num_layers[i],
                                      encoder_bias = base_encoder_bias[i], decoder_bias = base_decoder_bias[i],
                                      # GRU/LSTM parameters
                                      rnn_bias = base_rnn_bias[i],
                                      rnn_dropout_p = base_rnn_dropout_p[i],
                                      rnn_bidirectional = base_rnn_bidirectional[i],
                                      rnn_attn = base_rnn_attn[i],
                                      rnn_weight_reg = base_rnn_weight_reg[i], rnn_weight_norm = base_rnn_weight_norm[i],
                                      # LRU parameters
                                      relax_init = base_relax_init[i], relax_train = base_relax_train[i], relax_minmax = base_relax_minmax[i], num_filterbanks = base_num_filterbanks[i],
                                      # CNN parameters
                                      cnn_kernel_size = base_cnn_kernel_size[i], cnn_stride = base_cnn_stride[i], cnn_padding = base_cnn_padding[i], cnn_dilation = base_cnn_dilation[i], cnn_groups = base_cnn_groups[i], cnn_bias = base_cnn_bias[i],
                                      # Transformer parameters
                                      encoder_output_size = encoder_output_size, seq_type = base_seq_type[i],
                                      transformer_embedding_type = base_transformer_embedding_type[i], transformer_embedding_bias = base_transformer_embedding_bias[i], transformer_embedding_activation = base_transformer_embedding_activation[i],
                                      transformer_embedding_weight_reg = base_transformer_embedding_weight_reg[i], transformer_embedding_weight_norm = base_transformer_embedding_weight_norm[i], transformer_embedding_dropout_p = base_transformer_embedding_dropout_p[i],
                                      transformer_positional_encoding_type = base_transformer_positional_encoding_type[i],
                                      transformer_dropout1_p = base_transformer_dropout1_p[i], transformer_dropout2_p = base_transformer_dropout2_p[i], transformer_dropout3_p = base_transformer_dropout3_p[i],
                                      transformer_linear1_bias = base_transformer_linear1_bias[i], transformer_linear2_bias = base_transformer_linear2_bias[i],
                                      transformer_linear1_weight_reg = base_transformer_linear1_weight_reg[i], transformer_linear1_weight_norm = base_transformer_linear1_weight_norm[i],
                                      transformer_linear2_weight_reg = base_transformer_linear2_weight_reg[i], transformer_linear2_weight_norm = base_transformer_linear2_weight_norm[i],
                                      transformer_feedforward_activation = base_transformer_feedforward_activation[i],
                                      transformer_feedforward_degree = base_transformer_feedforward_degree[i], transformer_coef_init = base_transformer_coef_init[i], transformer_coef_train = base_transformer_coef_train[i], transformer_coef_reg = base_transformer_coef_reg[i], transformer_zero_order = base_transformer_zero_order[i],
                                      transformer_scale_self_attn_residual_connection = base_transformer_scale_self_attn_residual_connection[i],
                                      transformer_scale_cross_attn_residual_connection = base_transformer_scale_cross_attn_residual_connection[i],
                                      transformer_scale_feedforward_residual_connection = base_transformer_scale_feedforward_residual_connection[i],
                                      transformer_layer_norm = base_transformer_layer_norm[i],
                                      # attention parameters
                                      num_heads = base_num_heads[i], transformer_dim_feedforward = base_transformer_dim_feedforward[i],
                                      self_attn_type = base_self_attn_type[i], multihead_attn_type = base_multihead_attn_type[i],
                                      memory_is_causal = base_memory_is_causal[i], tgt_is_causal = base_tgt_is_causal[i],
                                      query_dim = base_query_dim[i], key_dim = base_key_dim[i], value_dim = base_value_dim[i],
                                      query_weight_reg = base_query_weight_reg[i], query_weight_norm = base_query_weight_norm[i], query_bias = base_query_bias[i],
                                      key_weight_reg = base_key_weight_reg[i], key_weight_norm = base_key_weight_norm[i], key_bias = base_key_bias[i],
                                      value_weight_reg = base_value_weight_reg[i], value_weight_norm = base_value_weight_norm[i], value_bias = base_value_bias[i],
                                      gen_weight_reg = base_gen_weight_reg[i], gen_weight_norm = base_gen_weight_norm[i], gen_bias = base_gen_bias[i],
                                      concat_weight_reg = base_concat_weight_reg[i], concat_weight_norm = base_concat_weight_norm[i], concat_bias = base_concat_bias[i],
                                      attn_dropout_p = base_attn_dropout_p[i],
                                      average_attn_weights = base_average_attn_weights[i],
                                      # always batch first
                                      batch_first = True,
                                      #
                                      device = device, dtype = dtype)

      seq_base.append(seq_base_i)
      #

      # input-associated hidden layer
      if hidden_out_features[i] > 0:
        if base_hidden_size[i] > 0:
          if base_type[i] == 'lru':
            hidden_in_features_i = base_hidden_size[i]*len(base_relax_init[i])
          elif base_type[i] in ['gru', 'lstm']:
            hidden_in_features_i = base_hidden_size[i]*(1+base_rnn_bidirectional[i])
          else:
            hidden_in_features_i = base_hidden_size[i]
        else:
          input_size = input_size[i]

        hidden_layer_i = HiddenLayer(# linear transformation
                                     in_features = hidden_in_features_i, out_features = hidden_out_features[i],
                                     bias = hidden_bias[i],
                                     # activation
                                     activation = hidden_activation[i],
                                     # polynomial parameters
                                     degree = hidden_degree[i],
                                     coef_init = hidden_coef_init[i], coef_train = hidden_coef_train[i], coef_reg = hidden_coef_reg[i],
                                     zero_order = hidden_zero_order[i],
                                     # softmax parameter
                                     softmax_dim = hidden_softmax_dim[i],
                                     dropout_p = hidden_dropout_p[i],
                                     device = device, dtype = dtype)
      else:
        hidden_layer_i = torch.nn.Identity()

      hidden_layer.append(hidden_layer_i)
      #

    # interaction layer
    if interaction_out_features > 0:
      if np.sum(hidden_out_features) > 0:
        interaction_in_features = int(np.sum(hidden_out_features))
      else:
        interaction_in_features = 0
        for i in range(num_inputs):
          interaction_in_features += base_hidden_size[i]*(1+base_rnn_bidirectional[i])
        else:
          interaction_in_features += base_transformer_dim_feedforward[i]

      interaction_layer = HiddenLayer(# linear transformation
                                      in_features = interaction_in_features, out_features = interaction_out_features,
                                      bias = interaction_bias,
                                      # activation
                                      activation = interaction_activation,
                                      # polynomial parameters
                                      degree = interaction_degree,
                                      coef_init = interaction_coef_init, coef_train = interaction_coef_train, coef_reg = interaction_coef_reg,
                                      zero_order = interaction_zero_order,
                                      # softmax parameter
                                      softmax_dim = interaction_softmax_dim,
                                      dropout_p = interaction_dropout_p,
                                      device = device, dtype = dtype)
    else:
      interaction_layer = torch.nn.Identity()

    # modulation layer
    if interaction_in_features > 0:
      modulation_in_features = interaction_out_features
    elif np.sum(hidden_out_features) > 0:
      modulation_in_features = np.sum(hidden_out_features)
    else:
      modulation_in_features = 0
      for i in range(num_inputs):
        modulation_in_features += base_hidden_size[i]*(1+base_rnn_bidirectional[i])
      else:
        modulation_in_features += base_transformer_dim_feedforward[i]

    modulation_layer = ModulationLayer(window_len = modulation_window_len,
                                       in_features = modulation_in_features,
                                       associated = modulation_associated,
                                       legendre_degree = modulation_legendre_degree,
                                       chebychev_degree = modulation_chebychev_degree,
                                       dt = dt,
                                       num_freqs = modulation_num_freqs, freq_init = modulation_freq_init,  freq_train = modulation_freq_init,
                                       phase_init = modulation_phase_init, phase_train = modulation_phase_train,
                                       num_sigmoids = modulation_num_sigmoids,
                                       slope_init = modulation_slope_init, slope_train = modulation_slope_train,
                                       shift_init = modulation_shift_init, shift_train=  modulation_shift_init,
                                       weight_reg = modulation_weight_reg, weight_norm = modulation_weight_norm,
                                       zero_order = modulation_zero_order,
                                       bias = modulation_bias, pure = modulation_pure,
                                       device = device, dtype = dtype)
    #

    # output layer
    output_layer = torch.nn.ModuleList([])
    for i in range(num_outputs):
      if modulation_layer is not None:
        output_input_size_i = modulation_layer.num_modulators
      elif interaction_out_features > 0:
        output_input_size_i = interaction_out_features
      elif np.sum(hidden_out_features) > 0:
        if output_associated[i]:
          output_input_size_i = hidden_out_features[i]
        else:
          output_input_size_i = int(np.sum(hidden_out_features))
      else:
        if output_associated[i]:
          if base_type[i] in ['gru', 'lstm', 'lru']:
            output_input_size_i = base_hidden_size[i]*(1+base_rnn_bidirectional[i])
          elif base_type[i] == 'transformer':
            output_input_size_i = base_hidden_size[i]
          else:
            output_input_size_i = input_size[i]
        else:
          output_input_size_i = 0
          for i in range(num_inputs):
            if base_type[i] in ['gru', 'lstm', 'lru']:
              output_input_size_i += int(base_hidden_size[i]*(1+base_rnn_bidirectional[i]))
            elif base_type[i] == 'transformer':
              output_input_size_i += base_hidden_size[i]
            else:
              output_input_size_i += input_size[i]

      if output_size[i] > 0:
        output_layer_i = HiddenLayer(# linear transformation
                                     in_features = output_input_size_i, out_features = output_size[i],
                                     bias = output_bias[i],
                                     # activation
                                     activation = output_activation[i],
                                     # polynomial parameters
                                     degree = output_degree[i],
                                     coef_init = output_coef_init[i], coef_train = output_coef_train[i], coef_reg = output_coef_reg[i],
                                     zero_order = output_zero_order[i],
                                     # softmax parameter
                                     softmax_dim = output_softmax_dim[i],
                                     dropout_p = output_dropout_p[i],
                                     device = device, dtype = dtype)
      else:
        output_layer_i = torch.nn.Identity()
        if np.sum(hidden_out_features) > 0:
          output_size[i] = hidden_out_features[i]
        elif interaction_out_features > 0:
          output_size[i] = interaction_out_features
        elif output_associated[i]:
          output_size[i] = base_hidden_size[i]*(1 + base_rnn_bidirectional[i])
        else:
          output_size[i] = int(np.sum(np.array(base_hidden_size)*(1+np.array(base_rnn_bidirectional))))

      output_layer.append(output_layer_i)

    self.num_inputs,self.num_outputs = num_inputs, num_outputs
    self.input_size, self.output_size = input_size, output_size

    self.stateful = stateful
    self.seq_base, self.base_hidden_size, self.base_type = seq_base, base_hidden_size, base_type
    self.base_constrain, self.base_penalize = base_constrain, base_penalize

    self.hidden_layer, self.hidden_out_features = hidden_layer, hidden_out_features
    self.hidden_constrain, self.hidden_penalize = hidden_constrain, hidden_penalize

    self.interaction_layer, self.interaction_out_features = interaction_layer, interaction_out_features
    self.interaction_constrain, self.interaction_penalize = interaction_constrain, interaction_penalize

    self.modulation_layer = modulation_layer

    self.output_layer = output_layer
    self.output_associated = output_associated
    self.output_constrain, self.output_penalize = output_constrain, output_penalize

    self.device, self.dtype = device, dtype

  def __repr__(self):
    total_num_params = 0
    total_num_trainable_params = 0
    lines = []
    for name, param in self.named_parameters():
      trainable = 'Trainable' if param.requires_grad else 'Untrainable'
      lines.append(f"{name}: shape = {param.shape}. {param.numel()} parameters. {trainable}")
      total_num_params += param.numel()
      if param.requires_grad: total_num_trainable_params += param.numel()

    lines.append("-------------------------------------")
    lines.append(f"{total_num_params} total parameters.")
    lines.append(f"{total_num_trainable_params} total trainable parameters.")

    return '\n'.join(lines)

  def init_hiddens(self):
    return [None for _ in range(self.num_inputs)]

  def process(self,
              input, hiddens,
              steps = None,
              encoder_output = None):

    # Get the dimensions of the input
    num_samples, input_len, input_size = input.shape

    # List to store the output of hidden layers
    hidden_output = []

    # Process each input in the batch individually
    for i,input_i in enumerate(input.split(self.input_size, -1)):

      # Generate output and hiddens of sequence base for the ith input
      base_output_i, hiddens[i] = self.seq_base[i](input = input_i[:, -1:] \
                                                   if (self.seq_base[i].base_type in ['gru','lstm','lru']) & (self.seq_base[i].seq_type == 'decoder') \
                                                   else input_i,
                                                   hiddens = hiddens[i],
                                                   encoder_output = encoder_output)

      base_output_i = torch.nn.functional.pad(base_output_i,
                                              (0, 0, np.max([0, base_output_i.shape[1]-input_len]), 0),
                                              "constant", 0)

      # Generate hidden layer outputs for ith input, append result to previous hidden layer output of previous inputs
      hidden_output_i = self.hidden_layer[i](base_output_i)
      hidden_output.append(hidden_output_i)

    hidden_output = torch.cat(hidden_output,-1)

    interaction_output = self.interaction_layer(hidden_output)

    modulation_output = self.modulation_layer(interaction_output)

    # For each output
    output = []
    for i in range(self.num_outputs):
      # If ith output layer is "associated" (linked to a single input)
      if self.output_associated[i]:
        # Set the output of the ith hidden layer as input to the ith output layer
        output_input_i = hidden_output[i]
      # Otherwise, pass the entire output of previous layer as the input to the ith output layer
      else:
        output_input_i = modulation_output

      # Generate output of ith output layer, append result to previous outputs
      output_i = self.output_layer[i](output_input_i)
      output.append(output_i)

    # Concatenate outputs into single tensor
    output = torch.cat(output, -1)

    # Apply modulation layer
    if self.modulation_layer is not None:
      output = self.modulation_layer(output, steps)

    return output, hiddens

  def forward(self,
              input, steps = None,
              hiddens = None,
              target = None,
              output_window_idx = None,
              input_mask = None, output_mask = None,
              output_input_idx = None, input_output_idx = None,
              encoder_output= None):

    # Convert inputs to the correct device
    input = input.to(device = self.device)
    steps = steps.to(device = self.device) if steps is not None else None
    output_mask = output_mask.to(device =  self.device) if output_mask is not None else None

    # Get the dimensions of the input
    num_samples, input_len, input_size = input.shape

    # Get total number of steps
    if steps is not None:
      _, num_steps = steps.shape

    # Get the maximum output sequence length
    max_output_len = np.max([len(idx) for idx in output_window_idx]) if output_window_idx is not None else input_len

    # Get the total output size
    total_output_size = np.sum(self.output_size)

    # Initiate hiddens if None or not stateful
    if (hiddens is None) | (not self.stateful) & any(type_ in ['gru', 'lstm', 'lru'] for type_ in self.base_type):
      hiddens = hiddens or self.init_hiddens()

    # Process output and updated hiddens
    if 'encoder' in [base.seq_type for base in self.seq_base]: # model is an encoder
      output, hiddens = self.process(input = input,
                                     steps = steps,
                                     hiddens = hiddens,
                                     encoder_output = encoder_output)
    else: # model is a decoder

      # Prepare input for the next step
      input_, output = input, []
      for n in range(max_output_len):
        output_, hiddens = self.process(input = input_.clone()[:, :(n+1)],
                                        steps = steps[:, :(n+1)] if steps is not None else None,
                                        hiddens = hiddens,
                                        encoder_output = encoder_output)

        output.append(output_[:, -1:])

        if (len(output_input_idx) > 0) & (n < (max_output_len-1)):
          input_[:, (n+1):(n+2), output_input_idx] = target[:, n:(n+1), input_output_idx] if target is not None else output[-1][..., input_output_idx]

      output = torch.cat(output, 1)

    # Only keep the outputs for the maximum output sequence length
    output = output[:, -max_output_len:]

    # Apply the output mask if specified
    if output_mask is not None: output = output*output_mask

    return output, hiddens

  def constrain(self):

    for i in range(self.num_inputs):
      if self.base_constrain[i]:
        self.seq_base[i].constrain()

      if self.hidden_constrain[i]:
         self.hidden_layer[i].constrain()

    if self.interaction_constrain:
       self.interaction_layer.constrain()

    for i in range(self.num_outputs):
      if self.output_constrain[i]:
         self.output_layer[i].constrain()

  def penalize(self):
    loss = 0

    for i in range(self.num_inputs):
      if self.base_penalize[i]:
        loss += self.seq_base[i].penalize()

    if self.hidden_penalize[i]:
      loss += self.hidden_layer[i].penalize()

    if self.interaction_penalize:
      loss += self.interaction_layer.penalize()

    for i in range(self.num_outputs):
      if self.output_penalize[i]:
        loss += self.output_layer[i].penalize()

    return loss

  def generate_impulse_response(self, seq_len):
    with torch.no_grad():
      impulse_response = [None for _ in range(self.model.num_inputs)]
      for i in range(self.model.num_inputs):
        # if self.model.base_type[i] in ['gru', 'lstm', 'lru']:
        impulse_response[i] = [None for _ in range(self.model.input_size[i])]
        for f in range(self.model.input_size[i]):
          impulse_i = torch.zeros((1, seq_len, self.model.input_size[i])).to(device = self.model.device,
                                                                                  dtype = self.model.dtype)
          impulse_i[0, 0, f] = 1.

          base_output_if, _ = self.model.seq_base[i](input = impulse_i)

          weight = self.model.hidden_layer[i].F[0].weight

          impulse_response[i][f] = base_output_if.squeeze(0) @ weight.t()

    return impulse_response

In [None]:
class Seq2SeqModel(torch.nn.Module):
  '''
  Sequence-to-Sequence Model that consists of an encoder and a decoder.

  Args:
      encoder (torch.nn.Module): The encoder module.
      decoder (torch.nn.Module): The decoder module.
      learn_decoder_init_input (bool, optional): Whether to learn the decoder's initial input. Defaults to False.
      learn_decoder_hiddens (bool, optional): Whether to learn the decoder's hidden states. Defaults to False.
      enc2dec_bias (bool, optional): Whether to use bias in the encoder-to-decoder mappings. Defaults to True.
      enc2dec_hiddens_bias (bool, optional): Whether to use bias in the encoder-to-decoder hidden state mappings. Defaults to True.
      enc2dec_dropout_p (float, optional): Dropout probability for the encoder-to-decoder mappings. Defaults to 0.
      enc2dec_hiddens_dropout_p (float, optional): Dropout probability for the encoder-to-decoder hidden state mappings. Defaults to 0.
      device (str, optional): Device to run the model on (e.g., 'cpu', 'cuda'). Defaults to 'cpu'.
      dtype (torch.dtype, optional): Data type of the model parameters. Defaults to torch.float32.
  '''

  def __init__(self,
              encoder, decoder,
              learn_decoder_init_input=False, learn_decoder_hiddens=False,
              enc2dec_bias=True, enc2dec_hiddens_bias=True,
              enc2dec_dropout_p=0., enc2dec_hiddens_dropout_p=0.,
              device='cpu', dtype=torch.float32):

    super(Seq2SeqModel, self).__init__()

    enc2dec_init_input_block = None
    if learn_decoder_init_input:
      enc2dec_init_input_block = HiddenLayer(in_features=sum(encoder.input_size),
                                              out_features=sum(decoder.input_size),
                                              bias=enc2dec_bias,
                                              activation='identity',
                                              dropout_p=enc2dec_dropout_p,
                                              device=device,
                                              dtype=dtype)

    enc2dec_hiddens_block = None
    if learn_decoder_hiddens:
      if any(type_ in ['gru', 'lstm', 'lru'] for type_ in encoder.base_type):
        enc2dec_hiddens_input = 0
        for i in range(encoder.num_inputs):
          if encoder.base_type[i] in ['gru', 'lstm', 'lru']:
              enc2dec_hiddens_input += (1 + int(encoder.base_type[i] == 'lstm')) * encoder.base_hidden_size[i] * (1 + int(encoder.base_rnn_bidirectional[i]))
      else:
          enc2dec_hiddens_input = sum(encoder.output_size)

      enc2dec_hiddens_output = sum(np.array([1 + int(type_ == 'lstm') for type_ in decoder.base_type]) * np.array(decoder.base_hidden_size) * np.array([1 + int(bd) for bd in decoder.base_rnn_bidirectional]))

      enc2dec_hiddens_block = HiddenLayer(in_features=enc2dec_hiddens_input,
                                          out_features=enc2dec_hiddens_output,
                                          bias=enc2dec_hiddens_bias,
                                          activation='identity',
                                          dropout_p=enc2dec_hiddens_dropout_p,
                                          device=device,
                                          dtype=dtype)

    self.encoder = encoder
    self.decoder = decoder

    self.num_inputs = encoder.num_inputs
    self.num_outputs = decoder.num_outputs
    self.input_size = encoder.input_size
    self.output_size = decoder.output_size
    self.base_type = encoder.base_type
    self.enc2dec_init_input_block = enc2dec_init_input_block
    self.enc2dec_hiddens_block = enc2dec_hiddens_block
    self.device = device
    self.dtype = dtype

  def forward(self,
              input,
              steps=None,
              hiddens=None,
              input_mask=None, output_mask=None,
              output_input_idx=[], input_output_idx=[],
              encoder_output=None,
              target=None,
              output_window_idx=None):

    '''
    Forward pass of the Seq2SeqModel.

    Args:
        input (torch.Tensor): Input tensor of shape (num_samples, input_len, input_size).
        steps (torch.Tensor, optional): Tensor indicating the number of steps to process for each sample. Shape: (num_samples, input_len). Defaults to None.
        hiddens (list, optional): List of initial hidden states for the encoder. Defaults to None.
        input_mask (torch.Tensor, optional): Mask tensor for the input. Shape: (num_samples, input_len). Defaults to None.
        output_mask (torch.Tensor, optional): Mask tensor for the output. Shape: (num_samples, output_len). Defaults to None.
        output_input_idx (list, optional): List of indices indicating which inputs are used as inputs to the decoder. Defaults to [].
        input_output_idx (list, optional): List of indices indicating which inputs are used as outputs from the encoder. Defaults to [].
        encoder_output (torch.Tensor, optional): Output tensor from the encoder. Shape: (num_samples, input_len, encoder_output_size). Defaults to None.
        target (torch.Tensor, optional): Target tensor. Shape: (num_samples, output_len, output_size). Defaults to None.
        output_window_idx (list, optional): List of indices indicating which outputs are used for the output window. Defaults to None.

    Returns:
        torch.Tensor: Decoder output tensor of shape (num_samples, output_len, output_size).
        list: List of hidden states after the encoder.
    '''

    num_samples, input_len, input_size = input.shape

    encoder_steps = steps[:, :input_len] if steps is not None else None
    decoder_steps = steps[:, (input_len - 1):] if steps is not None else None

    encoder_output, encoder_hiddens = self.encoder(input=input,
                                                   steps=encoder_steps,
                                                   hiddens=hiddens,
                                                   input_mask=input_mask)

    hiddens = encoder_hiddens

    decoder_hiddens = [None for _ in range(self.decoder.num_inputs)]
    if self.enc2dec_hiddens_block is not None:
      # If the enc2dec_hiddens_block exists (decoder must contain rnn's)
      if encoder_hiddens is not None:
        # If there are rnn hidens
        enc2dec_hiddens_input = torch.cat([eh.reshape(num_samples, 1, -1) for eh in encoder_hiddens if eh is not None],-1)
      else:
          enc2dec_hiddens_input = encoder_output

      enc2dec_hiddens_output = self.enc2dec_hiddens_block(enc2dec_hiddens_input)

      j = 0
      for i in range(self.decoder.num_inputs):
          if self.decoder.base_type in ['gru', 'lstm', 'lru']:
              total_base_hidden_size_i = ((1 + self.decoder.base_type[i]) * self.decoder.base_hidden_size[i] *  (1 + self.decoder.rnn_bidirectional[i]))
              decoder_hiddens_i = enc2dec_hiddens_output[..., j:(j + total_base_hidden_size_i)].split( total_base_hidden_size_i, 2)
              decoder_hiddens[i] = decoder_hiddens_i[0] if self.base_type[i] != 'lstm' else decoder_hiddens_i
              j += total_base_hidden_size_i
    else:
      decoder_hiddens = encoder_hiddens

    max_output_len = np.max([base.seq_len for base in self.decoder.seq_base])

    decoder_init_input = self.enc2dec_init_input_block(input[:, -1:]) if self.enc2dec_init_input_block is not None else input[:, -1:]

    decoder_init_input = torch.nn.functional.pad(decoder_init_input, (0, 0, 0, max_output_len - 1), "constant", 0)

    decoder_output, _ = self.decoder(input=decoder_init_input,
                                      steps=decoder_steps,
                                      hiddens=decoder_hiddens,
                                      target=target,
                                      output_window_idx=output_window_idx,
                                      output_mask=output_mask,
                                      output_input_idx=output_input_idx,
                                      input_output_idx=input_output_idx,
                                      encoder_output=encoder_output)

    return decoder_output, hiddens


In [None]:
class Embedding(torch.nn.Module):
    '''
    Embedding layer that maps input tokens to continuous vectors.

    Args:
        num_embeddings (int): Number of unique tokens in the input vocabulary.
        embedding_dim (int): Dimensionality of the embedding vectors.
        embedding_type (str, optional): Type of embedding to use. Supported types: 'time', 'category'.
                                        Defaults to 'time'.
        bias (bool, optional): Whether to include a bias term in the embedding layer. Defaults to False.
        activation (str, optional): Activation function to apply to the embedding. Defaults to 'identity'.
        weight_reg (List[float], optional): Regularization terms for the embedding weights.
                                             Defaults to [0.001, 1].
        weight_norm (float, optional): Order of the normalization applied to the embedding weights.
                                       Defaults to 2.
        dropout_p (float, optional): Dropout probability to apply to the embedding layer. Defaults to 0.0.
        device (str, optional): Device on which the embedding layer is allocated. Defaults to 'cpu'.
        dtype (torch.dtype, optional): Data type of the embedding layer. Defaults to torch.float32.
    '''

    def __init__(self,
                 num_embeddings, embedding_dim, embedding_type='time',
                 bias=False, activation='identity',
                 weight_reg=[0.001, 1], weight_norm=2,
                 dropout_p=0.0,
                 device='cpu', dtype=torch.float32):
      super(Embedding, self).__init__()

      # Check the type of embedding
      if embedding_type == 'time':
          # Time-based embedding using HiddenLayer
          embedding = HiddenLayer(in_features=num_embeddings,
                                  out_features=embedding_dim,
                                  bias=bias,
                                  activation=activation,
                                  weight_reg=weight_reg, weight_norm=weight_norm,
                                  dropout_p=dropout_p,
                                  device=device, dtype=dtype)
      elif embedding_type == 'category':
          # Category-based embedding using torch.nn.Embedding
          embedding = torch.nn.Embedding(num_embeddings, embedding_dim)
      else:
          raise ValueError(f"Unsupported embedding type: {embedding_type}")

      self.embedding = embedding
      self.embedding_type = embedding_type

    def forward(self, input, input_mask=None):
      '''
      Forward pass of the embedding layer.

      Args:
          input (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
          input_mask (torch.Tensor): Input mask tensor of shape (batch_size, sequence_length)
                                      or None if no mask is applied.

      Returns:
          torch.Tensor: Embedded input tensor of shape (batch_size, sequence_length, embedding_dim).
      '''

      # Apply input mask if provided
      input = input*input_mask if input_mask is not None else input

      # Embed the input
      input_embedding = self.embedding(input)

      return input_embedding


In [None]:
class PositionalEncoding(torch.nn.Module):
  '''
  Positional encoding layer that adds positional information to the input.

  Args:
    dim (int): Dimensionality of the input.
    seq_len (int): Length of the input sequence.
    encoding_type (str, optional): Type of positional encoding to use. Supported types: 'absolute', 'relative'.
                                    Defaults to 'absolute'.
    device (str, optional): Device on which the positional encoding layer is allocated. Defaults to 'cpu'.
    dtype (torch.dtype, optional): Data type of the positional encoding layer. Defaults to torch.float32.
  '''

  def __init__(self,
                dim, seq_len, encoding_type='absolute',
                device='cpu', dtype=torch.float32):
      super(PositionalEncoding, self).__init__()

      self.dim, self.seq_len = dim, seq_len
      self.encoding_type = encoding_type
      self.device, self.dtype = device, dtype

      self.positional_encoding = self.generate_positional_encoding()

  def generate_positional_encoding(self):
      '''
      Generates the positional encoding based on the encoding type.

      Returns:
        torch.Tensor: Positional encoding tensor of shape (seq_len, dim).
      '''

      position = torch.arange(self.seq_len).unsqueeze(1).to(device=self.device, dtype=self.dtype)

      if self.encoding_type == 'absolute':
          positional_encoding = torch.zeros((self.seq_len, self.dim)).to(device=self.device, dtype=self.dtype)

          scaler = torch.exp(torch.arange(0, self.dim, 2) * -(torch.math.log(10000.0) / self.dim)).to(
              device=self.device, dtype=self.dtype)

          positional_encoding[:, 0::2] = torch.sin((position) * scaler)
          positional_encoding[:, 1::2] = torch.cos((position) * scaler)

      elif self.encoding_type == 'relative':
          positional_encoding = (position.repeat(1, self.dim) +
                                  torch.arange(self.dim).reshape(1, -1).to(device=self.device, dtype=self.dtype)) / self.seq_len

          positional_encoding = positional_encoding / positional_encoding.max()

      return positional_encoding

  def forward(self, input):
      '''
      Forward pass of the positional encoding layer.

      Args:
        input (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).

      Returns:
        torch.Tensor: Input tensor with added positional encoding of shape (batch_size, seq_len, dim).
      '''

      return input + self.positional_encoding[:input.shape[1], :]


In [None]:
class SequenceDataset(torch.utils.data.Dataset):
  '''
  Dataset class for sequence data.

  Args:
      data (dict): Dictionary containing input and output data.
      input_names (list): Names of the input data.
      output_names (list): Names of the output data.
      step_name (str): Name of the step data.
      input_len (list): List of input sequence lengths. If a single value is provided, it is replicated for all inputs.
      output_len (list): List of output sequence lengths. If a single value is provided, it is replicated for all outputs.
      shift (list): List of output shifts. If a single value is provided, it is replicated for all outputs.
      stride (int): Stride value.
      init_input (torch.Tensor or None): Initial input for padding. Defaults to None.
      print_summary (bool): Whether to print summary information. Defaults to False.
      device (str): Device on which the dataset is allocated. Defaults to 'cpu'.
      dtype (torch.dtype): Data type of the dataset. Defaults to torch.float32.
  '''

  def __init__(self,
                data: dict,
                input_names, output_names, step_name='steps',
                input_len=[1], output_len=[1], shift=[0], stride=1,
                init_input=None,
                print_summary=False,
                device='cpu', dtype=torch.float32):

    num_inputs, num_outputs = len(input_names), len(output_names)

    if len(input_len) == 1:
        input_len = input_len * num_inputs

    if len(output_len) == 1:
        output_len = output_len * num_outputs
    if len(shift) == 1:
        shift = shift * num_outputs

    data_len = data[input_names[0]].shape[0]

    input_len = [data_len if len == -1 else len for len in input_len]
    output_len = [np.max(input_len) if len == -1 else len for len in output_len]

    input_size = [data[name].shape[-1] for name in input_names]
    output_size = [data[name].shape[-1] for name in output_names]

    max_input_len = np.max(input_len)
    max_output_len = np.max(output_len)
    max_shift = np.max(shift)

    has_ar = np.isin(output_names, input_names).any()

    input_window_idx = []
    for i in range(num_inputs):
      input_window_idx.append(torch.arange(max_input_len - input_len[i], max_input_len).to(device=device,
                                                                                              dtype=torch.long))

    output_window_idx = []
    for i in range(num_outputs):
      output_window_idx_i = torch.arange(max_input_len - output_len[i], max_input_len).to(device=device,
                                                                                          dtype=torch.long) + shift[i]
      output_window_idx.append(output_window_idx_i)

    total_window_size = torch.cat(output_window_idx).max().item() + 1
    total_window_idx = torch.arange(total_window_size).to(device=device, dtype=torch.long)

    start_step = max_input_len - max_output_len + max_shift + int(has_ar)

    if print_summary:
      print('\n'.join([f'Data length: {data_len}',
                        f'Window size: {total_window_size}',
                        f'Step indices: {total_window_idx.tolist()}',
                        '\n'.join([f'Input indices for {input_names[i]}: {input_window_idx[i].tolist()}' for i in
                                  range(num_inputs)]),
                        '\n'.join(
                            [f'Output indices for {output_names[i]}: {output_window_idx[i].tolist()}' for i in
                            range(num_outputs)])]))

    self.data = data
    self.input_names, self.output_names, self.step_name = input_names, output_names, step_name
    self.has_ar = has_ar
    self.data_len = data_len
    self.num_inputs, self.num_outputs = num_inputs, num_outputs
    self.input_size, self.output_size = input_size, output_size
    self.start_step = start_step
    self.shift, self.stride = shift, stride
    self.total_window_size, self.total_window_idx = total_window_size, total_window_idx
    self.input_len, self.input_window_idx = input_len, input_window_idx
    self.output_len, self.output_window_idx = output_len, output_window_idx
    self.init_input = init_input
    self.device, self.dtype = device, dtype

    self.input_samples, self.output_samples, self.steps_samples = self.get_samples()

  def get_samples(self):
    '''
    Generates input, output, and steps samples for the dataset.

    Returns:
        tuple: A tuple containing input samples, output samples, and steps samples.
    '''

    input_samples, output_samples, steps_samples = [], [], []

    unique_input_window_idx = torch.cat(self.input_window_idx).unique()
    unique_output_window_idx = torch.cat(self.output_window_idx).unique()

    max_input_len, max_output_len = np.max(self.input_len), np.max(self.output_len + self.shift)

    min_output_idx = torch.cat(self.output_window_idx).min().item()

    window_idx_n = self.total_window_idx

    num_samples = 0
    while window_idx_n.max() < self.data_len:
        num_samples += 1

        steps_samples.append(self.data[self.step_name][window_idx_n])

        # input
        input_n = torch.zeros((max_input_len, np.sum(self.input_size))).to(device=self.device,
                                                                            dtype=self.dtype)

        j = 0
        for i in range(self.num_inputs):
            input_window_idx_i = self.input_window_idx[i]

            input_samples_window_idx_i = window_idx_n[input_window_idx_i] - int(
                self.input_names[i] in self.output_names)

            if (input_samples_window_idx_i[0] == -1) & (self.init_input is not None):
                input_n[0, j:(j + self.input_size[i])] = self.init_input[j:(j + self.input_size[i])]

            input_window_idx_i = input_window_idx_i[input_samples_window_idx_i >= 0]
            input_samples_window_idx_i = input_samples_window_idx_i[input_samples_window_idx_i >= 0]

            input_n[input_window_idx_i, j:(j + self.input_size[i])] = self.data[self.input_names[i]].clone()[
                input_samples_window_idx_i]

            j += self.input_size[i]

        input_samples.append(input_n)

        # output
        output_n = torch.zeros((len(unique_output_window_idx), np.sum(self.output_size))).to(device=self.device,
                                                                                              dtype=self.dtype)

        j = 0
        for i in range(self.num_outputs):
            output_window_idx_i = self.output_window_idx[i]
            output_samples_window_idx_i = window_idx_n[output_window_idx_i]

            output_window_idx_j = output_window_idx_i - min_output_idx

            output_n[output_window_idx_j, j:(j + self.output_size[i])] = self.data[self.output_names[i]].clone()[
                output_samples_window_idx_i]

            j += self.output_size[i]

        output_samples.append(output_n)

        window_idx_n = num_samples * self.stride + self.total_window_idx

    input_samples = torch.stack(input_samples)
    output_samples = torch.stack(output_samples)
    steps_samples = torch.stack(steps_samples)

    self.num_samples = num_samples

    return input_samples, output_samples, steps_samples

  def __len__(self):
    '''
    Returns the number of samples in the dataset.

    Returns:
        int: Number of samples in the dataset.
    '''
    return self.num_samples

  def __getitem__(self, idx):
    '''
    Returns a sample from the dataset at the given index.

    Args:
        idx (int): Index of the sample.

    Returns:
        tuple: A tuple containing the input, output, and steps for the sample.
    '''
    return self.input_samples[idx], self.output_samples[idx], self.steps_samples[idx]


class SequenceDataloader:
  '''
  Dataloader class for sequence data.

  Args:
      input_names (list): Names of the input data.
      output_names (list): Names of the output data.
      step_name (str): Name of the step data.
      data (dict): Dictionary containing input and output data.
      batch_size (int): Batch size. Defaults to 1.
      input_len (list): List of input sequence lengths. If a single value is provided, it is replicated for all inputs.
      output_len (list): List of output sequence lengths. If a single value is provided, it is replicated for all outputs.
      shift (list): List of output shifts. If a single value is provided, it is replicated for all outputs.
      stride (int): Stride value. Defaults to 1.
      init_input (torch.Tensor or None): Initial input for padding. Defaults to None.
      print_summary (bool): Whether to print summary information. Defaults to False.
      device (str): Device on which the dataloader is allocated. Defaults to 'cpu'.
      dtype (torch.dtype): Data type of the dataloader. Defaults to torch.float32.
  '''

  def __init__(self,
                input_names, output_names, step_name,
                data: dict,
                batch_size=1,
                input_len=[1], output_len=[1], shift=[0], stride=1,
                init_input=None,
                print_summary=False,
                device='cpu', dtype=torch.float32):

    self.data = data
    self.batch_size = batch_size
    self.input_names, self.output_names, self.step_name = input_names, output_names, step_name
    self.input_len, self.output_len, self.shift, self.stride = input_len, output_len, shift, stride
    self.init_input = init_input
    self.print_summary = print_summary
    self.device, self.dtype = device, dtype

    self.dl = self.get_dataloader

  def collate_fn(self, batch):
    '''
    Collate function for the dataloader.

    Args:
        batch (list): List of samples.

    Returns:
        tuple: A tuple containing input, output, steps, and batch size.
    '''

    input_samples, output_samples, steps_samples = zip(*batch)

    batch_size = len(input_samples)

    pad_fn = lambda x, fill_value: \
        x + tuple(
            torch.full(x[0].shape, fill_value=fill_value).to(device=x[0].device, dtype=x[0].dtype)
            if isinstance(x[0], torch.Tensor)
            else np.full(x[0].shape, fill_value=fill_value)
            for _ in range(self.batch_size - batch_size))

    if batch_size % self.batch_size != 0:
        input_samples = pad_fn(input_samples, 0)
        output_samples = pad_fn(output_samples, 0)
        steps_samples = pad_fn(steps_samples, -1)

    input = torch.stack(input_samples)
    output = torch.stack(output_samples)
    steps = torch.stack(steps_samples)

    return input, output, steps, batch_size

  @property
  def get_dataloader(self):
    '''
    Property function that returns the dataloader.

    Returns:
        torch.utils.data.DataLoader: DataLoader for the sequence dataset.
    '''

    if len(self.data) > 0:
      ds = SequenceDataset(data=self.data,
                            input_names=self.input_names, output_names=self.output_names,
                            step_name=self.step_name,
                            input_len=self.input_len, output_len=self.output_len,
                            shift=self.shift, stride=self.stride,
                            init_input=self.init_input,
                            print_summary=self.print_summary,
                            device=self.device, dtype=self.dtype)

      self.batch_size = len(ds) if self.batch_size == -1 else self.batch_size

      self.input_size, self.output_size = ds.input_size, ds.output_size
      self.num_inputs, self.num_outputs = ds.num_inputs, ds.num_outputs
      self.input_size, self.output_size = ds.input_size, ds.output_size
      self.data_len, self.num_samples = ds.data_len, ds.num_samples
      self.total_window_size, self.total_window_idx = ds.total_window_size, ds.total_window_idx
      self.shift, self.stride = ds.shift, ds.stride
      self.input_len, self.input_window_idx = ds.input_len, ds.input_window_idx
      self.output_len, self.output_window_idx = ds.output_len, ds.output_window_idx

      self.total_input_len, self.total_output_len = len(torch.cat(ds.input_window_idx, 0).unique()), len(
          torch.cat(ds.output_window_idx, 0).unique())
      self.unique_output_window_idx = torch.cat(ds.output_window_idx, 0).unique()

      self.output_mask = torch.zeros((self.total_output_len, np.sum(self.output_size)), device=self.device,
                                      dtype=self.dtype)
      j = 0
      for i in range(len(ds.output_window_idx)):
          output_window_idx_k = [k for k, l in enumerate(self.unique_output_window_idx) if
                                  l in ds.output_window_idx[i]]
          self.output_mask[output_window_idx_k, j:(j + self.output_size[i])] = 1

          j += self.output_size[i]

    else:
      class NoDataset(torch.utils.data.Dataset):
          def __init__(self):
              pass

          def __getitem__(self, index):
              pass

          def __len__(self):
              return 0

      self.input_size, self.output_size = None, None
      self.num_inputs, self.num_outputs = None, None
      self.input_size, self.output_size = None, None
      self.data_len, self.num_samples = None, None
      self.total_window_size, self.total_window_idx = None, None
      self.shift, self.stride = None, None
      self.input_len, self.input_window_idx = None, None
      self.output_len, self.output_window_idx = None, None

      self.total_input_len, self.total_output_len = None, None
      self.unique_output_window_idx = None

      self.output_mask = None

      ds = NoDataset()

    dl = torch.utils.data.DataLoader(ds,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      collate_fn=self.collate_fn)

    self.num_batches = len(dl)

    return dl


In [None]:
class DataModule(pl.LightningDataModule):
  def __init__(self,
                data,
                time_name, input_names, output_names,
                fuse_features=None, transforms=None,
                pct_train_val_test=[1., 0., 0.],
                batch_size=-1,
                input_len=[1], output_len=[1], shift=[0], stride=1,
                dt=1,
                time_unit='s',
                pad_data=False,
                print_summary=True,
                device='cpu', dtype=torch.float32):

      '''
      Initializes a DataModule object.

      Args:
          data (str or pd.DataFrame): Path to structured data or a pandas DataFrame containing the data.
          time_name (str): Name of the column in the data that represents time.
          input_names (list): List of input feature names.
          output_names (list): List of output feature names.
          fuse_features (list, optional): List of features to fuse into a single feature. Defaults to None.
          transforms (dict, optional): Dictionary specifying the transformations to be applied to each feature. Defaults to None.
          pct_train_val_test (list, optional): List specifying the percentage of data to use for training,
                                                validation, and testing, respectively. Defaults to [1., 0., 0.].
          batch_size (int, optional): Batch size for the dataloaders. If -1, the entire dataset is treated as a
                                      single batch. Defaults to -1.
          input_len (list, optional): List of input sequence lengths for each input feature. If a single value is
                                      provided, it is used for all input features. Defaults to [1].
          output_len (list, optional): List of output sequence lengths for each output feature. If a single value
                                        is provided, it is used for all output features. Defaults to [1].
          shift (list, optional): List of output sequence shifts for each output feature. If a single value is
                                  provided, it is used for all output features. Defaults to [0].
          stride (int, optional): Stride value for creating input-output pairs. Defaults to 1.
          dt (int, optional): Time step size. Defaults to 1.
          time_unit (str, optional): Time unit of the data. Defaults to 's'.
          pad_data (bool, optional): Whether to pad the data to ensure each output sequence has at least one input sequence. Defaults to False.
          print_summary (bool, optional): Whether to print a summary of the data module configuration. Defaults to True.
          device (str, optional): Device to use for tensor operations. Defaults to 'cpu'.
          dtype (torch.dtype, optional): Data type of the tensors. Defaults to torch.float32.
      '''

      super().__init__()

      self.time_name = time_name
      self.input_names = input_names
      self.output_names = output_names
      self.fuse_features = fuse_features
      self.transforms = transforms
      self.pct_train_val_test = pct_train_val_test
      self.batch_size = batch_size
      self.input_len = input_len
      self.output_len = output_len
      self.max_input_len = np.max(input_len).item()
      self.max_output_len = np.max(output_len).item()
      self.shift = shift
      self.stride = stride
      self.max_shift = np.max(shift).item()
      self.dt = dt
      self.time_unit = time_unit
      self.pad_data = pad_data
      self.start_step = np.max([0, (self.max_input_len - self.max_output_len + self.max_shift)]).item()
      self.print_summary = print_summary
      self.data = data
      self.device = device
      self.dtype = dtype
      self.predicting = False
      self.data_prepared = False

  def prepare_data(self):
    '''
    Prepares the data for training, validation, and testing.

    This method is responsible for converting the input data to a dictionary of tensors, applying transformations
    to the data, splitting the data into training, validation, and testing sets, and padding the data if necessary.
    '''
    if not (self.predicting or self.data_prepared):
      self.input_output_names = np.unique(self.input_names + self.output_names).tolist()

      if isinstance(self.data, str):
          # If data is a string, assume it is a path to structured data
          with open(self.data, "rb") as file:
              self.data = pickle.load(file)

      if isinstance(self.data, pd.DataFrame):
          # If data is a pandas dataframe, assume each column is an individual feature
          self.data = self.data.filter(items=[self.time_name] + self.input_output_names)

      # Convert dataframe to dictionary of tensors. Concatenate features, if desired.
      data = {self.time_name: self.data[self.time_name]}
      for key in self.data:
          if key != self.time_name:
              if not isinstance(self.data[key], torch.Tensor):
                  data[key] = torch.tensor(np.array(self.data[key])).to(device=self.device, dtype=self.dtype)
              else:
                  data[key] = self.data[key].to(device=self.device, dtype=self.dtype)

              data[key] = data[key].unsqueeze(1) if data[key].ndim == 1 else data[key]
      self.data = data

      self.feature_names = None
      if self.fuse_features is not None:
          self.feature_names = {}
          for feature in self.fuse_features:
              self.data[feature] = torch.cat([self.data[name] for name in self.data if feature in name], -1)
              input_output_names_with_feature = [name for name in self.input_output_names if feature in name]
              if len(input_output_names_with_feature) > 0:
                  for name in input_output_names_with_feature:
                      _, feature_name = name.split('_', 1)
                      if feature not in self.feature_names:
                          self.feature_names[feature] = [feature_name]
                      else:
                          self.feature_names[feature] += [feature_name]

                      self.input_output_names.remove(name)

                      if name in self.data:
                          del self.data[name]

                  if any(feature in name for name in self.input_names):
                      self.input_names = [name for name in self.input_names if feature not in name] + [feature]
                  if any(feature in name for name in self.output_names):
                      self.output_names = [name for name in self.output_names if feature not in name] + [feature]
              else:
                  self.feature_names[feature] = []

      self.input_output_names = np.unique(self.input_names + self.output_names).tolist()
      self.num_inputs, self.num_outputs = len(self.input_names), len(self.output_names)
      self.input_size = [self.data[name].shape[-1] for name in self.input_names]
      self.output_size = [self.data[name].shape[-1] for name in self.output_names]
      self.max_input_size, self.max_output_size = np.max(self.input_size), np.max(self.output_size)

      if len(self.input_len) == 1:
          self.input_len = self.input_len * self.num_inputs

      if len(self.output_len) == 1:
          self.output_len = self.output_len * self.num_outputs
      if len(self.shift) == 1:
          self.shift = self.shift * self.num_outputs

      self.has_ar = np.isin(self.output_names, self.input_names).any()

      for name in self.input_output_names:
          if self.transforms is None:
              if 'all' in [name for name in self.transforms]:
                  self.transforms[name] = self.transforms['all']
              else:
                  self.transforms = {name: FeatureTransform(scale_type='identity')}
          if name not in self.transforms:
              if 'all' in [name for name in self.transforms]:
                  self.transforms[name] = self.transforms['all']
              else:
                  self.transforms = {name: FeatureTransform(scale_type='identity')}

      self.data_len = self.data[self.input_output_names[0]].shape[0]

      for name in self.input_output_names:
          self.data[name] = self.transforms[name].fit_transform(self.data[name])

      self.data['steps'] = torch.arange(self.data_len).to(device=self.device, dtype=torch.long)

      j = 0
      output_input_idx = []
      for i, name in enumerate(self.trainer.datamodule.input_names):
          input_idx = torch.arange(j, (j + self.input_size[i])).to(dtype=torch.long)
          if name in self.trainer.datamodule.output_names:
              output_input_idx.append(input_idx)
          j += self.input_size[i]
      output_input_idx = torch.cat(output_input_idx, -1) if len(output_input_idx) > 0 else []

      j = 0
      input_output_idx = []
      for i, name in enumerate(self.trainer.datamodule.output_names):
          size_i =  self.output_size[i] if np.sum(self.output_size) > 0 \
                    else self.model.hidden_out_features[i] if np.sum(self.model.hidden_out_features) > 0 \
                    else self.model.base_hidden_size[i]

          output_idx = torch.arange(j, (j + size_i)).to(dtype=torch.long)
          if name in self.trainer.datamodule.input_names:
              input_output_idx.append(output_idx)
          j += size_i
      input_output_idx = torch.cat(input_output_idx, -1) if len(input_output_idx) > 0 else []

      self.input_output_idx, self.output_input_idx = input_output_idx, output_input_idx
      self.data_prepared = True

  def setup(self, stage=None):
    '''
    Sets up the data module for a specific stage of training.

    Args:
        stage (str, optional): The current stage of training ('fit' or 'predict'). Defaults to None.
    '''
    if (stage == 'fit') and (not self.predicting):

      # Split the data
      train_len = int(self.pct_train_val_test[0] * self.data_len)
      val_len = int(self.pct_train_val_test[1] * self.data_len)

      train_data = {name: self.data[name][:train_len] for name in ([self.time_name, 'steps'] + self.input_output_names)}
      if self.pct_train_val_test[1] > 0:
        val_data = {name: self.data[name][train_len:(train_len + val_len)] for name in ([self.time_name, 'steps'] + self.input_output_names)}
      else:
        val_data = {}

      if self.pct_train_val_test[2] > 0:
        test_data = {name: self.data[name][(train_len + val_len):] for name in ([self.time_name, 'steps'] + self.input_output_names)}
        test_len = len(next(iter(test_data.values())))
      else:
          test_data = {}
          test_len = 0

      self.train_len, self.val_len, self.test_len = train_len, val_len, test_len

      train_init_input, val_init_input, test_init_input = None, None, None

      if self.pad_data and (self.start_step > 0):

        pad_dim = self.start_step

        train_data['steps'] = torch.cat((train_data['steps'],
                                         torch.arange(1, 1 + pad_dim).to(device=self.device, dtype=torch.long) + train_data['steps'][-1]),0)

        for name in self.input_output_names:
          train_data[name] = torch.nn.functional.pad(train_data[name], (0, 0, pad_dim, 0), mode='constant', value=0)

        if len(val_data) > 0:
          val_data['steps'] = torch.cat((train_data['steps'][-pad_dim:], torch.arange(1, 1 + len(val_data['steps'])) + train_data['steps'][-1]))
          for name in self.input_output_names:
              val_data[name] = torch.cat((train_data[name][-pad_dim:], val_data[name]), 0)

          val_init_input = val_init_input or []
          for i, name in enumerate(self.input_names):
              val_init_input.append(train_data[name][-(pad_dim + 1)])

        if len(test_data) > 0:
          data_ = val_data if len(val_data) > 0 else train_data
          test_data['steps'] = torch.cat((data_['steps'][-pad_dim:], torch.arange(1, 1 + len(test_data['steps'])) + data_['steps'][-1]))
          for name in self.input_output_names:
            test_data[name] = torch.cat((data_[name][-pad_dim:], test_data[name]), 0)

          test_init_input = test_init_input or []
          for i, name in enumerate(self.input_names):
            test_init_input.append(data_[name][-(pad_dim + 1)])

        else:

          data_ = val_data if len(val_data) > 0 else train_data

          if (len(val_data) > 0) and self.has_ar:
            val_init_input = []
          if (len(test_data) > 0) and self.has_ar:
            test_init_input = []

          for i, name in enumerate(self.input_names):

              if (len(val_data) > 0) and self.has_ar:
                val_init_input.append(train_data[name][-1])

              if (len(test_data) > 0) and self.has_ar:
                test_init_input.append(data_[name][-1])

        if val_init_input is not None:
          val_init_input = torch.cat(val_init_input, -1)
        if test_init_input is not None:
          test_init_input = torch.cat(test_init_input, -1)

        self.train_data, self.val_data, self.test_data = train_data, val_data, test_data
        self.train_init_input, self.val_init_input, self.test_init_input = train_init_input, val_init_input, test_init_input

  def train_dataloader(self):
    '''
    Returns the training dataloader.

    Returns:
        torch.utils.data.DataLoader: The training dataloader.
    '''
    if not self.predicting:
      self.train_batch_size = len(self.train_data['steps']) if self.batch_size == -1 else self.batch_size

      self.train_dl = SequenceDataloader(input_names=self.input_names,
                                          output_names=self.output_names,
                                          step_name='steps',
                                          data=self.train_data,
                                          batch_size=self.train_batch_size,
                                          input_len=self.input_len,
                                          output_len=self.output_len,
                                          shift=self.shift,
                                          stride=self.stride,
                                          init_input=self.train_init_input,
                                          print_summary=self.print_summary,
                                          device=self.device,
                                          dtype=self.dtype)
      self.num_train_batches = self.train_dl.num_batches

      self.train_output_mask = self.train_dl.output_mask
      self.train_input_window_idx, self.train_output_window_idx = self.train_dl.input_window_idx, self.train_dl.output_window_idx
      self.train_total_input_len, self.train_total_output_len = self.train_dl.total_input_len, self.train_dl.total_output_len

      self.train_unique_output_window_idx = self.train_dl.unique_output_window_idx

      print("Training Dataloader Created.")

      return self.train_dl.dl
    else:
      return None

  def val_dataloader(self):
    '''
    Returns the validation dataloader.

    Returns:
        torch.utils.data.DataLoader: The validation dataloader.
    '''
    if not self.predicting:
      if len(self.val_data) > 0:
        self.val_batch_size = len(self.val_data['steps']) if self.batch_size == -1 else self.batch_size
      else:
        self.val_batch_size = 1

      self.val_dl = SequenceDataloader(input_names=self.input_names,
                                      output_names=self.output_names,
                                      step_name='steps',
                                      data=self.val_data,
                                      batch_size=self.val_batch_size,
                                      input_len=self.input_len,
                                      output_len=self.output_len,
                                      shift=self.shift,
                                      stride=self.stride,
                                      init_input=self.val_init_input,
                                      print_summary=self.print_summary,
                                      device=self.device,
                                      dtype=self.dtype)

      self.num_val_batches = self.val_dl.num_batches

      self.val_output_mask = self.val_dl.output_mask
      self.val_input_window_idx, self.val_output_window_idx = self.val_dl.input_window_idx, self.val_dl.output_window_idx
      self.val_total_input_len, self.val_total_output_len = self.val_dl.total_input_len, self.val_dl.total_output_len

      self.val_unique_output_window_idx = self.val_dl.unique_output_window_idx

      return self.val_dl.dl
    else:
      return None

  def test_dataloader(self):
    '''
    Returns the test dataloader.

    Returns:
        torch.utils.data.DataLoader: The test dataloader.
    '''
    if self.predicting and not hasattr(self, 'test_dl'):
      if len(self.test_data) > 0:
        self.test_batch_size = len(self.test_data['steps']) if self.batch_size == -1 else self.batch_size
      else:
        self.test_batch_size = 1

      self.test_dl = SequenceDataloader(input_names=self.input_names,
                                        output_names=self.output_names,
                                        step_name='steps',
                                        data=self.test_data,
                                        batch_size=self.test_batch_size,
                                        input_len=self.input_len,
                                        output_len=self.output_len,
                                        shift=self.shift,
                                        stride=self.stride,
                                        init_input=self.test_init_input,
                                        print_summary=self.print_summary,
                                        device=self.device,
                                        dtype=self.dtype)

      self.num_test_batches = self.test_dl.num_batches

      self.test_output_mask = self.test_dl.output_mask
      self.test_input_window_idx, self.test_output_window_idx = self.test_dl.input_window_idx, self.test_dl.output_window_idx
      self.test_total_input_len, self.test_total_output_len = self.test_dl.total_input_len, self.test_dl.total_output_len

      self.test_unique_output_window_idx = self.test_dl.unique_output_window_idx

      return self.test_dl.dl
    else:
      return None


In [None]:
class SequenceModule(pl.LightningModule):
  def __init__(self,
               model,
               opt, loss_fn, metric_fn = None,
               constrain = False, penalize = False,
               track = False,
               model_dir = None):

    super().__init__()

    self.automatic_optimization = False

    self.model = model

    self.opt, self.loss_fn, self.metric_fn = opt, loss_fn, metric_fn

    self.constrain, self.penalize = constrain, penalize

    input_size, output_size = self.model.input_size, self.model.output_size

    self.train_history, self.val_history = None, None
    self.current_val_epoch = 0

    self.train_step_loss = []
    self.val_step_loss = []
    self.test_step_loss = []

    self.hiddens = None

    self.track = track

    self.model_dir = model_dir

  def forward(self,
              input,
              hiddens = None,
              steps = None,
              target = None,
              output_window_idx = None,
              output_mask = None,
              output_input_idx = None, input_output_idx = None,
              encoder_output= None):

    output, hiddens = self.model.forward(input = input,
                                         steps = steps,
                                        hiddens = hiddens,
                                        target = target,
                                        output_window_idx = output_window_idx,
                                        output_mask = output_mask,
                                        output_input_idx = output_input_idx,
                                        input_output_idx = input_output_idx,
                                        encoder_output= encoder_output)

    return output, hiddens

  ## Configure optimizers
  def configure_optimizers(self):
    return self.opt
  ##

  ## train model
  def on_train_start(self):
    self.run_time = time.time()

  def training_step(self, batch, batch_idx):

    # constrain model if desired
    if self.constrain: self.model.constrain()
    #

    # unpack batch
    input_batch, output_batch, steps_batch, batch_size = batch
    #

    # keep the first `batch_size` batches of hiddens
    if self.hiddens is not None:
      for i in range(self.model.num_inputs):
        if (self.model.base_type[i] in ['gru', 'lstm', 'lru']) & (self.hiddens[i] is not None):
          if self.model.base_type[i] == 'lstm':
            if self.hiddens[i][0].shape[1] >= batch_size:
              self.hiddens[i] = [s[:, :batch_size].contiguous() for s in self.hiddens[i]]
            else:
              self.hiddens[i] = [torch.nn.functional.pad(s.contiguous(), pad=(0, 0, 0, batch_size-s.shape[1]), mode='constant', value=0) for s in self.hiddens[i]]
          else:
            if self.hiddens[i].shape[1] >= batch_size:
              self.hiddens[i] = self.hiddens[i][:, :batch_size].contiguous()
            else:
              self.hiddens[i] = torch.nn.functional.pad(self.hiddens[i].contiguous(), pad=(0, 0, 0, batch_size-self.hiddens[i].shape[1]), mode='constant', value=0)

    input_batch = input_batch[:batch_size]
    output_batch = output_batch[:batch_size]
    steps_batch = steps_batch[:batch_size]
    #

    # perform forward pass to compute gradients
    output_pred_batch, self.hiddens = self.forward(input = input_batch,
                                                   steps = steps_batch,
                                                   hiddens = self.hiddens,
                                                   target = output_batch,
                                                   output_window_idx = self.trainer.datamodule.train_output_window_idx,
                                                   output_input_idx = self.trainer.datamodule.output_input_idx,
                                                   input_output_idx = self.trainer.datamodule.input_output_idx,
                                                   output_mask = self.trainer.datamodule.train_output_mask)
    #

    # get loss for each output
    loss = self.loss_fn(output_pred_batch*self.trainer.datamodule.train_output_mask,
                        output_batch*self.trainer.datamodule.train_output_mask)
    loss = torch.stack([l.sum() for l in loss.split(self.model.output_size, -1)], 0)
    #

    # add penalty loss if desired
    if self.penalize: loss += self.model.penalize()
    #

    self.opt.zero_grad()
    loss.sum().backward()
    self.opt.step()

    # store loss to be used later in `on_train_epoch_end`
    self.train_step_loss.append(loss)
    #

    return {"loss": loss}

  def on_train_batch_start(self, batch, batch_idx):
    if self.hiddens is not None:
      for i in range(self.model.num_inputs):
        if (self.model.base_type[i] in ['gru', 'lstm', 'lru']) & (self.hiddens[i] is not None):
          if self.model.base_type[i] == 'lstm':
            self.hiddens[i] = [s.detach() for s in self.hiddens[i]]
          else:
            self.hiddens[i] = self.hiddens[i].detach()

  def on_train_batch_end(self, outputs, batch, batch_idx):

    # reduced loss of current batch
    train_step_loss = outputs['loss'].detach()
    #

    # log and display sum of batch loss
    self.log('train_step_loss', train_step_loss.sum(), on_step = True, prog_bar = True)
    #

    if self.track:
      if self.train_history is None:
        self.current_train_step = 0
        self.train_history = {'steps': torch.empty((0, 1)).to(device = train_step_loss.device,
                                                              dtype = torch.long)}
        for i in range(self.model.num_outputs):
          loss_name_i = self.loss_fn.name + '_' + self.trainer.datamodule.output_names[i]
          self.train_history[loss_name_i] = torch.empty((0, 1)).to(train_step_loss)

        for name, param in self.model.named_parameters():
          if param.requires_grad == True:
            self.train_history[name] = torch.empty((0, param.numel())).to(param)

      else:
        self.train_history['steps'] = torch.cat((self.train_history['steps'],
                                                 torch.tensor(self.current_train_step).reshape(1, 1).to(train_step_loss)), 0)

        for i in range(self.trainer.datamodule.num_outputs):
          loss_name_i = self.loss_fn.name + '_' + self.trainer.datamodule.output_names[i]
          self.train_history[loss_name_i] = torch.cat((self.train_history[loss_name_i],
                                                       train_step_loss[i].cpu().reshape(1, 1).to(train_step_loss)), 0)

        for i,(name, param) in enumerate(self.model.named_parameters()):
          if param.requires_grad:
            self.train_history[name] = torch.cat((self.train_history[name],
                                                  param.clone().detach().cpu().reshape(1, -1).to(param)), 0)

    self.current_train_step += 1

  def on_train_epoch_start(self):
    self.hiddens = None
    self.train_step_loss = []

  def on_train_epoch_end(self):

    # epoch loss
    train_epoch_loss = torch.stack(self.train_step_loss).mean(0)
    #

    self.log('train_epoch_loss', train_epoch_loss.sum(), on_epoch = True, prog_bar = True)

    self.train_step_loss.clear()
  ## End of Training

  ## Validate Model
  def validation_step(self, batch, batch_idx):

    # unpack batch
    input_batch, output_batch, steps_batch, batch_size = batch
    #

    # keep the first `batch_size` batches of hiddens
    if self.hiddens is not None:

      for i in range(self.model.num_inputs):
        if (self.model.base_type[i] in ['gru', 'lstm', 'lru']) & (self.hiddens[i] is not None):
          if self.model.base_type[i] == 'lstm':
            if self.hiddens[i][0].shape[1] >= batch_size:
              self.hiddens[i] = [s[:, :batch_size].contiguous() for s in self.hiddens[i]]
            else:
              self.hiddens[i] = [torch.nn.functional.pad(s.contiguous(), pad=(0, 0, 0, batch_size-s.shape[1]), mode='constant', value=0) for s in self.hiddens[i]]
          else:
            if self.hiddens[i].shape[1] >= batch_size:
              self.hiddens[i] = self.hiddens[i][:, :batch_size].contiguous()
            else:
              self.hiddens[i] = torch.nn.functional.pad(self.hiddens[i].contiguous(), pad=(0, 0, 0, batch_size-self.hiddens[i].shape[1]), mode='constant', value=0)

    input_batch = input_batch[:batch_size]
    output_batch = output_batch[:batch_size]
    steps_batch = steps_batch[:batch_size]
    #

    # perform forward pass to compute gradients
    output_pred_batch, self.hiddens = self.forward(input = input_batch,
                                                  steps = steps_batch,
                                                  hiddens = self.hiddens,
                                                  target = None,
                                                  output_window_idx = self.trainer.datamodule.val_output_window_idx,
                                                  output_input_idx = self.trainer.datamodule.output_input_idx,
                                                  input_output_idx = self.trainer.datamodule.input_output_idx,
                                                  output_mask = self.trainer.datamodule.val_output_mask)
    #

    # get loss for each output
    loss = self.loss_fn(output_pred_batch*self.trainer.datamodule.val_output_mask,
                        output_batch*self.trainer.datamodule.val_output_mask)
    loss = torch.stack([l.sum() for l in loss.split(self.model.output_size, -1)], 0)
    #

    self.val_step_loss.append(loss)

    {"loss": loss}

  def on_validation_epoch_end(self):
    # epoch loss
    val_epoch_loss = torch.stack(self.val_step_loss).mean(0)
    #

    self.log('val_epoch_loss', val_epoch_loss.sum(), on_step = False, on_epoch = True, prog_bar = True)

    if self.track:
      if self.val_history is None:
        self.val_history = {'epochs': torch.empty((0, 1)).to(device = val_epoch_loss.device,
                                                             dtype = torch.long)}
        for i in range(self.trainer.datamodule.num_outputs):
          self.val_history[self.loss_fn.name + '_' + self.trainer.datamodule.output_names[i]] = torch.empty((0, 1)).to(val_epoch_loss)

      else:
        self.val_history['epochs'] = torch.cat((self.val_history['epochs'],
                                              torch.tensor(self.current_val_epoch).reshape(1, 1).to(val_epoch_loss)), 0)

        for i in range(self.trainer.datamodule.num_outputs):
          loss_name_i = self.loss_fn.name + '_' + self.trainer.datamodule.output_names[i]
          self.val_history[loss_name_i] = torch.cat((self.val_history[loss_name_i],
                                                    val_epoch_loss[i].cpu().reshape(1, 1).to(val_epoch_loss)), 0)

    self.val_step_loss.clear()

    self.current_val_epoch += 1
  ## End of validation

  ## Test Model
  def test_step(self, batch, batch_idx):

    # unpack batch
    input_batch, output_batch, steps_batch, batch_size = batch
    #

    # keep the first `batch_size` batches of hiddens
    if self.hiddens is not None:
      for i in range(self.model.num_inputs):
        if (self.model.base_type[i] in ['gru', 'lstm', 'lru']) & (self.hiddens[i] is not None):
          if self.model.base_type[i] == 'lstm':
            if self.hiddens[i][0].shape[1] >= batch_size:
              self.hiddens[i] = [s[:, :batch_size].contiguous() for s in self.hiddens[i]]
            else:
              self.hiddens[i] = [torch.nn.functional.pad(s.contiguous(), pad=(0, 0, 0, batch_size-s.shape[1]), mode='constant', value=0) for s in self.hiddens[i]]
          else:
            if self.hiddens[i].shape[1] >= batch_size:
              self.hiddens[i] = self.hiddens[i][:, :batch_size].contiguous()
            else:
              self.hiddens[i] = torch.nn.functional.pad(self.hiddens[i].contiguous(), pad=(0, 0, 0, batch_size-self.hiddens[i].shape[1]), mode='constant', value=0)

    input_batch = input_batch[:batch_size]
    output_batch = output_batch[:batch_size]
    steps_batch = steps_batch[:batch_size]
    #

    # perform forward pass to compute gradients
    output_pred_batch, self.hiddens = self.forward(input = input_batch,
                                                  steps = steps_batch,
                                                  hiddens = self.hiddens,
                                                  target = None,
                                                  output_window_idx = self.trainer.datamodule.test_output_window_idx,
                                                  output_input_idx = self.trainer.datamodule.output_input_idx,
                                                  input_output_idx = self.trainer.datamodule.input_output_idx,
                                                  output_mask = self.trainer.datamodule.test_output_mask)
    #

    # get loss for each output
    loss = self.loss_fn(output_pred_batch*self.trainer.datamodule.test_output_mask,
                        output_batch*self.trainer.datamodule.test_output_mask)
    loss = torch.stack([l.sum() for l in loss.split(self.model.output_size, -1)], 0)
    #

    self.test_step_loss.append(loss)

    {"loss": loss}

  def on_test_epoch_end(self):
    # epoch loss
    test_epoch_loss = torch.stack(self.test_step_loss).mean(0)
    self.test_step_loss.clear()
    #

    self.log('test_epoch_loss', test_epoch_loss.sum(), on_epoch = True, prog_bar = True)
  ## End of Testing

  ## plot history
  def plot_history(self, history = None, plot_train_history_by = 'epochs'):

    history = [self.loss_fn.name] if history is None else history

    if plot_train_history_by == 'epochs':
      num_batches = len(self.trainer.datamodule.train_dl.dl)
      train_history_epoch = {'epochs': torch.arange(len(self.train_history['steps'])//num_batches).to(dtype = torch.long)}
      num_epochs = len(train_history_epoch['epochs'])
      for key in self.train_history.keys():
        if key != 'steps':
          batch_param = []
          for batch in self.train_history[key].split(num_batches, 0):
            batch_param.append(batch.mean(0, keepdim = True))
          batch_param = torch.cat(batch_param, 0)
          train_history_epoch[key] = batch_param[:num_epochs]

      train_history = train_history_epoch

      x_label = 'epochs'

    else:
      x_label = 'steps'
      train_history = self.train_history

    num_params = len(history)
    fig = plt.figure(figsize = (5, 5*num_params))
    ax_i = 0
    for param in history:
      ax_i += 1
      ax = fig.add_subplot(num_params, 1, ax_i)
      ax.plot(train_history[x_label], train_history[param], label = 'Train')
      if (self.val_history is not None) & (x_label == 'epochs') & ((self.loss_fn.name in param) | (self.metric_fn.name in param)):
        N = np.min([self.val_history[x_label].shape[0], self.val_history[param].shape[0]])
        ax.plot(self.val_history[x_label][:N], self.val_history[param][:N], label = 'Val')
      ax.set_title(param)
      ax.set_ylabel(param)
      ax.legend()
    plt.grid()
  ##

  ## Prediction
  def predict_step(self, batch, batch_idx):

    # unpack batch
    input_batch, output_batch, steps_batch, batch_size = batch
    #

    # keep the first `batch_size` batches of hiddens
    if self.hiddens is not None:
      for i in range(self.model.num_inputs):
        if (self.model.base_type[i] in ['gru', 'lstm', 'lru']) & (self.hiddens[i] is not None):
          if self.model.base_type[i] == 'lstm':
            if self.hiddens[i][0].shape[1] >= batch_size:
              self.hiddens[i] = [s[:, :batch_size].contiguous() for s in self.hiddens[i]]
            else:
              self.hiddens[i] = [torch.nn.functional.pad(s.contiguous(), pad=(0, 0, 0, batch_size-s.shape[1]), mode='constant', value=0) for s in self.hiddens[i]]
          else:
            if self.hiddens[i].shape[1] >= batch_size:
              self.hiddens[i] = self.hiddens[i][:, :batch_size].contiguous()
            else:
              self.hiddens[i] = torch.nn.functional.pad(self.hiddens[i].contiguous(), pad=(0, 0, 0, batch_size-self.hiddens[i].shape[1]), mode='constant', value=0)

    input_batch = input_batch[:batch_size]
    output_batch = output_batch[:batch_size]
    steps_batch = steps_batch[:batch_size]
    #

    output_len = output_batch.shape[1]

    # perform forward pass to compute gradients
    output_pred_batch, self.hiddens = self.forward(input = input_batch,
                                                   steps = steps_batch,
                                                   hiddens = self.hiddens,
                                                   target = None,
                                                   output_window_idx = self.predict_output_window_idx,
                                                   output_input_idx = self.trainer.datamodule.output_input_idx,
                                                   input_output_idx = self.trainer.datamodule.input_output_idx,
                                                   output_mask = self.predict_output_mask)
    #

    # get loss for each output
    step_loss = self.loss_fn(output_pred_batch*self.predict_output_mask,
                             output_batch*self.predict_output_mask)
    step_loss = torch.stack([l.sum() for l in step_loss.split(self.model.input_size, -1)], 0)
    #

    output_steps_batch = steps_batch[:, -output_len:]

    return output_batch, output_pred_batch, output_steps_batch # , baseline_pred_batch

  def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx = 0):
    self.step_target.append(outputs[0])
    self.output_pred_batch.append(outputs[1])
    self.output_steps_batch.append(outputs[2])
    # self.step_baseline_pred.append(outputs[3])

  def on_predict_epoch_end(self):
    self.target = torch.cat(self.step_target, 0)
    self.prediction = torch.cat(self.output_pred_batch, 0)
    self.output_steps = torch.cat(self.output_steps_batch, 0)
    # self.baseline_prediction = torch.cat(self.step_baseline_pred, 0)

    self.step_target.clear()
    self.output_pred_batch.clear()
    self.output_steps_batch.clear()
    # self.step_baseline_pred.clear()

  def on_predict_epoch_start(self):
    self.output_pred_batch, self.step_target = [], []
    self.output_steps_batch = []
    # self.step_baseline_pred = []

  def predict(self,
              reduction = 'mean',
              baseline_model = None):

    self.baseline_model = baseline_model

    self.trainer.datamodule.predicting = True

    self.trainer.enable_progress_bar = False

    pad_dim = self.trainer.datamodule.start_step*int(self.trainer.datamodule.pad_data)

    hiddens = None
    with torch.no_grad():

      ## Predict training data
      self.predict_output_mask = self.trainer.datamodule.train_output_mask
      self.predict_output_window_idx = self.trainer.datamodule.train_output_window_idx

      self.trainer.predict(self, self.trainer.datamodule.train_dl.dl)

      self.prediction, self.target, self.output_steps = self.prediction[pad_dim:], self.target[pad_dim:], self.output_steps[pad_dim:]

      train_prediction, train_output_steps = self.generate_reduced_output(self.prediction, self.output_steps,
                                                                          reduction = reduction, transforms=self.trainer.datamodule.transforms)

      train_target, _ = self.generate_reduced_output(self.target, self.output_steps,
                                                     reduction = reduction, transforms=self.trainer.datamodule.transforms)

      # train_loss = self.loss_fn(train_prediction.unsqueeze(0),
      #                           train_target.unsqueeze(0))

      # train_loss = torch.stack([l.sum() for l in train_loss.split(self.model.input_size, -1)], 0)

      train_time = self.trainer.datamodule.train_data[self.trainer.datamodule.time_name][pad_dim:]

      train_baseline_pred, train_baseline_loss = None, None
      if self.baseline_model is not None:
        train_baseline_pred = self.baseline_model(train_target)
        # train_baseline_loss = self.loss_fn(train_baseline_pred.unsqueeze(0),
        #                                    train_target.unsqueeze(0))
      ##

      # Predict validation data
      val_prediction, val_target, val_time, val_loss, val_baseline_pred, val_baseline_loss = None, None, None, None, None, None
      if len(self.trainer.datamodule.val_dl.dl) > 0:
        self.predict_output_mask = self.trainer.datamodule.val_output_mask
        self.predict_output_window_idx = self.trainer.datamodule.val_output_window_idx

        self.trainer.predict(self, self.trainer.datamodule.val_dl.dl) ;

        val_prediction, val_output_steps = self.generate_reduced_output(self.prediction, self.output_steps,
                                                                        reduction = reduction, transforms=self.trainer.datamodule.transforms)

        val_target, _ = self.generate_reduced_output(self.target, self.output_steps,
                                                     reduction = reduction, transforms=self.trainer.datamodule.transforms)

        # val_loss = self.loss_fn(val_prediction.unsqueeze(0),
        #                         val_target.unsqueeze(0))
        # val_loss = torch.stack([l.sum() for l in val_loss.split(self.model.input_size, -1)], 0)

        val_time = self.trainer.datamodule.val_data[self.trainer.datamodule.time_name]

        val_baseline_pred, val_baseline_loss = None, None
        if self.baseline_model is not None:
          val_baseline_pred = self.baseline_model(val_target)
          # val_baseline_loss = self.loss_fn(val_baseline_pred.unsqueeze(0),
          #                                  val_target.unsqueeze(0))
      #

      # Predict testing data
      if not hasattr(self.trainer.datamodule, 'test_dl'):
        self.trainer.datamodule.test_dataloader()
      test_prediction, test_target, test_time, test_loss, test_baseline_pred, test_baseline_loss = None, None, None, None, None, None
      if len(self.trainer.datamodule.test_dl.dl) > 0:
        self.predict_output_mask = self.trainer.datamodule.test_output_mask
        self.predict_output_window_idx = self.trainer.datamodule.test_output_window_idx

        self.trainer.predict(self, self.trainer.datamodule.test_dl.dl) ;

        test_prediction, test_output_steps = self.generate_reduced_output(self.prediction, self.output_steps,
                                                                          reduction = reduction, transforms=self.trainer.datamodule.transforms)

        test_target, _ = self.generate_reduced_output(self.target, self.output_steps,
                                                      reduction = reduction, transforms=self.trainer.datamodule.transforms)

        # test_loss = self.loss_fn(test_prediction.unsqueeze(0),
        #                         test_target.unsqueeze(0))
        # test_loss = torch.stack([l.sum() for l in test_loss.split(self.model.input_size, -1)], 0)

        test_time = self.trainer.datamodule.test_data[self.trainer.datamodule.time_name]

        test_baseline_pred, test_baseline_loss = None, None
        if self.baseline_model is not None:
          test_baseline_pred = self.baseline_model(test_target)
          # test_baseline_loss = self.loss_fn(test_baseline_pred.unsqueeze(0),
          #                                   test_target.unsqueeze(0))
      #

    train_prediction_data, val_prediction_data, test_prediction_data = {self.trainer.datamodule.time_name: train_time}, None, None

    if val_prediction is not None: val_prediction_data = {self.trainer.datamodule.time_name: val_time}
    if test_prediction is not None: test_prediction_data = {self.trainer.datamodule.time_name: test_time}

    j = 0
    for i,output_name in enumerate(self.trainer.datamodule.output_names):

      # train
      train_target_i = train_target[:, j:(j+self.trainer.datamodule.output_size[i])]
      train_prediction_i = train_prediction[:, j:(j+self.trainer.datamodule.output_size[i])]

      train_prediction_data[f"{output_name}_actual"] = train_target_i
      train_prediction_data[f"{output_name}_prediction"] = train_prediction_i

      train_loss_i = Loss(self.loss_fn.name,
                          dims=(0,1))(train_prediction_i.unsqueeze(0), train_target_i.unsqueeze(0))
      train_prediction_data[f"{output_name}_{self.loss_fn.name}"] = train_loss_i

      if self.metric_fn is not None:
        train_metric_i = Loss(self.metric_fn.name,
                            dims=(0,1))(train_prediction_i.unsqueeze(0), train_target_i.unsqueeze(0))
        train_prediction_data[f"{output_name}_{self.metric_fn.name}"] = train_metric_i

      train_baseline_pred_i, train_baseline_loss_i, train_baseline_metric_i = None, None, None
      if train_baseline_pred is not None:
        train_baseline_pred_i = train_baseline_pred[:, j:(j+self.trainer.datamodule.output_size[i])]

        train_baseline_loss_i = Loss(self.loss_fn.name,
                                     dims=(0,1))(train_baseline_pred_i.unsqueeze(0), train_target_i.unsqueeze(0))

        if self.metric_fn is not None:
          train_baseline_metric_i = Loss(self.metric_fn.name,
                                         dims=(0,1))(train_baseline_pred_i.unsqueeze(0), train_target_i.unsqueeze(0))

      train_prediction_data[f"{output_name}_baseline_prediction"] = train_baseline_pred_i
      train_prediction_data[f"{output_name}_baseline_{self.loss_fn.name}"] = train_baseline_loss_i
      train_prediction_data[f"{output_name}_baseline_{self.metric_fn.name}"] = train_baseline_metric_i
      #

      # val
      if val_prediction is not None:
        val_prediction_data[output_name] = {}

        val_target_i = val_target[:, j:(j+self.trainer.datamodule.output_size[i])]
        val_prediction_i = val_prediction[:, j:(j+self.trainer.datamodule.output_size[i])]

        val_prediction_data[f"{output_name}_actual"] = val_target_i
        val_prediction_data[f"{output_name}_prediction"] = val_prediction_i

        val_loss_i = Loss(self.loss_fn.name,
                            dims=(0,1))(val_prediction_i.unsqueeze(0), val_target_i.unsqueeze(0))
        val_prediction_data[f"{output_name}_{self.loss_fn.name}"] = val_loss_i

        if self.metric_fn is not None:
          val_metric_i = Loss(self.metric_fn.name,
                              dims=(0,1))(val_prediction_i.unsqueeze(0), val_target_i.unsqueeze(0))
          val_prediction_data[f"{output_name}_{self.metric_fn.name}"] = val_metric_i

        val_baseline_pred_i, val_baseline_loss_i, val_baseline_metric_i = None, None, None
        if val_baseline_pred is not None:
          val_baseline_pred_i = val_baseline_pred[:, j:(j+self.trainer.datamodule.output_size[i])]

          val_baseline_loss_i = Loss(self.loss_fn.name,
                              dims=(0,1))(val_baseline_pred_i.unsqueeze(0), val_target_i.unsqueeze(0))

          if self.metric_fn is not None:
            val_baseline_metric_i = Loss(self.metric_fn.name,
                                         dims=(0,1))(val_baseline_pred_i.unsqueeze(0), val_target_i.unsqueeze(0))

        val_prediction_data[f"{output_name}_baseline_prediction"] = val_baseline_pred_i
        val_prediction_data[f"{output_name}_baseline_{self.loss_fn.name}"] = val_baseline_loss_i
        val_prediction_data[f"{output_name}_baseline_{self.metric_fn.name}"] = val_baseline_metric_i
      #

      # test
      if test_prediction is not None:
        test_prediction_data[output_name] = {}

        test_target_i = test_target[:, j:(j+self.trainer.datamodule.output_size[i])]
        test_prediction_i = test_prediction[:, j:(j+self.trainer.datamodule.output_size[i])]

        test_prediction_data[f"{output_name}_actual"] = test_target_i
        test_prediction_data[f"{output_name}_prediction"] = test_prediction_i

        test_loss_i = Loss(self.loss_fn.name,
                           dims=(0,1))(test_prediction_i.unsqueeze(0), test_target_i.unsqueeze(0))
        test_prediction_data[f"{output_name}_{self.loss_fn.name}"] = test_loss_i

        if self.metric_fn is not None:
          test_metric_i = Loss(self.metric_fn.name,
                              dims=(0,1))(test_prediction_i.unsqueeze(0), test_target_i.unsqueeze(0))
          test_prediction_data[f"{output_name}_{self.metric_fn.name}"] = test_metric_i

        test_baseline_pred_i, test_baseline_loss_i, test_baseline_metric_i = None, None, None
        if test_baseline_pred is not None:
          test_baseline_pred_i = test_baseline_pred[:, j:(j+self.trainer.datamodule.output_size[i])]

          test_baseline_loss_i = Loss(self.loss_fn.name,
                                      dims=(0,1))(test_baseline_pred_i.unsqueeze(0), test_target_i.unsqueeze(0))

          if self.metric_fn is not None:
            test_baseline_metric_i = Loss(self.metric_fn.name,
                                          dims=(0,1))(test_baseline_pred_i.unsqueeze(0), test_target_i.unsqueeze(0))

        test_prediction_data[f"{output_name}_baseline_prediction"] = test_baseline_pred_i
        test_prediction_data[f"{output_name}_baseline_{self.loss_fn.name}"] = test_baseline_loss_i
        test_prediction_data[f"{output_name}_baseline_{self.metric_fn.name}"] = test_baseline_metric_i
      #

      j += self.trainer.datamodule.output_size[i]

    self.train_prediction_data, self.val_prediction_data, self.test_prediction_data = train_prediction_data, val_prediction_data, test_prediction_data

    self.trainer.enable_progress_bar = True
    self.trainer.datamodule.predicting = False

  ##
  def plot_predictions(self,
                       output_feature_units = None,
                       include_baseline = False):

    time_name = self.trainer.datamodule.time_name
    output_names = self.trainer.datamodule.output_names
    feature_names = self.trainer.datamodule.feature_names
    num_outputs = len(output_names)
    output_size = self.trainer.datamodule.output_size
    max_output_size = np.max(output_size)

    start_step = self.trainer.datamodule.start_step

    rows, cols = max_output_size, num_outputs
    fig, ax = plt.subplots(rows, cols, figsize = (10*num_outputs, 5*max_output_size))

    train_time = self.train_prediction_data[time_name]
    val_time = self.val_prediction_data[time_name] if self.val_prediction_data is not None else None
    test_time = self.test_prediction_data[time_name] if self.test_prediction_data is not None else None

    for i,output_name in enumerate(output_names):

      try:
        ax_i = ax[i, :]
        [ax_j.axis("off") for ax_j in ax_i]
      except:
        pass

      for f in range(output_size[i]):

        if (feature_names is not None):
          if any(output_name in name for name in feature_names) & (output_size[i] > 1):
            output_feature_name_if = feature_names[output_name][f]
        else:
          output_feature_name_if = None

        if output_feature_units is not None:
          if output_name in output_feature_units:
            output_feature_units_if = output_feature_units[output_name][f]
          else:
            output_feature_units_if = None
        else:
          output_feature_units_if = None

        try:
          ax_if = ax[f,i]
        except:
          try:
            j = i if (cols>1) & (rows == 1) else f
            ax_if = ax[j]
          except:
            ax_if = ax

        train_target_if = self.train_prediction_data[f"{output_name}_actual"][:, f]
        train_prediction_if = self.train_prediction_data[f"{output_name}_prediction"][:, f]
        train_loss_if = np.round(self.train_prediction_data[f"{output_name}_{self.loss_fn.name}"][f].item(),2)
        train_metric_if = np.round(self.train_prediction_data[f"{output_name}_{self.metric_fn.name}"][f].item(),2) if self.metric_fn is not None else None
        if include_baseline:
          train_baseline_prediction_if = self.train_prediction_data[f"{output_name}_baseline_prediction"][:, f]
          train_baseline_loss_if = np.round(self.train_prediction_data[f"{output_name}_baseline_{self.loss_fn.name}"][f].item(),2)
          train_baseline_metric_if = np.round(self.train_prediction_data[f"{output_name}_baseline_{self.metric_fn.name}"][f].item(),2) if self.metric_fn is not None else None

        ax_if.plot(train_time, train_target_if, '-k', label = 'Actual')
        ax_if.plot(train_time, train_prediction_if, '-r', label = 'Prediction')
        train_label = f"Train ({self.loss_fn.name} = {train_loss_if}, {self.metric_fn.name} = {train_metric_if})" \
                      if train_metric_if is not None \
                      else f"Train ({self.loss_fn.name} = {train_loss_if})"
        if include_baseline:
          ax_if.plot(train_time, train_baseline_prediction_if, '--g', linewidth = 1.0, label = 'Baseline')
          train_label = train_label + f", Baseline ({self.loss_fn.name} = {train_baseline_loss_if}, {self.metric_fn.name} = {train_baseline_metric_if})"

        ax_if.axvspan(train_time.min(), train_time.max(), facecolor='gray', alpha=0.2, label = train_label)

        if val_time is not None:
          val_target_if = self.val_prediction_data[f"{output_name}_actual"][:, f]
          val_prediction_if = self.val_prediction_data[f"{output_name}_prediction"][:, f]
          val_loss_if = np.round(self.val_prediction_data[f"{output_name}_{self.loss_fn.name}"][f].item(),2)
          val_metric_if = np.round(self.val_prediction_data[f"{output_name}_{self.metric_fn.name}"][f].item(),2) if self.metric_fn is not None else None
          if include_baseline:
            val_baseline_prediction_if = self.val_prediction_data[f"{output_name}_baseline_prediction"][:, f]
            val_baseline_loss_if = np.round(self.val_prediction_data[f"{output_name}_baseline_{self.loss_fn.name}"][f].item(),2)
            val_baseline_metric_if = np.round(self.val_prediction_data[f"{output_name}_baseline_{self.metric_fn.name}"][f].item(),2) if self.metric_fn is not None else None

          ax_if.plot(val_time, val_target_if, '-k')
          ax_if.plot(val_time, val_prediction_if, '-r')
          val_label = f"Val ({self.loss_fn.name} = {val_loss_if}, {self.metric_fn.name} = {val_metric_if})" \
                        if val_metric_if is not None \
                        else f"Val ({self.loss_fn.name} = {val_loss_if})"
          if include_baseline:
            ax_if.plot(val_time, val_baseline_prediction_if, '--g', linewidth = 1.0)
            val_label = val_label + f", Baseline ({self.loss_fn.name} = {val_baseline_loss_if}, {self.metric_fn.name} = {val_baseline_metric_if})"

          ax_if.axvspan(val_time.min(), val_time.max(), facecolor='blue', alpha=0.2, label = val_label)

        if test_time is not None:
          test_target_if = self.test_prediction_data[f"{output_name}_actual"][:, f]
          test_prediction_if = self.test_prediction_data[f"{output_name}_prediction"][:, f]
          test_loss_if = np.round(self.test_prediction_data[f"{output_name}_{self.loss_fn.name}"][f].item(),2)
          test_metric_if = np.round(self.test_prediction_data[f"{output_name}_{self.metric_fn.name}"][f].item(),2) if self.metric_fn is not None else None
          if include_baseline:
            test_baseline_prediction_if = self.test_prediction_data[f"{output_name}_baseline_prediction"][:, f]
            test_baseline_loss_if = np.round(self.test_prediction_data[f"{output_name}_baseline_{self.loss_fn.name}"][f].item(),2)
            test_baseline_metric_if = np.round(self.test_prediction_data[f"{output_name}_baseline_{self.metric_fn.name}"][f].item(),2) if self.metric_fn is not None else None

          ax_if.plot(test_time, test_target_if, '-k')
          ax_if.plot(test_time, test_prediction_if, '-r')
          test_label = f"Test ({self.loss_fn.name} = {test_loss_if}, {self.metric_fn.name} = {test_metric_if})" \
                        if test_metric_if is not None \
                        else f"Test ({self.loss_fn.name} = {test_loss_if})"
          if include_baseline:
            ax_if.plot(test_time, test_baseline_prediction_if, '--g', linewidth = 1.0)
            test_label = test_label + f", Baseline ({self.loss_fn.name} = {test_baseline_loss_if}, {self.metric_fn.name} = {test_baseline_metric_if})"

          ax_if.axvspan(test_time.min(), test_time.max(), facecolor='red', alpha=0.2, label = test_label)

        if (f == 0) & (feature_names is not None):
          ax_if.set_title(output_name)
        if f == output_size[i] - 1:
          ax_if.set_xlabel(f"Time [{self.trainer.datamodule.time_unit}]")

        if feature_names is None:
          ylabel = f"{output_name} [{output_feature_units_if}]" if output_feature_units_if is not None else f"{output_name}"
        elif output_feature_name_if is not None:
          ylabel = f"{output_feature_name_if} [{output_feature_units_if}]" if output_feature_units_if is not None else f"{output_feature_name_if}"
        else:
          ylabel = f"[{output_feature_units_if}]" if output_feature_units_if is not None else None

        ax_if.set_ylabel(ylabel)

        ax_if.legend(loc='upper left', bbox_to_anchor=(1.02, 1), ncol=1) # loc = 'upper center', bbox_to_anchor = (0.5, 1.15), ncol = 5))
        ax_if.grid()

    if num_outputs > 1:
      for i in range(num_outputs, rows):
          ax[i].axis("off")

    fig.tight_layout()

    self.actual_prediction_plot = plt.gcf()
  ##

  ## forecast
  def forecast(self, num_forecast_steps = 1, hiddens = None):

    with torch.no_grad():
      steps = None

      if self.trainer.datamodule.test_dl is not None:
        for batch in self.trainer.datamodule.test_dl.dl: last_sample = batch
        data = self.trainer.datamodule.test_data
      elif self.trainer.datamodule.val_dl is not None:
        for batch in self.trainer.datamodule.val_dl.dl: last_sample = batch
        data = self.trainer.datamodule.val_data
      else:
        for batch in self.trainer.datamodule.train_dl.dl: last_sample = batch
        data = self.trainer.datamodule.train_data

      input, _, steps, batch_size = last_sample

      last_input_sample, last_steps_sample = input[:batch_size][-1:], steps[:batch_size][-1:]

      max_output_len = self.trainer.datamodule.max_output_len
      max_input_size, max_output_size = self.trainer.datamodule.max_input_size, self.trainer.datamodule.max_output_size
      output_mask = self.trainer.datamodule.train_output_mask
      output_input_idx, input_output_idx = self.trainer.datamodule.output_input_idx, self.trainer.datamodule.input_output_idx

      output, hiddens = self.forward(input = last_input_sample,
                                    steps = last_steps_sample,
                                    hiddens = hiddens,
                                    target = None,
                                    output_len = max_output_len,
                                    output_mask = output_mask)

      forecast = torch.empty((1, 0, max_output_size)).to(output)
      forecast_steps = torch.empty((1, 0)).to(last_steps_sample)

      input, steps = last_input_sample, last_steps_sample

      steps += max_output_len

      while forecast.shape[1] < num_forecast_steps:

        input_ = torch.zeros((1, max_output_len, max_input_size)).to(input)

        if len(output_input_idx) > 0:
          input_[:, :, output_input_idx] = output[:, -max_output_len:, input_output_idx]

        input = torch.cat((input[:, max_output_len:], input_), 1)

        output, hiddens = self.forward(input = input,
                                       steps = steps,
                                       hiddens = hiddens,
                                       target = None,
                                       output_len = max_output_len,
                                       output_mask = output_mask)

        forecast = torch.cat((forecast, output[:, -max_output_len:]), 1)
        forecast_steps = torch.cat((forecast_steps, steps[:, -max_output_len:]), 1)

        steps += max_output_len

      forecast, forecast_steps = forecast[:, -num_forecast_steps:], forecast_steps[:, -num_forecast_steps:]
      forecast_reduced, forecast_steps_reduced = self.generate_reduced_output(forecast, forecast_steps,
                                                                          transforms=self.trainer.datamodule.transforms)

      # self.forecast_data = {"warmup_time": }

    return forecast_reduced, forecast_steps_reduced


  ##
  def generate_reduced_output(self, output, output_steps, reduction='mean', transforms=None):

    # Get unique output steps and remove any -1 values
    unique_output_steps = output_steps.unique()
    unique_output_steps = unique_output_steps[unique_output_steps != -1]

    # Create a tensor to store the reduced output
    output_reduced = torch.zeros((len(unique_output_steps), np.sum(self.model.output_size))).to(output)

    k = -1
    for step in unique_output_steps:
        k += 1

        # Find the indices of the current step in the output_steps tensor
        batch_step_idx = torch.where(output_steps == step)
        num_step_output = len(batch_step_idx[0])

        j = 0
        for i in range(self.model.num_outputs):

            # Extract the output for the current output index
            output_i = output[:, :, j:(j + self.model.output_size[i])]
            output_reduced_i = []

            step_output_i = []
            for batch_idx, step_idx in zip(*batch_step_idx[:2]):
                step_output_i.append(output_i[batch_idx, step_idx, :].reshape(1, 1, -1))

            if len(step_output_i) > 0:
                step_output_i = torch.cat(step_output_i, 0)

                # Reduce the step outputs based on the specified reduction method
                step_output_reduced_i = (step_output_i.median(0)[0] if reduction == 'median' else
                                         step_output_i.mean(0)).reshape(-1, self.model.output_size[i])

                # Assign the reduced output to the output_reduced tensor
                output_reduced[k, j:(j + self.model.output_size[i])] = step_output_reduced_i.squeeze(0)

            j += self.model.output_size[i]

    # Optionally invert the reduced output using data scalers
    if transforms is not None:
        j = 0
        for i in range(self.model.num_outputs):
            output_name_i = self.trainer.datamodule.output_names[i]
            output_reduced[:, j:(j + self.model.output_size[i])] = transforms[output_name_i].inverse_transform(output_reduced[:, j:(j + self.model.output_size[i])])
            j += self.model.output_size[i]

    # Return the reduced output and unique output steps
    return output_reduced, unique_output_steps

  def fit(self,
          datamodule,
          max_epochs = 20,
          callbacks = [None]):

    try:
      self.trainer = pl.Trainer(max_epochs = max_epochs,
                                accelerator = 'gpu' if self.model.device == 'cuda' else 'cpu',
                                callbacks = callbacks)

      self.trainer.fit(self,
                       datamodule = datamodule)

    except KeyboardInterrupt:
      state_dict = self.model.state_dict()
      self.model.to(device = self.model.device,
                    dtype = self.model.dtype)
      self.model.load_state_dict(state_dict)
