In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn
import polars as pl
import pandas as pd
import numpy as np
import seaborn as sns
import warnings

from datetime import datetime, timedelta
from torch_geometric.nn import GATv2Conv, GATConv
from torch_geometric.utils import dense_to_sparse
from torch.distributions import Normal, Laplace, RelaxedOneHotCategorical
from torchdiffeq import odeint  # For continuous-time normalizing flows

from feature.scalers import ranged_scaler
from feature.engineering import *
from CARAT.model_utils import *
from CARAT.model import CausalGraphVAE
from CARAT.components import *
from utils.utils import set_seed, logger

import torch
import random
import numpy as np

# Define the seed value
seed = 42

# Set seed for PyTorch
torch.manual_seed(seed)

# Set seed for CUDA (if using GPUs)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

# Set seed for Python's random module
random.seed(seed)

# Set seed for NumPy
np.random.seed(seed)

# Ensure deterministic behavior for PyTorch operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Set device
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = "1"

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Suppress warnings
warnings.filterwarnings('ignore')

In [2]:
cats_df = pl.read_csv("data/data.csv", separator=",")  
print(cats_df.shape)
metadata = pl.read_csv('data/metadata.csv',separator=',')
potential_causes = metadata['root_cause'].unique().to_list()

for col in cats_df.columns:
    unique_vals = cats_df[col].n_unique()
    data_type = cats_df[col].dtype
    bad_dtypes = [pl.Date,pl.Datetime,pl.Utf8]
    if ((unique_vals >= 50) & (data_type not in bad_dtypes) ):
        cats_df = cats_df.with_columns(ranged_scaler(cats_df[col]))
    else:
        continue

(5000000, 20)


In [3]:
cats_df = cats_df.with_columns(
    pl.col('timestamp').str.to_datetime("%Y-%m-%d %H:%M:%S"),
    pl.Series("entity_id",range(len(cats_df)))
)

In [4]:
cats_rows_list = metadata.rows(named=True)
cats_df.shape

(5000000, 21)

In [5]:
cats_df = make_stationary(cats_df)

⚠️ Skipping column timestamp (not a float column)
Column: aimp | Final ADF: -2211.6294 | p-value: 0.0000 | Diffs: 0
Column: amud | Final ADF: -8.0091 | p-value: 0.0000 | Diffs: 0
Column: arnd | Final ADF: -18.1248 | p-value: 0.0000 | Diffs: 0
Column: asin1 | Final ADF: -5.7565 | p-value: 0.0000 | Diffs: 1
Column: asin2 | Final ADF: -39.8976 | p-value: 0.0000 | Diffs: 3
Column: adbr | Final ADF: -111.5305 | p-value: 0.0000 | Diffs: 0
Column: adfl | Final ADF: -217.2031 | p-value: 0.0000 | Diffs: 0
Column: bed1 | Final ADF: -35.0214 | p-value: 0.0000 | Diffs: 0
Column: bed2 | Final ADF: -44.5328 | p-value: 0.0000 | Diffs: 0
Column: bfo1 | Final ADF: -8.5614 | p-value: 0.0000 | Diffs: 0
Column: bfo2 | Final ADF: -13.3015 | p-value: 0.0000 | Diffs: 0
Column: bso1 | Final ADF: -2164.1026 | p-value: 0.0000 | Diffs: 1
Column: bso2 | Final ADF: -13.5707 | p-value: 0.0000 | Diffs: 0
Column: bso3 | Final ADF: -14.7270 | p-value: 0.0000 | Diffs: 0
Column: ced1 | Final ADF: -727.3289 | p-value: 0.

In [6]:
cats_df = cats_df.to_pandas()

In [7]:
metadata['affected'].unique().to_list()

["['cso1']", "['cfo1']", "['ced1']"]

In [8]:
potential_causes = metadata['root_cause'].unique().to_list()

In [9]:
cats_df=cats_df.set_index('timestamp')
cats_df = cats_df.drop(['y','category','entity_id'],axis=1)
cats_df.head()


Unnamed: 0_level_0,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
2023-01-01 00:00:06,0.0,0.142857,-0.44448,2e-05,-7.999823e-12,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,7.2e-05,-0.507953,-0.716059,0.0,0.100401,-0.186461
2023-01-01 00:00:07,0.0,0.142857,-0.446078,2e-05,-7.999823e-12,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,8.3e-05,-0.507953,-0.716059,0.0,0.100408,-0.186406
2023-01-01 00:00:08,0.0,0.142857,-0.447166,2e-05,-8.000267e-12,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,9.4e-05,-0.507953,-0.716059,0.0,0.100416,-0.186345
2023-01-01 00:00:09,0.0,0.142857,-0.442843,2e-05,-7.999823e-12,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,0.000104,-0.507953,-0.716059,0.0,0.100426,-0.186278
2023-01-01 00:00:10,0.0,0.142857,-0.44132,2e-05,-8.000045e-12,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,0.000113,-0.507953,-0.716059,0.0,0.100438,-0.186205


In [10]:
cats_df.shape

(4999994, 17)

In [11]:
train_df = cats_df[0:1000000]
test_df = cats_df[1000000:]

In [12]:
test_df = test_df.astype('float32')

In [13]:
try:
    train_df = train_df.drop('time',axis=1)
except:
    None
try:
    test_df = test_df.drop('time',axis=1)
except:
    None

In [14]:
cols = list(test_df.columns)
non_causal_columns = list(set(cols).difference(set(potential_causes)))
causal_indices = [train_df.columns.get_loc(col) for col in potential_causes]
non_causal_indices = [train_df.columns.get_loc(col) for col in non_causal_columns]

In [15]:
import torch
import random
import numpy as np

# Define the seed value
seed = 42

# Set seed for PyTorch
torch.manual_seed(seed)

# Set seed for CUDA (if using GPUs)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

# Set seed for Python's random module
random.seed(seed)

# Set seed for NumPy
np.random.seed(seed)

# Ensure deterministic behavior for PyTorch operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [16]:
TIME_STEPS = 5
BATCH_SIZE = 10000
hidden_dim = 128
latent_dim = 8
num_nodes = 17

dataset_nominal = TimeSeriesDataset(train_df, device=device, time_steps=TIME_STEPS)
dataloader_nominal = DataLoader(dataset_nominal, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)

# Initialize model and optimizer
model = CausalGraphVAE(input_dim=train_df.shape[1], hidden_dim=hidden_dim,
                        latent_dim=latent_dim, num_nodes=train_df.shape[1],device=device,
                        time_steps=TIME_STEPS, prior_adj=None,instantaneous_weight=0.3)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)

# Train on nominal data
print("Pretraining on nominal data...")
model.train_model(dataloader_nominal, optimizer, num_epochs=250, patience=25,rho_max=1.0,alpha_max=1.0)


Pretraining on nominal data...
Epoch 1: Loss = 108534.1583
Recon Loss = 57724.8086, KL Loss = 2.0070, Lagrangian Loss = 60.1475
Epoch 51: Loss = 1435.1395
Recon Loss = 1368.7638, KL Loss = 7.1449, Lagrangian Loss = 34.4208
Epoch 101: Loss = 333.6693
Recon Loss = 269.9066, KL Loss = 6.1545, Lagrangian Loss = 52.5123
Epoch 151: Loss = 203.4428
Recon Loss = 146.1718, KL Loss = 5.8149, Lagrangian Loss = 51.6312
Epoch 201: Loss = 160.6211
Recon Loss = 104.1256, KL Loss = 5.6720, Lagrangian Loss = 47.0627


In [17]:
prior_adj = model.causal_graph.adj_mat.clone().detach()
prior_adj

tensor([[0.0000, 0.6109, 0.5554, 0.5470, 0.6113, 0.5895, 0.5467, 0.5785, 0.5533,
         0.6134, 0.5810, 0.5329, 0.5515, 0.5693, 0.5256, 0.5947, 0.5695],
        [0.6085, 0.0000, 0.5515, 0.5363, 0.5926, 0.5579, 0.5834, 0.5974, 0.5512,
         0.6047, 0.5680, 0.5466, 0.5339, 0.5594, 0.5428, 0.5993, 0.5608],
        [0.5587, 0.5966, 0.0000, 0.5645, 0.5677, 0.5773, 0.5888, 0.5774, 0.5769,
         0.6060, 0.5828, 0.5405, 0.5743, 0.5761, 0.6012, 0.5599, 0.5693],
        [0.5759, 0.5509, 0.5831, 0.0000, 0.5409, 0.5769, 0.6012, 0.5706, 0.6064,
         0.5870, 0.5602, 0.5468, 0.5732, 0.6070, 0.5907, 0.5699, 0.6001],
        [0.6247, 0.5830, 0.5840, 0.5732, 0.0000, 0.5702, 0.5382, 0.5737, 0.5460,
         0.5495, 0.6203, 0.5999, 0.5523, 0.5288, 0.5404, 0.5878, 0.5561],
        [0.5801, 0.5444, 0.5906, 0.5604, 0.5871, 0.0000, 0.5955, 0.5205, 0.5812,
         0.5709, 0.5585, 0.5437, 0.5914, 0.6000, 0.6051, 0.5341, 0.6160],
        [0.5614, 0.5728, 0.5733, 0.5894, 0.5287, 0.5875, 0.0000, 0.579

In [18]:
prior_adj = model.causal_graph.adj_mat.clone().detach()
#scaled_prior = scale_tensor(prior_adj)
for i, row in enumerate(prior_adj):
    for j, column in enumerate(row):
        if (j in non_causal_indices) and (i in causal_indices) & (i!=j):
            continue
        else:
            prior_adj[i,j] = 0

In [19]:
pd.DataFrame(prior_adj.cpu().detach().numpy(),index=cols,columns=cols)

Unnamed: 0,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1
aimp,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
amud,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
arnd,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
asin1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
asin2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
adbr,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
adfl,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
bed1,0.585129,0.585867,0.585613,0.564786,0.588158,0.51126,0.56482,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.530803,0.580621,0.527371
bed2,0.535241,0.580195,0.56898,0.57801,0.516187,0.583845,0.619527,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.619278,0.524089,0.600948
bfo1,0.565199,0.593039,0.576741,0.576508,0.555203,0.587609,0.585661,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.57833,0.573487,0.594357


In [None]:
cols = list(test_df.columns)
non_causal_columns = list(set(cols).difference(set(potential_causes)))
causal_indices = [train_df.columns.get_loc(col) for col in potential_causes]
non_causal_indices = [train_df.columns.get_loc(col) for col in non_causal_columns]
num_nodes = len(train_df.columns)
BATCH_SIZE = 50


new_metadata = []
edge_correct = 0
instantaneous_correct = 0
lagged_correct = 0
counterfactual_correct = 0 
rr_correct = 0
total_correct = 0
total_checked = 0
incorrect = []

for i, row in enumerate(cats_rows_list):
    total_checked +=1 
    logger.info('Model: '+ str(i))
    anomaly = eval(row['affected'])[0]
    print('Anomaly: ' + anomaly)
    anomaly_time = datetime.strptime(row['start_time'],"%Y-%m-%d %H:%M:%S")
    #start_time = datetime.strptime(row['start_time'],"%Y-%m-%d %H:%M:%S")
    end_time = datetime.strptime(row['end_time'],"%Y-%m-%d %H:%M:%S")
    root_cause = row['root_cause']
    #start_len = mod_df.shape[0]
    #if start_len >1000:
        #start_len = 1000
    start_len = 8
    start_time = anomaly_time- timedelta(seconds=start_len)
    finish_time = end_time + timedelta(seconds=start_len)
    mod_df = test_df[(test_df.index>= start_time) & (test_df.index<= finish_time)]
    mod_df = mod_df[['aimp', 'amud', 'arnd', 'asin1', 'asin2', 'adbr', 'adfl', 'bed1',
       'bed2', 'bfo1', 'bfo2', 'bso1', 'bso2', 'bso3', 'ced1', 'cfo1', 'cso1']]

    """
    FIND THE OPTIMAL NUMBER OF LAGS
    """
    TIME_STEPS = 8#max(most_frequent(find_optimal_lags_for_dataframe(mod_df))+1,3)

    dataset = TimeSeriesDataset(mod_df,device=device, time_steps=TIME_STEPS)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    X_data = torch.empty(0,device=device)
    T_data = torch.empty(0,device=device)
    for batch_idx, (X_batch, time_batch) in enumerate(dataloader):
        X_data = torch.cat((X_data[:batch_idx], X_batch, X_data[batch_idx:]))
        T_data = torch.cat((T_data[:batch_idx], time_batch, T_data[batch_idx:]))
    
    fine_tuned = CausalGraphVAE(input_dim=num_nodes, hidden_dim=hidden_dim,
                           latent_dim=latent_dim, num_nodes=num_nodes,device=device,
                           time_steps=TIME_STEPS,prior_adj=prior_adj,instantaneous_weight=0.5)
    optimizer = torch.optim.AdamW(fine_tuned.parameters(), lr=0.0001, weight_decay=1e-3)  # L2 Regularization

    
    fine_tuned.train_model(dataloader, optimizer, num_epochs=2500, patience=20,BATCH_SIZE=BATCH_SIZE,rho_max=2.0,alpha_max=2.0)

    causes = fine_tuned.infer_causal_effect(X_data.to(torch.float32).to(device),T_data.to(torch.float32).to(device),anomaly,cols,non_causal_indices=non_causal_indices)
    
    causes = causes.filter(items=potential_causes, axis=0)
    edge_cause_1 = causes.sort_values(by='causes',ascending=False)[0:3].index[0]
    edge_cause_2 = causes.sort_values(by='causes',ascending=False)[0:3].index[1]
    edge_cause_3 = causes.sort_values(by='causes',ascending=False)[0:3].index[2]
    
    instant_cause_1 = causes.sort_values(by='instantaneous',ascending=False)[0:3].index[0]
    instant_cause_2 = causes.sort_values(by='instantaneous',ascending=False)[0:3].index[1]
    instant_cause_3 = causes.sort_values(by='instantaneous',ascending=False)[0:3].index[2]
    
    lag_cause_1 = causes.sort_values(by='lagged',ascending=False)[0:3].index[0]
    lag_cause_2 = causes.sort_values(by='lagged',ascending=False)[0:3].index[1]
    lag_cause_3 = causes.sort_values(by='lagged',ascending=False)[0:3].index[2]
    
    counterfactual_cause_1 = causes.sort_values(by='counterfactuals',ascending=False)[0:3].index[0]
    counterfactual_cause_2 = causes.sort_values(by='counterfactuals',ascending=False)[0:3].index[1]
    counterfactual_cause_3 = causes.sort_values(by='counterfactuals',ascending=False)[0:3].index[2]

    rr_cause_1 = causes.sort_values(by='RootRank',ascending=False)[0:3].index[0]
    rr_cause_2 = causes.sort_values(by='RootRank',ascending=False)[0:3].index[1]
    rr_cause_3 = causes.sort_values(by='RootRank',ascending=False)[0:3].index[2]
    
    total_score_cause_1=causes.sort_values(by='causal_strength',ascending=False)[0:3].index[0]
    total_score_cause_2=causes.sort_values(by='causal_strength',ascending=False)[0:3].index[1]
    total_score_cause_3=causes.sort_values(by='causal_strength',ascending=False)[0:3].index[2]
    
    if root_cause == edge_cause_1:
        row['edge_cause_1'] = 1
    if root_cause == edge_cause_1:
        row['edge_cause_2'] = 1
    if root_cause == edge_cause_1:
        row['edge_cause_3'] = 1
    
    if root_cause == counterfactual_cause_1:
        row['counterfactual_cause_1'] = 1
    if root_cause == counterfactual_cause_2:
        row['counterfactual_cause_2'] = 1
    if root_cause == counterfactual_cause_3:
        row['counterfactual_cause_3'] = 1
    
    if root_cause == total_score_cause_1:
        row['total_score_cause_1'] = 1
    if root_cause == total_score_cause_2:
        row['total_score_cause_2'] = 1
    if root_cause == total_score_cause_3:
        row['total_score_cause_3'] = 1
    
    if root_cause == instant_cause_1:
        row['instant_cause_1'] = 1
    if root_cause == instant_cause_2:
        row['instant_cause_2'] = 1
    if root_cause == instant_cause_3:
        row['instant_cause_3'] = 1
    
    if root_cause == lag_cause_1:
        row['lag_cause_1'] = 1
    if root_cause == lag_cause_2:
        row['lag_cause_2'] = 1
    if root_cause == lag_cause_3:
        row['lag_cause_3'] = 1

    if root_cause == rr_cause_1:
        row['rr_cause_1'] = 1
    if root_cause == rr_cause_2:
        row['rr_cause_2'] = 1
    if root_cause == rr_cause_3:
        row['rr_cause_3'] = 1
    
    new_metadata.append(row)
    
    if root_cause in [total_score_cause_1 , total_score_cause_2 , total_score_cause_3]:
        total_correct+=1
    if root_cause in [edge_cause_1 , edge_cause_2 , edge_cause_3]:
        edge_correct+=1
    if root_cause in [counterfactual_cause_1 , counterfactual_cause_2 , counterfactual_cause_3]:
        counterfactual_correct+=1
    if root_cause in [instant_cause_1 , instant_cause_2 , instant_cause_3]:
        instantaneous_correct+=1
    if root_cause in [lag_cause_1 , lag_cause_2 , lag_cause_3]:
        lagged_correct+=1
    if root_cause in [rr_cause_1 , rr_cause_2 , rr_cause_3]:
        rr_correct+=1
    
    total_accuracy = total_correct/total_checked* 100
    edge_accuracy = edge_correct/total_checked* 100
    cf_accuracy = counterfactual_correct/total_checked* 100
    instant_accuracy = instantaneous_correct/total_checked* 100
    lag_accuracy = lagged_correct/total_checked* 100
    rr_accuracy = rr_correct/total_checked* 100
    
    if root_cause not in [total_score_cause_1 , total_score_cause_2 , total_score_cause_3,edge_cause_1 , edge_cause_2 , edge_cause_3 ]:
        incorrect.append(i) 
    logger.info(f"Edge Accuracy = {edge_accuracy:.2f}, Instantaneous Accuracy = {instant_accuracy:.2f}, Lagged Accuracy = {lag_accuracy:.2f}, Counterfactual Accuracy = {cf_accuracy:.2f},  Blended Accuracy = {total_accuracy:.2f},  RR Accuracy = {rr_accuracy:.2f}  ") 



2025-03-22 08:00:43,227 INFO -- Model: 0


Anomaly: cfo1
Epoch 1: Loss = 2142.1821
Recon Loss = 1544.6382, KL Loss = 0.2464, Lagrangian Loss = 55.2770
Epoch 51: Loss = 130.2133
Recon Loss = 119.9956, KL Loss = 3.2680, Lagrangian Loss = 25.0616
Epoch 101: Loss = 79.3184
Recon Loss = 59.6290, KL Loss = 4.0164, Lagrangian Loss = 27.5530
Epoch 151: Loss = 60.5041
Recon Loss = 38.3329, KL Loss = 3.1908, Lagrangian Loss = 22.2487
Epoch 201: Loss = 50.5777
Recon Loss = 24.8533, KL Loss = 2.2769, Lagrangian Loss = 22.8569
Epoch 251: Loss = 45.0816
Recon Loss = 28.9451, KL Loss = 1.7998, Lagrangian Loss = 18.6154
Epoch 301: Loss = 40.2754
Recon Loss = 19.5170, KL Loss = 1.5060, Lagrangian Loss = 22.2257
Epoch 351: Loss = 37.8796
Recon Loss = 20.6481, KL Loss = 1.3024, Lagrangian Loss = 17.6442
Epoch 401: Loss = 36.2191
Recon Loss = 19.5806, KL Loss = 1.1504, Lagrangian Loss = 21.2405
Epoch 451: Loss = 33.7137
Recon Loss = 12.3846, KL Loss = 0.9730, Lagrangian Loss = 23.6803
Early stopping triggered. Last Epoch: 456
Recon Loss = 16.0156,

2025-03-22 08:05:07,049 INFO -- Edge Accuracy = 100.00, Instantaneous Accuracy = 100.00, Lagged Accuracy = 0.00, Counterfactual Accuracy = 100.00,  Blended Accuracy = 100.00,  RR Accuracy = 100.00  
2025-03-22 08:05:07,049 INFO -- Model: 1


Anomaly: cfo1
Epoch 1: Loss = 1039.7863
Recon Loss = 481.7098, KL Loss = 1.8107, Lagrangian Loss = 45.2394
Epoch 51: Loss = 43.1230
Recon Loss = 17.1533, KL Loss = 1.2755, Lagrangian Loss = 23.7621
Epoch 101: Loss = 24.8164
Recon Loss = 5.3089, KL Loss = 0.2146, Lagrangian Loss = 20.5569
Epoch 151: Loss = 20.1579
Recon Loss = 2.6493, KL Loss = 0.0200, Lagrangian Loss = 18.5106
Epoch 201: Loss = 18.1645
Recon Loss = 1.8682, KL Loss = 0.0041, Lagrangian Loss = 16.5201
Epoch 251: Loss = 16.4155
Recon Loss = 1.6047, KL Loss = 0.0030, Lagrangian Loss = 18.5408
