In [1]:
import os
import torch
import polars as pl
import numpy as np
import seaborn as sns
from datetime import datetime, timedelta
import warnings
from feature_eng.scalers import ranged_scaler
warnings.filterwarnings("ignore", category=UserWarning) 
import pandas as pd

In [2]:
from dataclasses import dataclass

import numpy as np
import networkx as nx

import torch
import pytorch_lightning as ptl

from torch.utils.data import DataLoader
from tensordict import TensorDict

from castle.datasets import DAG, IIDSimulation 
from castle.common import GraphDAG
from castle.metrics import MetricsDAG

import causica.distributions as cd

from causica.functional_relationships import ICGNN
from causica.training.auglag import AugLagLossCalculator, AugLagLR, AugLagLRConfig
from causica.graph.dag_constraint import calculate_dagness

from causica.datasets.variable_types import VariableTypeEnum
from causica.datasets.tensordict_utils import tensordict_shapes

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

COLORS = [
    '#00B0F0',
    '#FF0000',
    '#B0F000'
]

# Set random seed
SEED = 11
np.random.seed(SEED)
ptl.seed_everything(SEED)  


Global seed set to 11


11

In [3]:
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.graph.GraphNode import GraphNode
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
cats_df = pl.read_csv("data/data.csv", separator=",")  

In [5]:
metadata = pl.read_csv('data/metadata.csv',separator=',')


In [6]:
cats_df.head()

timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""2023-01-01 00:00:00""",0.0,1.0,20.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""2023-01-01 00:00:01""",0.0,1.0,20.080031,2e-05,0.0002,0.0,0.0,0.0,0.0,0.0,4.9939e-07,0.000789,0.0,0.0,0.0,2.1e-05,0.001229,0.0,0.0
"""2023-01-01 00:00:02""",0.0,1.0,20.276562,4e-05,0.0004,0.0,0.0,0.0,0.0,0.0,1e-06,0.003115,0.0,0.0,0.0,0.000104,0.004833,0.0,0.0
"""2023-01-01 00:00:03""",0.0,1.0,20.730938,6e-05,0.0006,0.0,0.0,0.0,0.0,0.0,3e-06,0.006914,0.0,0.0,0.0,0.000285,0.010688,0.0,0.0
"""2023-01-01 00:00:04""",0.0,1.0,21.118101,8e-05,0.0008,0.0,0.0,0.0,0.0,0.0,5e-06,0.012123,0.0,0.0,0.0,0.000601,0.018669,0.0,0.0


In [7]:
metadata.head()


start_time,end_time,root_cause,affected,category
str,str,str,str,i64
"""2023-01-12 15:11:45""","""2023-01-12 15:20:05""","""bso3""","""['cfo1']""",12
"""2023-01-12 16:27:46""","""2023-01-12 17:51:06""","""bso3""","""['cfo1']""",1
"""2023-01-12 18:19:35""","""2023-01-12 18:36:15""","""bfo2""","""['cso1']""",8
"""2023-01-12 20:46:32""","""2023-01-12 20:51:32""","""bed2""","""['ced1']""",7
"""2023-01-13 05:57:10""","""2023-01-13 06:02:10""","""bfo1""","""['cfo1']""",9


In [8]:
for col in cats_df.columns:
    unique_vals = cats_df[col].n_unique()
    data_type = cats_df[col].dtype
    bad_dtypes = [pl.Date,pl.Datetime,pl.Utf8]
    if ((unique_vals >= 50) & (data_type not in bad_dtypes) ):
        cats_df = cats_df.with_columns(ranged_scaler(cats_df[col]))
    else:
        continue

In [9]:
cats_df = cats_df.with_columns( 
    pl.col("timestamp").str.to_datetime("%Y-%m-%d %H:%M:%S")
)


In [10]:
cats_df.head()


timestamp,aimp,amud,arnd,asin1,asin2,adbr,adfl,bed1,bed2,bfo1,bfo2,bso1,bso2,bso3,ced1,cfo1,cso1,y,category
datetime[μs],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2023-01-01 00:00:00,0.0,0.142857,-0.5,-4.1078e-14,2.0428e-14,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180547,-0.507953,-0.716059,-0.774361,0.100389,-0.186623,0.0,0.0
2023-01-01 00:00:01,0.0,0.142857,-0.495998,2e-05,0.0002,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.18054,-0.507953,-0.716059,-0.774361,0.100389,-0.186618,0.0,0.0
2023-01-01 00:00:02,0.0,0.142857,-0.486172,4e-05,0.0004,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180519,-0.507953,-0.716059,-0.774361,0.10039,-0.186604,0.0,0.0
2023-01-01 00:00:03,0.0,0.142857,-0.463453,6e-05,0.0006,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180484,-0.507953,-0.716059,-0.774361,0.100391,-0.18658,0.0,0.0
2023-01-01 00:00:04,0.0,0.142857,-0.444095,8e-05,0.0008,0.0,0.0,-0.32802,-0.369237,-0.738163,-0.767181,-0.180437,-0.507953,-0.716059,-0.774361,0.100393,-0.186548,0.0,0.0


In [11]:
cats_df['timestamp'].min()

datetime.datetime(2023, 1, 1, 0, 0)

In [12]:
cats_rows_list = metadata.rows(named=True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
seed = 11
ptl.seed_everything(seed)

Global seed set to 11


11

In [13]:
device

'cuda:0'

In [14]:
@dataclass(frozen=True)
class TrainingConfig:
    noise_dist=cd.ContinuousNoiseDist.SPLINE
    batch_size=64
    max_epoch=500
    gumbel_temp=0.25
    averaging_period=10
    prior_sparsity_lambda=5.0
    init_rho=1.0
    init_alpha=0.0
        
training_config = TrainingConfig()
auglag_config = AugLagLRConfig()

In [None]:
%%time
new_metadata = []
iteration = 0 
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
previous_fail = False
for i, row in enumerate(cats_rows_list):
    if previous_fail == True:
        device =  'cpu'
    else: 
        device = 'cuda:0'
    try:
        if i == 0:       
            start_time = datetime.strptime(row['start_time'],'%Y-%m-%d %H:%M:%S')
            end_time = datetime.strptime(row['end_time'],'%Y-%m-%d %H:%M:%S')
            delta = end_time - start_time
            start_time = start_time - delta
        else:
            start_time = end_time + timedelta(seconds=1)
            end_time = datetime.strptime(row['end_time'],'%Y-%m-%d %H:%M:%S')
    
        
        anomaly = eval(row['affected'])[0]
        root_cause = row['root_cause']
        
        model_df = cats_df.filter( (pl.col('timestamp')>= start_time) & (pl.col('timestamp') <= end_time))
        model_df = model_df.drop(['timestamp','y','category'])
        out_cols = model_df.columns
        cats_np = model_df.to_numpy()
    
        # Cast data to torch tensors
        data_tensors = {}
        
        for i in range(cats_np.shape[1]):
            data_tensors[out_cols[i]] = torch.tensor(cats_np[:, i].reshape(-1, 1))
            
        dataset_train = TensorDict(data_tensors, torch.Size([cats_np.shape[0]]))
            
        # Move the entire dataset to the device (for big datasets move to device by batch within training loop)
        dataset_train = dataset_train.apply(lambda t: t.to(dtype=torch.float32, device=device)).to(device)
        
        
        # Create loader
        dataloader_train = DataLoader(
            dataset=dataset_train,
            collate_fn=lambda x: x,
            batch_size=training_config.batch_size,
            shuffle=True,
            drop_last=False,
        )
    
        num_nodes = len(dataset_train.keys())
    
        # Define the prior
        prior = cd.GibbsDAGPrior(
            num_nodes=num_nodes, 
            sparsity_lambda=training_config.prior_sparsity_lambda,
           # expert_graph_container=expert_knowledge
        )
    
            # Define the adjaceny module
        adjacency_dist = cd.ENCOAdjacencyDistributionModule(num_nodes)
        
        #Define the functional module
        icgnn = ICGNN(
            variables=tensordict_shapes(dataset_train),
            embedding_size=8, #32,
            out_dim_g=8, #32,
            norm_layer=torch.nn.LayerNorm,
            res_connection=True,
        )
        
        # Define the noise module
        types_dict = {var_name: VariableTypeEnum.CONTINUOUS for var_name in dataset_train.keys()}
        
        noise_submodules = cd.create_noise_modules(
            shapes=tensordict_shapes(dataset_train), 
            types=types_dict, 
            continuous_noise_dist=training_config.noise_dist
        )
        
        noise_module = cd.JointNoiseModule(noise_submodules)
    
        sem_module = cd.SEMDistributionModule(
        adjacency_module=adjacency_dist, 
        functional_relationships=icgnn, 
        noise_module=noise_module)
    
        sem_module.to(device)
    
        modules = {
        "icgnn": sem_module.functional_relationships,
        "vardist": sem_module.adjacency_module,
        "noise_dist": sem_module.noise_module,
        }
        
        parameter_list = [
            {"params": module.parameters(), "lr": auglag_config.lr_init_dict[name], "name": name}
            for name, module in modules.items()
        ]
        
        # Define the optimizer
        optimizer = torch.optim.Adam(parameter_list)
                
        
    
        # Define the augmented Lagrangian loss objects
        scheduler = AugLagLR(config=auglag_config)
        
        auglag_loss = AugLagLossCalculator(
            init_alpha=training_config.init_alpha, 
            init_rho=training_config.init_rho
        )
    
        assert len(dataset_train.batch_size) == 1, "Only 1D batch size is supported"
    
        num_samples = len(dataset_train)
        
        for epoch in range(training_config.max_epoch):
            
            for i, batch in enumerate(dataloader_train):
                
                # Zero the gradients
                optimizer.zero_grad()
                
                # Get SEM 
                sem_distribution = sem_module()
                sem, *_ = sem_distribution.relaxed_sample(
                    torch.Size([]), 
                    temperature=training_config.gumbel_temp
                )  # soft sample
                
                # Compute the log probability of data
                batch_log_prob = sem.log_prob(batch).mean()
                
                # Get the distribution entropy
                sem_distribution_entropy = sem_distribution.entropy()
                
                # Compute the likelihood of the current graph
                prior_term = prior.log_prob(sem.graph.to(device))
                
                # Compute the objective
                objective = (-sem_distribution_entropy - prior_term) / num_samples - batch_log_prob
                
                # Compute the DAG-ness term
                constraint = calculate_dagness(sem.graph)
                
                # Compute the Lagrangian loss
                loss = auglag_loss(objective, constraint / num_samples)
        
                # Propagate gradients and update
                loss.backward()
                optimizer.step()
                
                # Update the Auglag parameters
                scheduler.step(
                    optimizer=optimizer,
                    loss=auglag_loss,
                    loss_value=loss.item(),
                    lagrangian_penalty=constraint.item(),
                )
                
                # Log metrics & plot the matrices
                """if epoch % 500 == 0 and i == 0:
                    print(
                        f"epoch:{epoch} loss:{loss.item():.5g} nll:{-batch_log_prob.detach().cpu().numpy():.5g} "
                        f"dagness:{constraint.item():.5f} num_edges:{(sem.graph > 0.0).sum()} "
                        f"alpha:{auglag_loss.alpha:.5g} rho:{auglag_loss.rho:.5g} "
                        f"step:{scheduler.outer_opt_counter}|{scheduler.step_counter} "
                        f"num_lr_updates:{scheduler.num_lr_updates}"
                    )"""
    
        vardist = adjacency_dist()
        pred_dag = vardist.mode.cpu().numpy()
    
        treatment_columns = set(out_cols)
        treatment_columns.remove(anomaly)
        treatment_columns = list(treatment_columns)
    
        estimated_ate = {}
        num_samples = 1000
        sample_shape = torch.Size([num_samples])
        #normalizer = data_module.normalizer
    
        estimated_ate = {}
        num_samples = 20000
        sample_shape = torch.Size([num_samples])
        #normalizer = data_module.normalizer
        
        for treatment in treatment_columns:
            intervention_a = TensorDict({treatment: torch.tensor([1.0]).to(device)}, batch_size=tuple())
            intervention_b = TensorDict({treatment: torch.tensor([0.0]).to(device)}, batch_size=tuple())
        
            rev_a_samples = (sem.do(interventions=intervention_a).sample(sample_shape))[anomaly]
            rev_b_samples = (sem.do(interventions=intervention_b).sample(sample_shape))[anomaly]
        
            ate_mean = rev_a_samples.mean(0) - rev_b_samples.mean(0)
            ate_std = np.sqrt((rev_a_samples.cpu().var(0) + rev_b_samples.cpu().var(0)) / num_samples)
        
            estimated_ate[treatment] = (
                ate_mean.cpu().numpy()[0],
                ate_std.cpu().numpy()[0],
            )
        
        col_names = []
        effects = []
        for k, effect in estimated_ate.items():
            col_names.append(k)
            effects.append(np.abs(effect[0]))  
    
        top_causes = pd.DataFrame({"variable":col_names,'effect':effects}).sort_values(by='effect', ascending=False)[0:3]['variable'].reset_index(drop=True)
    
        if root_cause == top_causes[0]:
            row['cause_1'] = 1
        if root_cause == top_causes[1]:
            row['cause_2'] = 1
        if root_cause == top_causes[2]:
            row['cause_3'] = 1
        new_metadata.append(row)
        if iteration%50 == 0:
            print(iteration)
        iteration+=1


        del sem
        del intervention_a
        del intervention_b
        del dataset_train
        torch.cuda.empty_cache()
    except Exception as e:
        previous_fail = True
        print(e)
        
        

Updating alpha to: 3.7537364959716797
Updating alpha to: 5.414556503295898
Updating alpha to: 27.072782516479492
0
Updating alpha to: 41.005699157714844
Updating alpha to: 205.02849578857422
Updating alpha to: 1025.142478942871
Updating alpha to: 5125.7123947143555
Updating alpha to: 25628.561973571777
Updating alpha to: 128142.80986785889
Updating alpha to: 14.789104461669922
Updating alpha to: 18.26199722290039
Updating alpha to: 19.448284149169922
Updating rho, dag penalty prev:  1.1862869263
Updating rho, dag penalty prev:  1.1862869263
Updating rho, dag penalty prev:  1.1862869263
Updating rho, dag penalty prev:  1.1862869263
Updating rho, dag penalty prev:  1.1862869263
Updating rho, dag penalty prev:  1.1862869263
Updating rho, dag penalty prev:  1.1862869263
Updating alpha to: 39.206478118896484
Updating alpha to: 45.216800689697266
Updating rho, dag penalty prev:  6.0103225708
Updating rho, dag penalty prev:  6.0103225708
Updating rho, dag penalty prev:  6.0103225708
Updating 

In [None]:
torch.cuda.mem_get_info()

In [None]:
estimated_ate = {}
num_samples = 20000
sample_shape = torch.Size([num_samples])
#normalizer = data_module.normalizer

for treatment in treatment_columns:
    intervention_a = TensorDict({treatment: torch.tensor([1.0]).to(device)}, batch_size=tuple())
    intervention_b = TensorDict({treatment: torch.tensor([0.0]).to(device)}, batch_size=tuple())

    rev_a_samples = (sem.do(interventions=intervention_a).sample(sample_shape))[anomaly]
    rev_b_samples = (sem.do(interventions=intervention_b).sample(sample_shape))[anomaly]

    ate_mean = rev_a_samples.mean(0) - rev_b_samples.mean(0)
    ate_std = np.sqrt((rev_a_samples.cpu().var(0) + rev_b_samples.cpu().var(0)) / num_samples)

    estimated_ate[treatment] = (
        ate_mean.cpu().numpy()[0],
        ate_std.cpu().numpy()[0],
    )
estimated_ate

In [None]:
col_names = []
effects = []
for k, effect in estimated_ate.items():
    col_names.append(k)
    effects.append(np.abs(effect[0]))
    

In [None]:
top_causes = pd.DataFrame({"variable":col_names,'effect':effects}).sort_values(by='effect', ascending=False)[0:3]['variable'].reset_index(drop=True)

In [None]:
top_causes[0]

In [None]:
%%time
top_causes