In [None]:
import torch
import torch.nn as nn
import yaml
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
from typing import Dict, Tuple, Union

from sssd.core.model_specs import setup_model, MASK_FN
from sssd.utils.utils import (
    calc_diffusion_hyperparams,
    sampling,
    std_normal
)
from sssd.core.imputers.SSSDS4Imputer import SSSDS4Imputer
from sssd.core.layers.s4.s4_layer import S4Layer
from sssd.core.utils import calc_diffusion_step_embedding
from sssd.data.generator import ArDataGenerator
from sssd.data.dataloader import ArDataLoader

## Data Visualization

In [None]:
ar_coefs = [0.8]
series_length = 128
season_period = 12

# Generate data with intercept (mean = 3)
data_with_intercept = ArDataGenerator(ar_coefs, series_length, std=5, intercept=100, season_period=season_period).generate()

# Generate data without intercept (mean = 0)
data_without_intercept = ArDataGenerator(ar_coefs, series_length, std=5, season_period=season_period).generate()

# Plot the results
plt.plot(data_with_intercept, label="With Intercept (Mean = 3, Std = 5)")
plt.plot(data_without_intercept, label="Without Intercept (Mean = 0, Std = 5)")
plt.xlabel("Time Step")
plt.ylabel("Value")
plt.title("AR Process with and Without Intercept")
plt.legend()
plt.show()

## Simulation Setup

In [None]:
num_series = 1024
coefficients = [0.8] 
series_length = 128
std = 1
intercept = 100
season = 12
batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 4
training_rate = 0.8
seeds = list(range(num_series))

data_loader = ArDataLoader(
    coefficients,
    num_series,
    series_length,
    std,
    intercept,
    season,
    batch_size,
    device,
    num_workers,
    training_rate,
    seeds,
)

train_loader = data_loader.train_dataloader
test_loader = data_loader.test_dataloader

In [None]:
with open("../configs/model.yaml", "rt") as f:
    model_config = yaml.safe_load(f.read())
with open("../configs/training.yaml", "rt") as f:
    training_config = yaml.safe_load(f.read())

with open("../configs/inference.yaml", "rt") as f:
    inference_config = yaml.safe_load(f.read())

In [None]:
def update_mask(batch: torch.Tensor) -> torch.Tensor:
    """Update mask based on the given batch."""
    transposed_mask = MASK_FN["forecast"](batch[0], 24)
    return (
        transposed_mask.permute(1, 0)
        .repeat(batch.size()[0], 1, 1)
        .to(device, dtype=torch.float32)
    )

In [None]:
net = SSSDS4Imputer(**model_config.get("wavenet"), device=device)
net = net.to(device)
net = nn.DataParallel(net)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
diffusion_hyperparams = calc_diffusion_hyperparams(
    **model_config["diffusion"], device=device
)

In [None]:
def training_loss(
    model: torch.nn.Module,
    training_data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    diffusion_parameters: Dict[str, torch.Tensor],
    generate_only_missing: int = 1,
    device: str = "cpu",
) -> torch.Tensor:
    """
    Compute the training loss of epsilon and epsilon_theta.

    Args:
        model (torch.nn.Module): The neural network model.
        training_data (tuple): Training data tuple containing (time_series, condition, mask, loss_mask).
        diffusion_parameters (dict): Dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams.
                                     Note, the tensors need to be cuda tensors.
        generate_only_missing (int): Flag to indicate whether to only generate missing values (default=1).
        device (str): Device to run the computations on (default="cuda").

    Returns:
        torch.Tensor: Training loss.
    """

    # Unpack diffusion hyperparameters
    T, alpha_bar = diffusion_parameters["T"], diffusion_parameters["Alpha_bar"]

    # Unpack training data
    time_series, condition, mask, loss_mask = training_data

    batch_size = time_series.shape[0]

    # Sample random diffusion steps for each batch element
    diffusion_steps = torch.randint(T, size=(batch_size, 1, 1)).to(device)
    # Generate Gaussian noise, applying mask if specified
    noise = (
        time_series * mask.float()
        + std_normal(time_series.shape, device) * (1 - mask).float()
        if generate_only_missing
        else std_normal(time_series.shape, device)
    )

    # Compute x_t from q(x_t|x_0)
    transformed_series = (
        torch.sqrt(alpha_bar[diffusion_steps]) * time_series
        + torch.sqrt(1 - alpha_bar[diffusion_steps]) * noise
    )

    # Predict epsilon according to epsilon_theta
    epsilon_theta = model(
        (transformed_series, condition, mask, diffusion_steps.view(batch_size, 1))
    )

    # Compute loss
    if generate_only_missing:
        return nn.MSELoss()(epsilon_theta[loss_mask], noise[loss_mask])#, epsilon_theta[loss_mask], noise[loss_mask]
    else:
        return nn.MSELoss()(epsilon_theta, noise)#, epsilon_theta[loss_mask], noise[loss_mask]


## Training

In [None]:
losses = []
epochs = 100
for epoch in range(epochs):  # Train for 100 epochs (0-indexed)
    epoch_loss = 0
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}") as pbar:
        for batch in pbar:
            batch = batch.to(device)
            mask = update_mask(batch)
            loss_mask = ~mask.bool()

            batch = batch.permute(0, 2, 1)
            assert batch.size() == mask.size() == loss_mask.size()

            optimizer.zero_grad()
            loss = training_loss(
                model=net,
                training_data=(batch, batch, mask, loss_mask),
                diffusion_parameters=diffusion_hyperparams,
                generate_only_missing=training_config.get("only_generate_missing"),
                device=device,
            )
            loss.backward()
            optimizer.step()

            epoch_loss += loss.cpu().detach().numpy() / len(train_loader)
            pbar.set_postfix_str(f"Loss: {epoch_loss:.4f}")  # Update progress bar with loss
    losses.append(epoch_loss)  # Append epoch loss to main list

print(f"Finished training for {len(losses)} epochs.")


In [None]:
epochs = range(1, len(losses) + 1)

# Plotting the losses
plt.figure(figsize=(10, 6))
plt.plot(epochs, losses, marker='o', linestyle='-', color='b', label='Training Loss')

# Adding title and labels
plt.title('Training Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Show grid
plt.grid(True)

# Display the plot
plt.show()

In [None]:
_dh = diffusion_hyperparams
T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
batch = next(iter(test_loader))
size = batch.shape
cond = batch.to(device)
x = std_normal(size, device) * 5 + 100
only_generate_missing = 1
mask = update_mask(batch).permute(0, 2, 1)

In [None]:
# Assuming batch and generated_series are numpy arrays
batch_mean = np.mean(batch.numpy(), axis=0).squeeze()
generated_series_mean = np.mean(x.cpu().numpy(), axis=0).squeeze()

plt.figure(figsize=(12, 6))
plt.plot(np.arange(series_length), batch_mean, label='Batch Mean')
plt.plot(np.arange(series_length), generated_series_mean, label='Generated Series Mean')
plt.legend()
plt.show()


## Testing

In [None]:
result = []
result2 = []

epochs = 10
for epoch in range(epochs):
    with tqdm(test_loader, desc=f"Epoch {epoch + 1}") as pbar:
        for batch in pbar:
            mask = update_mask(batch)
            batch = batch.permute(0, 2, 1)

            generated_series, generated_series2 = sampling(
                    net=net,
                    size=batch.shape,
                    diffusion_hyperparams=diffusion_hyperparams,
                    cond=batch.to(device),
                    mask=mask,
                    only_generate_missing=0,
                    device=device,
                ) 
            
        result.append(generated_series.detach().cpu().numpy().squeeze())
        result2.append(generated_series2.detach().cpu().numpy().squeeze())



In [None]:
stack_result = np.stack(result, axis=0)
pred = np.mean(stack_result, axis=0)

In [None]:
# Assuming batch and generated_series are numpy arrays
batch_mean = np.mean(batch.numpy(), axis=0).squeeze()
generated_series_mean = np.mean((pred.squeeze()), axis=0).squeeze()

plt.figure(figsize=(12, 6))
plt.plot(np.arange(series_length), batch_mean, label='Batch Mean')
plt.plot(np.arange(series_length), generated_series_mean, label='Generated Series Mean')
plt.legend()
plt.show()

In [None]:
stack_result = np.stack(result, axis=0)
pred = np.mean(stack_result, axis=0)
pred_med  = np.median(stack_result, axis=0)

In [None]:
test_data = torch.stack(list(test_loader.dataset))
target = test_data[:,-24:, :].transpose(0, 2, 1).squeeze()
test_mean = np.mean(test_data[:,:168, :], axis=1)
test_std = np.std(test_data[:,:168, :], axis=1)

In [None]:
print(mean_squared_error(target, pred))
print(mean_squared_error(target, pred_med))

In [None]:
print(mean_absolute_percentage_error(target, pred))
print(mean_absolute_percentage_error(target, pred_med))

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Calculate mean and standard deviation for each hour
mean_target = np.mean(target, axis=0)
std_target = np.std(target, axis=0)
mean_pred = np.mean(pred, axis=0)
std_pred = np.std(pred, axis=0)

# Generate hourly labels
hours = np.arange(24)

# Plotting
plt.figure(figsize=(12, 6))

# Plot target mean with standard deviation band
plt.plot(hours, mean_target, label='Target', marker='o')
plt.fill_between(hours, mean_target - std_target, mean_target + std_target, alpha=0.2)

# Plot prediction mean with standard deviation band
plt.plot(hours, mean_pred, label='Prediction', marker='x')
plt.fill_between(hours, mean_pred - std_pred, mean_pred + std_pred, alpha=0.2)

plt.xlabel('Hour of the Day')
plt.ylabel('Value')
plt.title('Time Series Comparison: Target vs Prediction')
plt.legend()
plt.grid(True)
plt.xticks(hours)

plt.show()

