In [1]:
import os
import torch
import polars as pl
import pandas as pd
import numpy as np
import seaborn as sns
import warnings
from feature.scalers import ranged_scaler
from datetime import datetime, timedelta
#from mpge.rca import mpge_root_cause_diagnosis
warnings.filterwarnings("ignore", category=UserWarning) 

In [2]:
from tkgngc.embeddings import PretrainedTKGEmbeddingWithTimestamps
from tkgngc.data_processing import TKGNGCDataProcessor
from tkgngc.model import train_model

In [3]:
cats_df = pl.read_csv("data/data.csv", separator=",")  
metadata = pl.read_csv('data/metadata.csv',separator=',')
potential_causes = metadata['root_cause'].unique().to_list()

for col in cats_df.columns:
    unique_vals = cats_df[col].n_unique()
    data_type = cats_df[col].dtype
    bad_dtypes = [pl.Date,pl.Datetime,pl.Utf8]
    if ((unique_vals >= 50) & (data_type not in bad_dtypes) ):
        cats_df = cats_df.with_columns(ranged_scaler(cats_df[col]))
    else:
        continue

In [4]:
cats_df = cats_df.with_columns(
    pl.col('timestamp').str.to_datetime("%Y-%m-%d %H:%M:%S"),
    pl.Series("entity_id",range(len(cats_df)))
)
cats_rows_list = metadata.rows(named=True)

In [5]:
cats_rows_list = metadata.rows(named=True)
cats_df.head()

timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category,entity_id
datetime[μs],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64
2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.1078e-14,2.0428e-14,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623,0.0,0.0,0
2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618,0.0,0.0,1
2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604,0.0,0.0,2
2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658,0.0,0.0,3
2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0008,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548,0.0,0.0,4


In [6]:
cats_df = cats_df.to_pandas()

In [7]:
cats_df=cats_df.set_index('timestamp')
cats_df = cats_df.drop(['y','category','entity_id'],axis=1)
cats_df.head()


Unnamed: 0_level_0,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.107825e-14,2.04281e-14,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623
2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618
2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604
2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658
2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0007999999,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548


In [8]:
device = torch.device('cpu')

In [9]:
train_df = cats_df[0:1000000]
test_df = cats_df[1000000:]


In [10]:
tkgnc_data = TKGNGCDataProcessor(train_df,device,num_timestamps=20, lags=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['time'] = self.data.index
  return torch.tensor(


In [11]:
pretrained_tkg = PretrainedTKGEmbeddingWithTimestamps(
    num_entities=int(tkgnc_data.entity_indices.max().item()+1),
    num_relations=int(tkgnc_data.relation_indices.max().item()+1),
    embedding_dim=8,
    num_timestamps=tkgnc_data.num_timestamps,
).to(device)

In [12]:
quads = (
    tkgnc_data.entity_indices[:-1],  # Head entities
    tkgnc_data.relation_indices,  # Relations
    tkgnc_data.entity_indices[1:],  # Tail entities (shifted example)
    tkgnc_data.timestamp_indices[:-1],  # Timestamps
)

In [13]:
pretrained_tkg.pretrain(quads, learning_rate=0.01, epochs=500)


Epoch 0, Loss: 3.6063
Epoch 10, Loss: 2.9710
Epoch 20, Loss: 2.4408
Epoch 30, Loss: 2.0068
Epoch 40, Loss: 1.6544
Epoch 50, Loss: 1.3690
Epoch 60, Loss: 1.1373
Epoch 70, Loss: 0.9488
Epoch 80, Loss: 0.7948
Epoch 90, Loss: 0.6682
Epoch 100, Loss: 0.5637
Epoch 110, Loss: 0.4771
Epoch 120, Loss: 0.4049
Epoch 130, Loss: 0.3445
Epoch 140, Loss: 0.2938
Epoch 150, Loss: 0.2511
Epoch 160, Loss: 0.2150
Epoch 170, Loss: 0.1845
Epoch 180, Loss: 0.1585
Epoch 190, Loss: 0.1365
Epoch 200, Loss: 0.1178
Epoch 210, Loss: 0.1018
Epoch 220, Loss: 0.0882
Epoch 230, Loss: 0.0766
Epoch 240, Loss: 0.0667
Epoch 250, Loss: 0.0582
Epoch 260, Loss: 0.0509
Epoch 270, Loss: 0.0447
Epoch 280, Loss: 0.0394
Epoch 290, Loss: 0.0348
Epoch 300, Loss: 0.0308
Epoch 310, Loss: 0.0274
Epoch 320, Loss: 0.0245
Epoch 330, Loss: 0.0220
Epoch 340, Loss: 0.0198
Epoch 350, Loss: 0.0179
Epoch 360, Loss: 0.0162
Epoch 370, Loss: 0.0148
Epoch 380, Loss: 0.0135
Epoch 390, Loss: 0.0124
Epoch 400, Loss: 0.0114
Epoch 410, Loss: 0.0105
Epo

In [34]:
import warnings
import pandas as pd

# Suppress FutureWarning messages
warnings.simplefilter(action='ignore', category=FutureWarning)
pd.options.mode.chained_assignment = None  # Suppresses the warning
#warnings.simplefilter(action='ignore', category=SettingWithCopyWarning)
new_metadata = []

for i, row in enumerate(cats_rows_list):
    potential_causes = metadata['root_cause'].unique().to_list()

    start_time = datetime.strptime(row['start_time'],"%Y-%m-%d %H:%M:%S")
    end_time = datetime.strptime(row['end_time'],"%Y-%m-%d %H:%M:%S")
    anomaly = eval(row['affected'])[0]
    root_cause = row['root_cause']
    potential_causes.append(anomaly)
    mod_df = test_df[(test_df.index>= start_time) & (test_df.index<= end_time)]
    test_data = TKGNGCDataProcessor(mod_df,device,num_timestamps=20, lags=1)
    # Instantiate the full model

    """entity_emb, relation_emb, _, timestamp_emb = pretrained_tkg(
    test_data.entity_indices, test_data.relation_indices, test_data.entity_indices, test_data.timestamp_indices
)"""
    z, mean, log_var, x_reconstructed, causal_effect, adj_df = train_model(test_data,pretrained_tkg)
    
    adj_df = pd.DataFrame(adj_df.detach().numpy(),index=test_data.feature_columns,columns=test_data.feature_columns)
    causes = adj_df.loc[potential_causes][anomaly].sort_values(ascending=False)
    potential_cause1 = causes.index[0]
    potential_cause2 = causes.index[1]
    potential_cause3 = causes.index[2]
    if root_cause == potential_cause1:
        row['cause_1'] = 1
    if root_cause == potential_cause2:
        row['cause_2'] = 1
    if root_cause == potential_cause3:
        row['cause_3'] = 1
    new_metadata.append(row)

    if i%5 == 0:
        print('Iteration #: ' + str(i))


Iteration #: 0
Iteration #: 5
Iteration #: 10
Iteration #: 15
Iteration #: 20
Iteration #: 25
Iteration #: 30
Iteration #: 35
Iteration #: 40
Iteration #: 45
Iteration #: 50
Iteration #: 55
Iteration #: 60
Iteration #: 65
Iteration #: 70
Iteration #: 75
Iteration #: 80
Iteration #: 85
Iteration #: 90
Iteration #: 95
Iteration #: 100
Iteration #: 105
Iteration #: 110
Iteration #: 115
Iteration #: 120
Iteration #: 125
Iteration #: 130
Iteration #: 135
Iteration #: 140
Iteration #: 145
Iteration #: 150
Iteration #: 155
Iteration #: 160
Iteration #: 165
Iteration #: 170
Iteration #: 175
Iteration #: 180
Iteration #: 185
Iteration #: 190
Iteration #: 195


In [35]:
stats = pl.DataFrame(new_metadata)
agg_stats = stats.select(pl.sum("cause_1", "cause_2",'cause_3'))
agg_stats.select(pl.sum_horizontal(pl.all())).item()/stats.shape[0]

0.485