In [1]:
from tkgngc.embeddings import PretrainedTKGEmbeddingWithTimestamps
from tkgngc.model import NGCWithPretrainedTKGAndTimestamps

In [2]:
import os
import torch
import polars as pl
import pandas as pd
import numpy as np
import seaborn as sns
import warnings
from feature.scalers import ranged_scaler
from datetime import datetime, timedelta
#from mpge.rca import mpge_root_cause_diagnosis
warnings.filterwarnings("ignore", category=UserWarning) 

In [3]:
#os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
#os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

In [4]:
cats_df = pl.read_csv("data/data.csv", separator=",")  
metadata = pl.read_csv('data/metadata.csv',separator=',')
potential_causes = metadata['root_cause'].unique().to_list()

In [5]:
for col in cats_df.columns:
    unique_vals = cats_df[col].n_unique()
    data_type = cats_df[col].dtype
    bad_dtypes = [pl.Date,pl.Datetime,pl.Utf8]
    if ((unique_vals >= 50) & (data_type not in bad_dtypes) ):
        cats_df = cats_df.with_columns(ranged_scaler(cats_df[col]))
    else:
        continue

In [6]:
cats_df = cats_df.with_columns(
    pl.col('timestamp').str.to_datetime("%Y-%m-%d %H:%M:%S"),
    pl.Series("entity_id",range(len(cats_df)))
)
cats_rows_list = metadata.rows(named=True)

In [7]:
cats_rows_list = metadata.rows(named=True)
cats_df.head()


timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category,entity_id
datetime[μs],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64
2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.1078e-14,2.0428e-14,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623,0.0,0.0,0
2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618,0.0,0.0,1
2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604,0.0,0.0,2
2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658,0.0,0.0,3
2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0008,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548,0.0,0.0,4


In [8]:

cats_df = cats_df.to_pandas()

In [9]:
cats_df.head()

Unnamed: 0,timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,...,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category,entity_id
0,2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.107825e-14,2.04281e-14,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623,0.0,0.0,0
1,2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618,0.0,0.0,1
2,2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604,0.0,0.0,2
3,2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658,0.0,0.0,3
4,2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0007999999,0.0,0.0,-0.32802,-0.369237,...,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548,0.0,0.0,4


In [10]:
device = torch.device('cpu')

In [11]:
train_df = cats_df[0:1000000]
train = train_df[['aimp', 'amud', 'arnd', 'asin1', 'asin2', 'adbr', 'adfl',
       'bed1', 'bed2', 'bfo1', 'bfo2', 'bso1', 'bso2', 'bso3', 'ced1', 'cfo1',
       'cso1']]
test_df = cats_df[1000000:]


In [12]:
class tkgngc_data_processing:
    def __init__(self, data, device, num_timestamps=20):
        self.data = data
        self.ordered_column_names = self.data
        self.train_list = self.data.values.tolist()
        self.time_series_tensor = torch.tensor(self.train_list,dtype=torch.float32)
        # Entity and Relation indices
        self.entity_indices = torch.arange(len(self.train_list), dtype=torch.long)
        self.relation_indices = torch.tensor(
            [0 if i % 2 == 0 else 1 for i in range(len(self.entity_indices))],dtype=torch.long
        )

        self.timestamps = entity_indices
        # Timestamp binning
        self.num_timestamps = num_timestamps
        min_time, max_time = min(self.entity_indices),max(self.entity_indices)
        bins = torch.linspace(min_time,max_time,self.num_timestamps+1)
        self.timestamp_indices = torch.tensor(torch.bucketize(self.entity_indices,bins),dtype=torch.long) - 1
        self.timestamp_indices = torch.clamp(self.timestamp_indices, min = 0, max= num_timestamps -1)

        # Edge index

        self.edge_index = torch.tensor(
            [[i,i+1] for i in range(len(self.entity_indices) - 1)],dtype=torch.long).t()



    
        

In [13]:
ordered_column_names = train.columns
train_list = train.values.tolist()
time_series_tensor = torch.tensor(train_list,dtype=torch.float32)

In [14]:
entity_indices = torch.arange(len(train_list), dtype=torch.long)
relation_indices = torch.tensor(
    [0 if i % 2 == 0 else 1 for i in range(len(entity_indices))],dtype=torch.long
)

In [15]:
timestamps = entity_indices

In [16]:
num_timestamps = 20
min_time, max_time = min(entity_indices),max(entity_indices)
bins = torch.linspace(min_time,max_time,num_timestamps+1)
timestamp_indices = torch.tensor(torch.bucketize(entity_indices,bins),dtype=torch.long) - 1
timestamp_indices = torch.clamp(timestamp_indices, min = 0, max= num_timestamps -1)

In [17]:
edge_index = torch.tensor(
    [[i,i+1] for i in range(len(entity_indices) - 1)],dtype=torch.long).t()

In [18]:
time_series_tensor = time_series_tensor.to(device)
entity_indices=entity_indices.to(device)
relation_indices=relation_indices.to(device)
timestamp_indices=timestamp_indices.to(device)
edge_index=edge_index.to(device)

In [19]:
pretrained_tkg = PretrainedTKGEmbeddingWithTimestamps(
    num_entities=int(entity_indices.max().item()+1),
    num_relations=int(relation_indices.max().item()+1),
    embedding_dim=16,
    num_timestamps=num_timestamps,
).to(device)

In [20]:
quads = (
    entity_indices[:-1],  # Head entities
    relation_indices[:-1],  # Relations
    entity_indices[1:],  # Tail entities (shifted example)
    timestamp_indices[:-1],  # Timestamps
)

In [21]:
pretrained_tkg.pretrain(quads, learning_rate=0.01, epochs=500)


Epoch 0, Loss: 4.284931182861328
Epoch 10, Loss: 3.510828733444214
Epoch 20, Loss: 2.865093231201172
Epoch 30, Loss: 2.3369147777557373
Epoch 40, Loss: 1.9087989330291748
Epoch 50, Loss: 1.5630245208740234
Epoch 60, Loss: 1.2839479446411133
Epoch 70, Loss: 1.0583704710006714
Epoch 80, Loss: 0.8755102753639221
Epoch 90, Loss: 0.7267642021179199
Epoch 100, Loss: 0.6053422093391418
Epoch 110, Loss: 0.5058901309967041
Epoch 120, Loss: 0.4241730272769928
Epoch 130, Loss: 0.35682472586631775
Epoch 140, Loss: 0.3011571168899536
Epoch 150, Loss: 0.2550145387649536
Epoch 160, Loss: 0.2166621834039688
Epoch 170, Loss: 0.18469928205013275
Epoch 180, Loss: 0.1579911708831787
Epoch 190, Loss: 0.13561610877513885
Epoch 200, Loss: 0.11682309955358505
Epoch 210, Loss: 0.10099862515926361
Epoch 220, Loss: 0.08764007687568665
Epoch 230, Loss: 0.0763346329331398
Epoch 240, Loss: 0.06674236804246902
Epoch 250, Loss: 0.05858272686600685
Epoch 260, Loss: 0.051623567938804626
Epoch 270, Loss: 0.0456724017858

In [None]:
new_metadata = []

for i, row in enumerate(cats_rows_list):
    potential_causes = metadata['root_cause'].unique().to_list()

    start_time = datetime.strptime(row['start_time'],"%Y-%m-%d %H:%M:%S")
    end_time = datetime.strptime(row['end_time'],"%Y-%m-%d %H:%M:%S")
    anomaly = eval(row['affected'])[0]
    root_cause = row['root_cause']
    potential_causes.append(anomaly)
    mod_df = test_df[(test_df['timestamp']>= start_time) & (test_df['timestamp']<= end_time)]
    test = mod_df[['aimp', 'amud', 'arnd', 'asin1', 'asin2', 'adbr', 'adfl',
       'bed1', 'bed2', 'bfo1', 'bfo2', 'bso1', 'bso2', 'bso3', 'ced1', 'cfo1',
       'cso1']]
    test_data = tkgngc_data_processing(data=test, device=device, num_timestamps=20)
    # Instantiate the full model

    entity_emb, relation_emb, _, timestamp_emb = pretrained_tkg(
    test_data.entity_indices, test_data.relation_indices, test_data.entity_indices, test_data.timestamp_indices
)

    model = NGCWithPretrainedTKGAndTimestamps(
        pretrained_tkg=pretrained_tkg,
        input_dim=test_data.time_series_tensor.shape[1],
        hidden_dim=64,
        output_dim=test_data.time_series_tensor.shape[1],
        confounder_latent_dim=17,
        entity_indices=test_data.entity_indices,
        relation_indices=test_data.relation_indices,
        time_series_data=test_data.time_series_tensor,
        timestamp_indices=test_data.timestamp_indices,
        edge_index=test_data.edge_index,
        use_sliding_window=False,
        window_size=10,
        step_size=2,
        regularization_type="l1",
        regularization_strength=0.01,
    )
    """for j in range(50):
        # Forward pass with the processed data
         z, mean, log_var, x_reconstructed = model(
            entity_indices=test_data.entity_indices,
            relation_indices=test_data.relation_indices,
            time_series_data=test_data.time_series_tensor,
            edge_index=test_data.edge_index,
            timestamp_indices=test_data.timestamp_indices,
        )"""
    model.train()
    score_df = pd.DataFrame(np.mean(np.abs(model.z.detach().numpy()),axis=0),
             index=ordered_column_names,columns=['scores']).sort_values(by=['scores'], ascending=False)
    score_df = score_df.drop(anomaly)

    potential_cause1 = score_df['scores'].index[0]
    potential_cause2 = score_df['scores'].index[1]
    potential_cause3 = score_df['scores'].index[2]
    if root_cause == potential_cause1:
        row['cause_1'] = 1
    if root_cause == potential_cause2:
        row['cause_2'] = 1
    if root_cause == potential_cause3:
        row['cause_3'] = 1
    new_metadata.append(row)

    if i%5 == 0:
        print('Iteration #: ' + str(i))





Iteration #: 0
Iteration #: 5


In [None]:
stats = pl.DataFrame(new_metadata)
agg_stats = stats.select(pl.sum("cause_1", "cause_2",'cause_3'))
agg_stats.select(pl.sum_horizontal(pl.all())).item()/stats.shape[0]