In [1]:
from tkgngc.embeddings import PretrainedTKGEmbeddingWithTimestamps
from tkgngc.model import NGCWithPretrainedTKGAndTimestamps

In [2]:
import os
import torch
import polars as pl
import pandas as pd
import numpy as np
import seaborn as sns
import warnings
from feature.scalers import ranged_scaler
from datetime import datetime, timedelta
#from mpge.rca import mpge_root_cause_diagnosis
warnings.filterwarnings("ignore", category=UserWarning) 

In [3]:
#os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
#os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

In [4]:
cats_df = pl.read_csv("data/data.csv", separator=",")  
metadata = pl.read_csv('data/metadata.csv',separator=',')
potential_causes = metadata['root_cause'].unique().to_list()

In [5]:
for col in cats_df.columns:
    unique_vals = cats_df[col].n_unique()
    data_type = cats_df[col].dtype
    bad_dtypes = [pl.Date,pl.Datetime,pl.Utf8]
    if ((unique_vals >= 50) & (data_type not in bad_dtypes) ):
        cats_df = cats_df.with_columns(ranged_scaler(cats_df[col]))
    else:
        continue

In [6]:
cats_df = cats_df.with_columns(
    pl.col('timestamp').str.to_datetime("%Y-%m-%d %H:%M:%S"),
    pl.Series("entity_id",range(len(cats_df)))
)
cats_rows_list = metadata.rows(named=True)

In [7]:
cats_rows_list = metadata.rows(named=True)
cats_df.head()


timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category,entity_id
datetime[μs],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64
2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.1078e-14,2.0428e-14,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623,0.0,0.0,0
2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618,0.0,0.0,1
2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604,0.0,0.0,2
2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658,0.0,0.0,3
2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0008,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548,0.0,0.0,4


In [8]:

cats_df = cats_df.to_pandas()

In [9]:
cats_df.head()

Unnamed: 0,timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,...,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category,entity_id
0,2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.107825e-14,2.04281e-14,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623,0.0,0.0,0
1,2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618,0.0,0.0,1
2,2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604,0.0,0.0,2
3,2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658,0.0,0.0,3
4,2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0007999999,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548,0.0,0.0,4


In [10]:
device = torch.device('cpu')

In [11]:
train_df = cats_df[0:1000000]
train = train_df[['aimp', 'amud', 'arnd', 'asin1', 'asin2', 'adbr', 'adfl',
       'bed1', 'bed2', 'bfo1', 'bfo2', 'bso1', 'bso2', 'bso3', 'ced1', 'cfo1',
       'cso1']]
test_df = cats_df[1000000:]


In [12]:
class tkgngc_data_processing:
    def __init__(self, data, device, num_timestamps=20):
        self.data = data
        self.ordered_column_names = self.data
        self.train_list = self.data.values.tolist()
        self.time_series_tensor = torch.tensor(self.train_list,dtype=torch.float32)
        # Entity and Relation indices
        self.entity_indices = torch.arange(len(self.train_list), dtype=torch.long)
        self.relation_indices = torch.tensor(
            [0 if i % 2 == 0 else 1 for i in range(len(self.entity_indices))],dtype=torch.long
        )

        self.timestamps = entity_indices
        # Timestamp binning
        self.num_timestamps = num_timestamps
        min_time, max_time = min(self.entity_indices),max(self.entity_indices)
        bins = torch.linspace(min_time,max_time,self.num_timestamps+1)
        self.timestamp_indices = torch.tensor(torch.bucketize(self.entity_indices,bins),dtype=torch.long) - 1
        self.timestamp_indices = torch.clamp(self.timestamp_indices, min = 0, max= num_timestamps -1)

        # Edge index

        self.edge_index = torch.tensor(
            [[i,i+1] for i in range(len(self.entity_indices) - 1)],dtype=torch.long).t()



    
        

In [13]:
ordered_column_names = train.columns
train_list = train.values.tolist()
time_series_tensor = torch.tensor(train_list,dtype=torch.float32)

In [14]:
entity_indices = torch.arange(len(train_list), dtype=torch.long)
relation_indices = torch.tensor(
    [0 if i % 2 == 0 else 1 for i in range(len(entity_indices))],dtype=torch.long
)

In [15]:
timestamps = entity_indices

In [16]:
num_timestamps = 20
min_time, max_time = min(entity_indices),max(entity_indices)
bins = torch.linspace(min_time,max_time,num_timestamps+1)
timestamp_indices = torch.tensor(torch.bucketize(entity_indices,bins),dtype=torch.long) - 1
timestamp_indices = torch.clamp(timestamp_indices, min = 0, max= num_timestamps -1)

In [17]:
edge_index = torch.tensor(
    [[i,i+1] for i in range(len(entity_indices) - 1)],dtype=torch.long).t()

In [18]:
time_series_tensor = time_series_tensor.to(device)
entity_indices=entity_indices.to(device)
relation_indices=relation_indices.to(device)
timestamp_indices=timestamp_indices.to(device)
edge_index=edge_index.to(device)

In [19]:
pretrained_tkg = PretrainedTKGEmbeddingWithTimestamps(
    num_entities=int(entity_indices.max().item()+1),
    num_relations=int(relation_indices.max().item()+1),
    embedding_dim=16,
    num_timestamps=num_timestamps,
).to(device)

In [20]:
quads = (
    entity_indices[:-1],  # Head entities
    relation_indices[:-1],  # Relations
    entity_indices[1:],  # Tail entities (shifted example)
    timestamp_indices[:-1],  # Timestamps
)

In [21]:
pretrained_tkg.pretrain(quads, learning_rate=0.01, epochs=500)


Epoch 0, Loss: 3.7462270259857178
Epoch 10, Loss: 3.0316848754882812
Epoch 20, Loss: 2.4488525390625
Epoch 30, Loss: 1.983251690864563
Epoch 40, Loss: 1.6134852170944214
Epoch 50, Loss: 1.3197829723358154
Epoch 60, Loss: 1.0855704545974731
Epoch 70, Loss: 0.8976711630821228
Epoch 80, Loss: 0.7459503412246704
Epoch 90, Loss: 0.6226732134819031
Epoch 100, Loss: 0.521932065486908
Epoch 110, Loss: 0.43918588757514954
Epoch 120, Loss: 0.3709111213684082
Epoch 130, Loss: 0.3143463730812073
Epoch 140, Loss: 0.26730769872665405
Epoch 150, Loss: 0.22805573046207428
Epoch 160, Loss: 0.19519668817520142
Epoch 170, Loss: 0.16760790348052979
Epoch 180, Loss: 0.14438001811504364
Epoch 190, Loss: 0.12477302551269531
Epoch 200, Loss: 0.10818178951740265
Epoch 210, Loss: 0.09410931915044785
Epoch 220, Loss: 0.08214595913887024
Epoch 230, Loss: 0.0719527080655098
Epoch 240, Loss: 0.06324820965528488
Epoch 250, Loss: 0.05579833686351776
Epoch 260, Loss: 0.049407705664634705
Epoch 270, Loss: 0.04391295090

In [36]:
new_metadata = []

for i, row in enumerate(cats_rows_list):
    potential_causes = metadata['root_cause'].unique().to_list()

    start_time = datetime.strptime(row['start_time'],"%Y-%m-%d %H:%M:%S")
    end_time = datetime.strptime(row['end_time'],"%Y-%m-%d %H:%M:%S")
    anomaly = eval(row['affected'])[0]
    root_cause = row['root_cause']
    potential_causes.append(anomaly)
    mod_df = test_df[(test_df['timestamp']>= start_time) & (test_df['timestamp']<= end_time)]
    test = mod_df[['aimp', 'amud', 'arnd', 'asin1', 'asin2', 'adbr', 'adfl',
       'bed1', 'bed2', 'bfo1', 'bfo2', 'bso1', 'bso2', 'bso3', 'ced1', 'cfo1',
       'cso1']]
    test_data = tkgngc_data_processing(data=test, device=device, num_timestamps=20)
    # Instantiate the full model

    entity_emb, relation_emb, _, timestamp_emb = pretrained_tkg(
    test_data.entity_indices, test_data.relation_indices, test_data.entity_indices, test_data.timestamp_indices
)

    model = NGCWithPretrainedTKGAndTimestamps(
        pretrained_tkg=pretrained_tkg,
        input_dim=test_data.time_series_tensor.shape[1],
        hidden_dim=64,
        output_dim=test_data.time_series_tensor.shape[1],
        confounder_latent_dim=17,
        entity_indices=test_data.entity_indices,
        relation_indices=test_data.relation_indices,
        time_series_data=test_data.time_series_tensor,
        timestamp_indices=test_data.timestamp_indices,
        edge_index=test_data.edge_index,
        use_sliding_window=False,
        window_size=10,
        step_size=2,
        regularization_type="l1",
        regularization_strength=0.01,
    )
    """for j in range(50):
        # Forward pass with the processed data
         z, mean, log_var, x_reconstructed = model(
            entity_indices=test_data.entity_indices,
            relation_indices=test_data.relation_indices,
            time_series_data=test_data.time_series_tensor,
            edge_index=test_data.edge_index,
            timestamp_indices=test_data.timestamp_indices,
        )"""
    model.train()
    score_df = pd.DataFrame(np.mean(np.abs(model.z.detach().numpy()),axis=0),
             index=ordered_column_names,columns=['scores']).sort_values(by=['scores'], ascending=False)
    remvove_list = list(set(score_df.index).difference(set(potential_causes)))
    score_df = score_df.drop(remvove_list)

    potential_cause1 = score_df['scores'].index[0]
    potential_cause2 = score_df['scores'].index[1]
    potential_cause3 = score_df['scores'].index[2]
    if root_cause == potential_cause1:
        row['cause_1'] = 1
    if root_cause == potential_cause2:
        row['cause_2'] = 1
    if root_cause == potential_cause3:
        row['cause_3'] = 1
    new_metadata.append(row)

    if i%5 == 0:
        print('Iteration #: ' + str(i))






Iteration #: 0
Iteration #: 5
Iteration #: 10
Iteration #: 15
Iteration #: 20
Iteration #: 25
Iteration #: 30
Iteration #: 35
Iteration #: 40
Iteration #: 45
Iteration #: 50
Iteration #: 55
Iteration #: 60
Iteration #: 65
Iteration #: 70
Iteration #: 75
Iteration #: 80
Iteration #: 85
Iteration #: 90
Iteration #: 95
Iteration #: 100
Iteration #: 105
Iteration #: 110
Iteration #: 115
Iteration #: 120
Iteration #: 125
Iteration #: 130
Iteration #: 135
Iteration #: 140
Iteration #: 145
Iteration #: 150
Iteration #: 155
Iteration #: 160
Iteration #: 165
Iteration #: 170
Iteration #: 175
Iteration #: 180
Iteration #: 185
Iteration #: 190
Iteration #: 195


In [37]:
stats = pl.DataFrame(new_metadata)
agg_stats = stats.select(pl.sum("cause_1", "cause_2",'cause_3'))
agg_stats.select(pl.sum_horizontal(pl.all())).item()/stats.shape[0]

0.38

In [None]:
model.z.shape

In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch.distributions.normal import Normal
from tkgngc.utils import vae_loss

def create_lagged_data(time_series_data, num_lags):
    """
    Create lagged versions of time-series data.
    
    Args:
        time_series_data (torch.Tensor): Time-series data of shape [batch_size, num_nodes, input_dim, time_steps].
        num_lags (int): Number of lags to include.
        
    Returns:
        lagged_data (torch.Tensor): Lagged data of shape [batch_size, num_nodes, input_dim * num_lags, time_steps - num_lags].
        current_data (torch.Tensor): Current time steps of shape [batch_size, num_nodes, input_dim, time_steps - num_lags].
    """
    batch_size, num_nodes, input_dim, time_steps = time_series_data.shape

    # Prepare lagged data
    lagged_data = []
    for lag in range(1, num_lags + 1):
        lagged_data.append(time_series_data[..., :-lag])  # Remove the last `lag` steps

    # Stack lagged data along the feature dimension
    lagged_data = torch.cat(lagged_data, dim=2)  # Shape: [batch_size, num_nodes, input_dim * num_lags, time_steps - num_lags]

    # Current data excludes the first `num_lags` time steps
    current_data = time_series_data[..., num_lags:]  # Shape: [batch_size, num_nodes, input_dim, time_steps - num_lags]

    return lagged_data, current_data

class Sampling(nn.Module):
    def forward(self, z_mean, z_log_var):
        # get the shape of the tensor for the mean and log variance
        batch, dim = z_mean.shape
        # generate a normal random tensor (epsilon) with the same shape as z_mean
        # this tensor will be used for reparameterization trick
        epsilon = Normal(0, 1).sample((batch, dim)).to(z_mean.device)
        # apply the reparameterization trick to generate the samples in the
        # latent space
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

class CrossAttention(nn.Module):
    """Cross-Attention Module."""
    def __init__(self, dim_query, dim_key_value, dim_out, num_heads):
        super(CrossAttention, self).__init__()
        self.query_proj = nn.Linear(dim_query, dim_out)
        self.key_proj = nn.Linear(dim_key_value, dim_out)
        self.value_proj = nn.Linear(dim_key_value, dim_out)
        self.num_heads = num_heads
        self.dim_head = dim_out // num_heads

        assert dim_out % num_heads == 0, "Output dimension must be divisible by the number of heads."

    def forward(self, query, key, value):
        # Linear projections
        Q = self.query_proj(query).view(query.size(0), -1, self.num_heads, self.dim_head).transpose(1, 2)
        K = self.key_proj(key).view(key.size(0), -1, self.num_heads, self.dim_head).transpose(1, 2)
        V = self.value_proj(value).view(value.size(0), -1, self.num_heads, self.dim_head).transpose(1, 2)

        # Scaled dot-product attention
        attn_weights = torch.matmul(Q, K.transpose(-1, -2)) / (self.dim_head ** 0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)

        # Attention output
        attn_output = torch.matmul(attn_weights, V).transpose(1, 2).reshape(query.size(0), -1, self.num_heads * self.dim_head)
        return attn_output

class GraphAttentionGC(nn.Module):
    """Graph Attention Network for Granger Causality."""
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads=4):
        super(GraphAttentionGC, self).__init__()
        self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=num_heads)
        self.gat2 = GATv2Conv(hidden_dim * num_heads, output_dim, heads=1, concat=False)

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.gat2(x, edge_index)
        return x

class GrangerCausality(nn.Module):
    """Granger Causality Model with GAT and Cross-Attention in the Decoder."""
    def __init__(self, pretrained_tkg, input_dim, hidden_dim, latent_dim, num_heads, num_nodes, num_lags):
        super(GrangerCausality, self).__init__()

        self.pretrained_tkg = pretrained_tkg
        self.num_nodes = num_nodes
        self.num_lags = num_lags

        # Graph Attention Network in Encoder
        self.gat_encoder = GATv2Conv(input_dim + pretrained_tkg.entity_embedding.embedding_dim, hidden_dim, heads=num_heads)

        # Cross-Attention in Encoder
        self.cross_attention_encoder = CrossAttention(
            dim_query=hidden_dim,
            dim_key_value=(input_dim + pretrained_tkg.entity_embedding.embedding_dim) * num_lags,
            dim_out=hidden_dim,
            num_heads=num_heads,
        )

        # Variational Components
        self.encoder_fc = nn.Linear(hidden_dim, latent_dim * 2)  # Mean and log variance

        # Cross-Attention in Decoder
        self.cross_attention_decoder = CrossAttention(
            dim_query=latent_dim,
            dim_key_value=(input_dim + pretrained_tkg.entity_embedding.embedding_dim) * num_lags,
            dim_out=hidden_dim,
            num_heads=num_heads,
        )

        # Graph Attention in Decoder
        self.gat_decoder = GATv2Conv(hidden_dim, hidden_dim, heads=num_heads)

        # Fully Connected Layer to Reconstruct Input
        self.decoder_fc = nn.Linear(hidden_dim, input_dim)

        # Adjacency Matrix for Causal Graph
        self.adjacency_matrix = nn.Parameter(torch.randn(num_nodes, num_nodes, num_lags))

    def forward(self, entity_indices, relation_indices, timestamp_indices, time_series_data, edge_index, lagged_data):
        """
        Forward pass for Granger causality detection with GAT and cross-attention in the decoder.
        Args:
            entity_indices (torch.Tensor): Indices of entities.
            relation_indices (torch.Tensor): Indices of relations.
            timestamp_indices (torch.Tensor): Indices of timestamps.
            time_series_data (torch.Tensor): Original time-series data.
            edge_index (torch.Tensor): Graph edges.
            lagged_data (torch.Tensor): Lagged time-series data of shape [batch, num_nodes, lagged_features].
        """
        # Pretrained TKG embeddings
        entity_emb, relation_emb, _, timestamp_emb = self.pretrained_tkg(
            entity_indices, relation_indices, entity_indices, timestamp_indices
        )
        entity_emb = entity_emb.unsqueeze(-2).expand(-1, -1, time_series_data.size(-1))  # Match time steps

        # Concatenate TKG embeddings with time-series data
        enriched_features = torch.cat([time_series_data, entity_emb], dim=-1)  # [batch, num_nodes, input_dim + embedding_dim]

        # Encoder: Graph Attention
        x = self.gat_encoder(enriched_features, edge_index)
        x = F.relu(x)

        # Encoder: Cross-Attention
        x, _ = self.cross_attention_encoder(x, lagged_data, lagged_data)
        x = F.relu(x)

        # Latent Space
        q_params = self.encoder_fc(x)
        mean, log_var = torch.chunk(q_params, 2, dim=-1)

        # Reparameterization Trick
        std = torch.exp(0.5 * log_var)
        z = mean + std * torch.randn_like(std)

        # Decoder: Cross-Attention
        x, _ = self.cross_attention_decoder(z, lagged_data, lagged_data)
        x = F.relu(x)

        # Decoder: Graph Attention
        x = self.gat_decoder(x, edge_index)
        x = F.relu(x)

        # Decoder: Fully Connected Reconstruction
        x_reconstructed = self.decoder_fc(x)

        # Learned Adjacency Matrix
        adj = torch.sigmoid(self.adjacency_matrix)  # Values in [0, 1]

        return z, mean, log_var, x_reconstructed, adj

    def loss_function(self, x_reconstructed, time_series_data, mean, log_var, adj, sparsity_weight=0.01, beta=1.0):
        """
        Loss function for Granger causality with GAT and cross-attention in the decoder.
        Args:
            x_reconstructed (torch.Tensor): Reconstructed features.
            time_series_data (torch.Tensor): Original input features.
            mean (torch.Tensor): Mean of latent distribution.
            log_var (torch.Tensor): Log variance of latent distribution.
            adj (torch.Tensor): Learned adjacency matrix.
            sparsity_weight (float): Weight for sparsity regularization.
            beta (float): Weight for KL divergence.
        """
        # Reconstruction Loss
        recon_loss = F.mse_loss(x_reconstructed, time_series_data, reduction='sum')

        # KL Divergence
        kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

        # Sparsity Regularization
        sparsity_loss = sparsity_weight * torch.sum(torch.abs(adj))

        # Total Loss
        return recon_loss + beta * kl_divergence + sparsity_loss

In [43]:
model = GrangerCausality(pretrained_tkg ,input_dim=16, hidden_dim=32, latent_dim=16, num_heads=4, num_nodes=10, num_lags=5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [44]:
test_data = tkgngc_data_processing(data=test, device=device, num_timestamps=20)
# Instantiate the full model


In [45]:
def create_lagged_data(time_series_data, num_lags):
    """
    Create lagged versions of time-series data.
    
    Args:
        time_series_data (torch.Tensor): Time-series data of shape [batch_size, num_nodes, input_dim, time_steps].
        num_lags (int): Number of lags to include.
        
    Returns:
        lagged_data (torch.Tensor): Lagged data of shape [batch_size, num_nodes, input_dim * num_lags, time_steps - num_lags].
        current_data (torch.Tensor): Current time steps of shape [batch_size, num_nodes, input_dim, time_steps - num_lags].
    """
    batch_size, num_nodes, input_dim, time_steps = time_series_data.shape

    # Prepare lagged data
    lagged_data = []
    for lag in range(1, num_lags + 1):
        lagged_data.append(time_series_data[..., :-lag])  # Remove the last `lag` steps

    # Stack lagged data along the feature dimension
    lagged_data = torch.cat(lagged_data, dim=2)  # Shape: [batch_size, num_nodes, input_dim * num_lags, time_steps - num_lags]

    # Current data excludes the first `num_lags` time steps
    current_data = time_series_data[..., num_lags:]  # Shape: [batch_size, num_nodes, input_dim, time_steps - num_lags]

    return lagged_data, current_data

In [46]:
time_series_data = test_data.time_series_tensor
lagged_data, original = create_lagged_data(time_series_data,10)

ValueError: not enough values to unpack (expected 4, got 2)

In [None]:

# Example data
time_series_data = torch.randn(32, 10, 16)  # Batch of 32 graphs with 10 nodes, 16 features each
lagged_data = torch.randn(32, 10, 80)  # Lagged data: num_lags * input_dim = 5 * 16
edge_index = torch.randint(0, 10, (2, 20))  # Random graph edges

for epoch in range(50):
    optimizer.zero_grad()
    z, mean, log_var, x_reconstructed, adj, attn_weights = model(time_series_data, edge_index, lagged_data)
    loss = model.loss_function(x_reconstructed, time_series_data, mean, log_var, adj)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")