In [None]:
import boto3
from sssd.core.model_specs import setup_model
import torch
import yaml
from sssd.core.model_specs import MASK_FN
import torch.nn as nn
from sssd.data.utils import get_dataloader
from sssd.utils.utils import calc_diffusion_hyperparams
from sssd.utils.utils import sampling
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, Dict, Union

import math
from tqdm import tqdm

from sssd.core.layers.s4.s4_layer import S4Layer
from sssd.core.utils import calc_diffusion_step_embedding
from sssd.data.generator import ArDataGenerator
from sssd.data.dataloader import ArDataLoader

In [None]:
ar_coefs = [0.8]
series_length = 128
season_period = 12

# Generate data with intercept (mean = 3)
data_with_intercept = ArDataGenerator(ar_coefs, series_length, std=5, intercept=100, season_period=season_period).generate()

# Generate data without intercept (mean = 0)
data_without_intercept = ArDataGenerator(ar_coefs, series_length, std=5, season_period=season_period).generate()

# Plot the results
plt.plot(data_with_intercept, label="With Intercept (Mean = 3, Std = 5)")
plt.plot(data_without_intercept, label="Without Intercept (Mean = 0, Std = 5)")
plt.xlabel("Time Step")
plt.ylabel("Value")
plt.title("AR Process with and Without Intercept")
plt.legend()
plt.show()

In [None]:
num_series = 1024
coefficients = [0.8] 
series_length = 128
std = 1
intercept = 100
season = 12
batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 4
training_rate = 0.8
seeds = list(range(num_series))

data_loader = ArDataLoader(
    coefficients,
    num_series,
    series_length,
    std,
    intercept,
    season,
    batch_size,
    device,
    num_workers,
    training_rate,
    seeds,
)

train_loader = data_loader.train_dataloader
test_loader = data_loader.test_dataloader

In [None]:
with open("../configs/model.yaml", "rt") as f:
    model_config = yaml.safe_load(f.read())
with open("../configs/training.yaml", "rt") as f:
    training_config = yaml.safe_load(f.read())

with open("../configs/inference.yaml", "rt") as f:
    inference_config = yaml.safe_load(f.read())

In [None]:
def update_mask(batch: torch.Tensor) -> torch.Tensor:
    """Update mask based on the given batch."""
    transposed_mask = MASK_FN["forecast"](batch[0], 24)
    return (
        transposed_mask.permute(1, 0)
        .repeat(batch.size()[0], 1, 1)
        .to(device, dtype=torch.float32)
    )

In [None]:



def swish(x):
    return x * torch.sigmoid(x)


class Conv(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size=3, dilation=1):
        super().__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.conv = nn.Conv1d(
            input_channels,
            output_channels,
            kernel_size,
            dilation=dilation,
            padding=self.padding,
        )
        self.conv = nn.utils.parametrizations.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        out = self.conv(x)
        return out


class ZeroConv1d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ZeroConv1d, self).__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

    def forward(self, x):
        out = self.conv(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(
        self,
        residual_channels,
        skip_channels,
        diffusion_step_embed_dim_output,
        input_channels,
        s4_max_sequence_length,
        s4_state_dim,
        s4_dropout,
        s4_bidirectional,
        s4_use_layer_norm,
    ):
        super().__init__()
        self.residual_channels = residual_channels

        self.fc_t = nn.Linear(diffusion_step_embed_dim_output, self.residual_channels)

        self.S41 = S4Layer(
            features=2 * self.residual_channels,
            lmax=s4_max_sequence_length,
            N=s4_state_dim,
            dropout=s4_dropout,
            bidirectional=s4_bidirectional,
            layer_norm=s4_use_layer_norm,
        )

        self.conv_layer = Conv(
            self.residual_channels, 2 * self.residual_channels, kernel_size=3
        )

        self.S42 = S4Layer(
            features=2 * self.residual_channels,
            lmax=s4_max_sequence_length,
            N=s4_state_dim,
            dropout=s4_dropout,
            bidirectional=s4_bidirectional,
            layer_norm=s4_use_layer_norm,
        )

        self.cond_conv = Conv(
            2 * input_channels, 2 * self.residual_channels, kernel_size=1
        )
        self.res_conv = nn.Conv1d(residual_channels, residual_channels, kernel_size=1)
        self.res_conv = nn.utils.parametrizations.weight_norm(self.res_conv)
        nn.init.kaiming_normal_(self.res_conv.weight)

        self.skip_conv = nn.Conv1d(residual_channels, skip_channels, kernel_size=1)
        self.skip_conv = nn.utils.parametrizations.weight_norm(self.skip_conv)
        nn.init.kaiming_normal_(self.skip_conv.weight)

    def forward(self, input_data):
        x, cond, diffusion_step_embed = input_data
        h = x
        B, C, L = x.shape
        assert C == self.residual_channels

        part_t = self.fc_t(diffusion_step_embed)
        part_t = part_t.view([B, self.residual_channels, 1])
        h = h + part_t

        h = self.conv_layer(h)
        h = self.S41(h.permute(2, 0, 1)).permute(1, 2, 0)

        assert cond is not None
        cond = self.cond_conv(cond)
        h += cond

        h = self.S42(h.permute(2, 0, 1)).permute(1, 2, 0)

        out = torch.tanh(h[:, : self.residual_channels, :]) * torch.sigmoid(
            h[:, self.residual_channels :, :]
        )

        res = self.res_conv(out)
        assert x.shape == res.shape
        skip = self.skip_conv(out)

        return (x + res) * math.sqrt(0.5), skip  # normalize for training stability


class ResidualGroup(nn.Module):
    def __init__(
        self,
        residual_channels,
        skip_channels,
        residual_layers,
        diffusion_step_embed_dim_input,
        diffusion_step_embed_dim_hidden,
        diffusion_step_embed_dim_output,
        input_channels,
        s4_max_sequence_length,
        s4_state_dim,
        s4_dropout,
       s4_bidirectional,
        s4_use_layer_norm,
        device="cuda",
    ):
        super(ResidualGroup, self).__init__()
        self.residual_layers = residual_layers
        self.diffusion_step_embed_dim_input = diffusion_step_embed_dim_input

        self.fc_t1 = nn.Linear(
            diffusion_step_embed_dim_input, diffusion_step_embed_dim_hidden
        )
        self.fc_t2 = nn.Linear(
            diffusion_step_embed_dim_hidden, diffusion_step_embed_dim_output
        )

        self.residual_blocks = nn.ModuleList()
        for n in range(self.residual_layers):
            self.residual_blocks.append(
                ResidualBlock(
                    residual_channels,
                    skip_channels,
                    diffusion_step_embed_dim_output=diffusion_step_embed_dim_output,
                    input_channels=input_channels,
                    s4_max_sequence_length=s4_max_sequence_length,
                    s4_state_dim=s4_state_dim,
                    s4_dropout=s4_dropout,
                    s4_bidirectional=s4_bidirectional,
                    s4_use_layer_norm=s4_use_layer_norm,
                )
            )

        self.device = device

    def forward(self, input_data):
        noise, conditional, diffusion_steps = input_data

        diffusion_step_embed = calc_diffusion_step_embedding(
            diffusion_steps, self.diffusion_step_embed_dim_input, device=self.device
        )
        diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
        diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))

        h = noise
        skip = 0
        for n in range(self.residual_layers):
            h, skip_n = self.residual_blocks[n]((h, conditional, diffusion_step_embed))
            skip += skip_n

        return skip * math.sqrt(1.0 / self.residual_layers)


class SSSDS4Imputer(nn.Module):
    def __init__(
        self,
        input_channels,
        residual_channels,
        skip_channels,
        output_channels,
        residual_layers,
        diffusion_step_embed_dim_input,
        diffusion_step_embed_dim_hidden,
        diffusion_step_embed_dim_output,
        s4_max_sequence_length,
        s4_state_dim,
        s4_dropout,
        s4_bidirectional,
        s4_use_layer_norm,
        device="cuda",
    ):
        super().__init__()

        self.init_conv = nn.Sequential(
            Conv(input_channels, residual_channels, kernel_size=1),
            nn.ReLU(),
        )

        self.residual_layer = ResidualGroup(
            residual_channels=residual_channels,
            skip_channels=skip_channels,
            residual_layers=residual_layers,
            diffusion_step_embed_dim_input=diffusion_step_embed_dim_input,
            diffusion_step_embed_dim_hidden=diffusion_step_embed_dim_hidden,
            diffusion_step_embed_dim_output=diffusion_step_embed_dim_output,
            input_channels=input_channels,
            s4_max_sequence_length=s4_max_sequence_length,
            s4_state_dim=s4_state_dim,
            s4_dropout=s4_dropout,
            s4_bidirectional=s4_bidirectional,
            s4_use_layer_norm=s4_use_layer_norm,
            device=device,
        )

        self.final_conv = nn.Sequential(
            Conv(skip_channels, skip_channels, kernel_size=1),
            nn.ReLU(),
            ZeroConv1d(skip_channels, output_channels),
        )

    def forward(self, input_data):
        noise, conditional, mask, diffusion_steps = input_data

        conditional = conditional * mask
        conditional = torch.cat([conditional, mask.float()], dim=1)

        x = noise
        x = self.init_conv(x)
        x = self.residual_layer((x, conditional, diffusion_steps))
        y = self.final_conv(x)
        return y


In [None]:
# model_config["residual_layers"] = 2
# model_config["residual_channels"] = 4

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = SSSDS4Imputer(**model_config.get("wavenet"), device=device)
net = net.to(device)
net = nn.DataParallel(net)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

In [None]:
diffusion_hyperparams = calc_diffusion_hyperparams(
    **model_config["diffusion"], device=device
)

In [None]:
from typing import Dict, Tuple

import torch

from sssd.utils.utils import std_normal


def training_loss(
    model: torch.nn.Module,
    training_data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    diffusion_parameters: Dict[str, torch.Tensor],
    generate_only_missing: int = 1,
    device: str = "cpu",
) -> torch.Tensor:
    """
    Compute the training loss of epsilon and epsilon_theta.

    Args:
        model (torch.nn.Module): The neural network model.
        training_data (tuple): Training data tuple containing (time_series, condition, mask, loss_mask).
        diffusion_parameters (dict): Dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams.
                                     Note, the tensors need to be cuda tensors.
        generate_only_missing (int): Flag to indicate whether to only generate missing values (default=1).
        device (str): Device to run the computations on (default="cuda").

    Returns:
        torch.Tensor: Training loss.
    """

    # Unpack diffusion hyperparameters
    T, alpha_bar = diffusion_parameters["T"], diffusion_parameters["Alpha_bar"]

    # Unpack training data
    time_series, condition, mask, loss_mask = training_data

    batch_size = time_series.shape[0]

    # Sample random diffusion steps for each batch element
    diffusion_steps = torch.randint(T, size=(batch_size, 1, 1)).to(device)
    # Generate Gaussian noise, applying mask if specified
    noise = (
        time_series * mask.float()
        + std_normal(time_series.shape, device) * (1 - mask).float()
        if generate_only_missing
        else std_normal(time_series.shape, device)
    )

    # Compute x_t from q(x_t|x_0)
    transformed_series = (
        torch.sqrt(alpha_bar[diffusion_steps]) * time_series
        + torch.sqrt(1 - alpha_bar[diffusion_steps]) * noise
    )

    # Predict epsilon according to epsilon_theta
    epsilon_theta = model(
        (transformed_series, condition, mask, diffusion_steps.view(batch_size, 1))
    )

    # Compute loss
    if generate_only_missing:
        return nn.MSELoss()(epsilon_theta[loss_mask], noise[loss_mask])#, epsilon_theta[loss_mask], noise[loss_mask]
    else:
        return nn.MSELoss()(epsilon_theta, noise)#, epsilon_theta[loss_mask], noise[loss_mask]


In [None]:
losses = []
epochs = 100
for epoch in range(epochs):  # Train for 100 epochs (0-indexed)
    epoch_loss = 0
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}") as pbar:
        for batch in pbar:
            batch = batch.to(device)
            mask = update_mask(batch)
            loss_mask = ~mask.bool()

            batch = batch.permute(0, 2, 1)
            assert batch.size() == mask.size() == loss_mask.size()

            optimizer.zero_grad()
            loss = training_loss(
                model=net,
                training_data=(batch, batch, mask, loss_mask),
                diffusion_parameters=diffusion_hyperparams,
                generate_only_missing=training_config.get("only_generate_missing"),
                device=device,
            )
            loss.backward()
            optimizer.step()

            epoch_loss += loss.cpu().detach().numpy() / len(train_loader)
            pbar.set_postfix_str(f"Loss: {epoch_loss:.4f}")  # Update progress bar with loss
    losses.append(epoch_loss)  # Append epoch loss to main list

print(f"Finished training for {len(losses)} epochs.")


In [None]:
epochs = range(1, len(losses) + 1)

# Plotting the losses
plt.figure(figsize=(10, 6))
plt.plot(epochs, losses, marker='o', linestyle='-', color='b', label='Training Loss')

# Adding title and labels
plt.title('Training Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Show grid
plt.grid(True)

# Display the plot
plt.show()

In [None]:

def std_normal(size: Tuple[int], device: Union[torch.device, str]) -> torch.Tensor:
    """
    Generate samples from the standard normal distribution of a specified size.

    Args:
        size (Tuple[int]): Size of the tensor to be generated.
        device (Union[torch.device, str]): Device to run the computations on (e.g., 'cpu' or 'cuda').

    Returns:
        torch.Tensor: Tensor containing samples from the standard normal distribution.
    """
    return torch.normal(0, 1, size=size).to(device)


def sampling(
    net: torch.nn.Module,
    size: Tuple[int, int, int],
    diffusion_hyperparams: Dict[str, torch.Tensor],
    cond: torch.Tensor,
    mask: torch.Tensor,
    only_generate_missing: int = 0,
    device: Union[torch.device, str] = "cpu",
) -> torch.Tensor:
    """
    Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t).

    Args:
        net (torch.nn.Module): The wavenet model.
        size (Tuple[int, int, int]): Size of tensor to be generated, usually (number of audios to generate, channels=1, length of audio).
        diffusion_hyperparams (Dict[str, torch.Tensor]): Dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams. Note, the tensors need to be cuda tensors.
        cond (torch.Tensor): Conditioning tensor.
        mask (torch.Tensor): Mask tensor.
        only_generate_missing (int, optional): Flag indicating whether to only generate missing values (default is 0).
        device (Union[torch.device, str], optional): Device to place tensors (default is 'cpu').

    Returns:
        torch.Tensor: The generated audio(s) in torch.Tensor, shape=size.
    """

    _dh = diffusion_hyperparams
    T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
    assert len(Alpha) == T
    assert len(Alpha_bar) == T
    assert len(Sigma) == T
    assert len(size) == 3

    x = std_normal(size, device)
    #print(x.shape, cond.shape)
    with torch.no_grad():
        for t in range(T - 1, -1, -1):
            if only_generate_missing == 1:
                x = x * (1 - mask).float() + cond * mask.float()
            diffusion_steps = (t * torch.ones((size[0], 1))).to(
                device
            )  # use the corresponding reverse step
            epsilon_theta = net(
                (x, cond, mask, diffusion_steps)
            )  # predict \epsilon according to \epsilon_\theta
            # update x_{t-1} to \mu_\theta(x_t)
            x = (
                x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta
            ) / torch.sqrt(Alpha[t])
            if t > 0:
                x = x + Sigma[t] * std_normal(
                    size, device
                )  # add the variance term to x_{t-1}

    return intercept + std * x, epsilon_theta


In [None]:
for t in range(T - 1, -1, -1):
    print((1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t])/torch.sqrt(Alpha[t]))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
_dh = diffusion_hyperparams
T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
batch = next(iter(test_loader))
size = batch.shape
cond = batch.to(device)
x = std_normal(size, device) * 5 + 100
only_generate_missing = 1
mask = update_mask(batch).permute(0, 2, 1)

In [None]:
import numpy as np
import matplotlib.pyplot as plt


# Assuming batch and generated_series are numpy arrays
batch_mean = np.mean(batch.numpy(), axis=0).squeeze()
generated_series_mean = np.mean(x.cpu().numpy(), axis=0).squeeze()

plt.figure(figsize=(12, 6))
plt.plot(np.arange(series_length), batch_mean, label='Batch Mean')
plt.plot(np.arange(series_length), generated_series_mean, label='Generated Series Mean')
plt.legend()
plt.show()


In [None]:
# import torch
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# from tqdm import tqdm

# fig, ax = plt.subplots()
# im = ax.imshow(x.cpu().squeeze(), cmap='gray', vmin=-1, vmax=1)
# plt.close()  # Prevents the static plot from displaying

# def update_frame(t):
#     global x
#     with torch.no_grad():
#         if only_generate_missing == 1:
#             x = x * (1 - mask).float() + cond * mask.float()
#         diffusion_steps = (t * torch.ones((size[0], 1))).to(device)
#         epsilon_theta = net((x, cond, mask, diffusion_steps))
#         x = (x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t])
#         if t > 0:
#             x = x + Sigma[t] * std_normal(size, device)
#         im.set_array(x.cpu().squeeze())
#         return [im]

# # Create the animation
# ani = animation.FuncAnimation(fig, update_frame, frames=tqdm(range(T - 1, -1, -1)), blit=True)

# # Display the animation in the notebook
# from IPython.display import HTML
# HTML(ani.to_jshtml())


In [None]:
from tqdm import tqdm

result = []
result2 = []

epochs = 1
for epoch in range(epochs):
    with tqdm(test_loader, desc=f"Epoch {epoch + 1}") as pbar:
        for batch in pbar:
            mask = update_mask(batch)
            batch = batch.permute(0, 2, 1)

            generated_series, generated_series2 = sampling(
                    net=net,
                    size=batch.shape,
                    diffusion_hyperparams=diffusion_hyperparams,
                    cond=batch.to(device),
                    mask=mask,
                    only_generate_missing=0,
                    device=device,
                ) 
            
        result.append(generated_series.detach().cpu().numpy().squeeze())
        result2.append(generated_series2.detach().cpu().numpy().squeeze())



In [None]:
stack_result = np.stack(result, axis=0)
pred = np.mean(stack_result, axis=0)

In [None]:
# Assuming batch and generated_series are numpy arrays
batch_mean = np.mean(batch.numpy(), axis=0).squeeze()
generated_series_mean = np.mean((pred.squeeze()), axis=0).squeeze()

plt.figure(figsize=(12, 6))
plt.plot(np.arange(series_length), batch_mean, label='Batch Mean')
plt.plot(np.arange(series_length), generated_series_mean, label='Generated Series Mean')
plt.legend()
plt.show()

In [None]:
stack_result = np.stack(result, axis=0)
pred = np.mean(stack_result, axis=0)
pred_med  = np.median(stack_result, axis=0)

In [None]:
target = test_data[:,-24:, :].transpose(0, 2, 1).squeeze()

In [None]:
test_mean = np.mean(test_data[:,:168, :], axis=1)
test_std = np.std(test_data[:,:168, :], axis=1)

In [None]:
print(mean_squared_error(target, pred))
print(mean_squared_error(target, pred_med))

In [None]:
print(mean_absolute_percentage_error(target, pred))
print(mean_absolute_percentage_error(target, pred_med))

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Calculate mean and standard deviation for each hour
mean_target = np.mean(target, axis=0)
std_target = np.std(target, axis=0)
mean_pred = np.mean(pred, axis=0)
std_pred = np.std(pred, axis=0)

# Generate hourly labels
hours = np.arange(24)

# Plotting
plt.figure(figsize=(12, 6))

# Plot target mean with standard deviation band
plt.plot(hours, mean_target, label='Target', marker='o')
plt.fill_between(hours, mean_target - std_target, mean_target + std_target, alpha=0.2)

# Plot prediction mean with standard deviation band
plt.plot(hours, mean_pred, label='Prediction', marker='x')
plt.fill_between(hours, mean_pred - std_pred, mean_pred + std_pred, alpha=0.2)

plt.xlabel('Hour of the Day')
plt.ylabel('Value')
plt.title('Time Series Comparison: Target vs Prediction')
plt.legend()
plt.grid(True)
plt.xticks(hours)

plt.show()

