## Counterfactual - Rung 3 in the Ladder of Causation

In [1]:
import networkx as nx
import numpy as np
import pandas as pd
from dowhy import gcm
from typing import Dict, List, Tuple
from itertools import combinations

# Step 1: Define the causal model and generate data
# Key Attribute of invertible SCM - invertible with respect to noise: noise can be reconstructed from the observed data
causal_graph = nx.DiGraph([('A', 'C'), ('B', 'C'), ('C', 'Y')])
causal_model = gcm.InvertibleStructuralCausalModel(causal_graph)

# Generate some data
rng = np.random.default_rng(42)
n_samples = 1000
A = rng.normal(0, 0.5, size=n_samples)
B = rng.normal(0, 0.5, size=n_samples)
C = A + 10 * B + rng.normal(0, 0.5, size=n_samples)
Y = 2 * C + rng.normal(0, 0.5, size=n_samples)
## Setup some training data
training_data = pd.DataFrame({'A': A, 'B': B, 'C': C, 'Y': Y})

# Set up and fit the causal model (learn the generative model)
## By using gcm.EmpiricalDistribution() and gcm.AdditiveNoiseModel(gcm.ml.create_linear_regressor()), we manually assign the causal mechanism
# Alternatively  
# gcm.auto.assign_causal_mechanisms(causal_model, data)
causal_model.set_causal_mechanism('A', gcm.EmpiricalDistribution())
causal_model.set_causal_mechanism('B', gcm.EmpiricalDistribution())
causal_model.set_causal_mechanism('C', gcm.AdditiveNoiseModel(gcm.ml.create_linear_regressor()))
causal_model.set_causal_mechanism('Y', gcm.AdditiveNoiseModel(gcm.ml.create_linear_regressor()))
gcm.fit(causal_model, training_data)

# Step 2: Compute control values
def compute_control_values(data: pd.DataFrame) -> Dict[str, float]:
    return {col: np.median(data[col]) for col in data.columns} # we use the median value of the data as favorable event

## vector comprising of the control values
control_values = compute_control_values(training_data)

  from .autonotebook import tqdm as notebook_tqdm
Fitting causal mechanism of node Y: 100%|██████████| 4/4 [00:00<00:00, 85.05it/s]


In [2]:
# Step 3: Counterfactual prediction function of the target variable for one intervention 
def counterfactual_prediction(causal_model: gcm.InvertibleStructuralCausalModel, 
                              observed_instance: pd.DataFrame, 
                              intervention: Dict[str, callable], 
                              target: str) -> float:
    cf_samples = gcm.counterfactual_samples(causal_model, intervention, observed_data=observed_instance)
    return cf_samples[target].values[0]

# Step 4: Calculate attribution scores
def calculate_attribution_scores(causal_model: gcm.InvertibleStructuralCausalModel, 
                                 observed_instance: pd.DataFrame, 
                                 control_values: Dict[str, float], 
                                 target: str) -> Dict[str, float]:
    scores = {}
    for variable in causal_model.graph.nodes:
        # Compute counterfactual for each node except target node (score)
        if variable != target:
            intervention = {variable: lambda x: control_values[variable]}
            cf_prediction = counterfactual_prediction(causal_model, observed_instance, intervention, target)
            # Absolute deviation from control value
            scores[variable] = abs(cf_prediction - control_values[target])
    return scores

# Step 5: Identify similarity groups
def identify_similarity_groups(scores: Dict[str, float], similarity_threshold: float) -> List[List[str]]:
    sorted_vars = sorted(scores.keys(), key=lambda x: scores[x])  # Sort from lowest to highest score, key specifies what should be compared for sorting
    groups = []
    current_group = [sorted_vars[0]]
    
    for var in sorted_vars[1:]:
        if abs(scores[var] - scores[current_group[0]]) <= similarity_threshold:
            current_group.append(var)
        else:
            groups.append(current_group)
            current_group = [var]
    
    if current_group:
        groups.append(current_group)
    
    return groups

# Step 6: Find earliest ancestor
def find_earliest_ancestor(group: List[str], causal_graph: nx.DiGraph) -> str:
    ancestors = set(group)
    for node in group:
        ancestors &= set(nx.ancestors(causal_graph, node)) | {node} #retrieves path up to node
    return min(ancestors, key=lambda x: list(causal_graph.nodes()).index(x)) # returns earliest member from ancestors set

# Step 7: Disentangle scores for each group
def disentangle_scores(scores: Dict[str, float], 
                       similarity_groups: List[List[str]], 
                       causal_graph: nx.DiGraph) -> Dict[str, float]:
    disentangled_scores = {}
    for group in similarity_groups:
        earliest_ancestor = find_earliest_ancestor(group, causal_graph)
        disentangled_scores[earliest_ancestor] = scores[earliest_ancestor]
    return disentangled_scores

# Step 8: Counterfactual Root Cause Analysis
def counterfactual_rca(causal_model: gcm.InvertibleStructuralCausalModel, 
                       observed_instance: pd.DataFrame, 
                       control_values: Dict[str, float], 
                       target: str, 
                       similarity_threshold: float) -> Tuple[str, Dict[str, float]]:
    # Calculate attribution scores
    scores = calculate_attribution_scores(causal_model, observed_instance, control_values, target)
    
    # Identify similarity groups
    similarity_groups = identify_similarity_groups(scores, similarity_threshold)
    
    # Disentangle scores
    disentangled_scores = disentangle_scores(scores, similarity_groups, causal_model.graph)
    
    # Identify root cause (now minimizing the score)
    root_cause = min(disentangled_scores, key=disentangled_scores.get)
    
    return root_cause, disentangled_scores


In [3]:

# Test the algorithm with the provided events
events = {
    'event_a': {'A': 1.00, 'B': 0.1, 'C': 4.75, 'Y': 12.5},
    'event_b': {'A': 0.75, 'B': 0.5, 'C': 4.75, 'Y': 12.5},
    'event_c': {'A': 0.75, 'B': 0.1, 'C': 4.75, 'Y': 12.5}
}

for event_name, event_data in events.items():
    print(f"\nAnalyzing {event_name}:")
    observed_instance = pd.DataFrame([event_data])
    root_cause, disentangled_scores = counterfactual_rca(causal_model, observed_instance, control_values, 'Y', similarity_threshold=0.1)
    print(f"Identified root cause: {root_cause}")
    print("Disentangled scores (lower is more causal):")
    for var, score in sorted(disentangled_scores.items(), key=lambda x: x[1]):
        print(f"{var}: {score}")


Analyzing event_a:
Identified root cause: C
Disentangled scores (lower is more causal):
C: 2.952756090215799
B: 10.705453547260689
A: 11.081312757094434

Analyzing event_b:
Identified root cause: B
Disentangled scores (lower is more causal):
B: 2.7320768140865006
C: 2.952756090215799
A: 11.596372904659507

Analyzing event_c:
Identified root cause: C
Disentangled scores (lower is more causal):
C: 2.952756090215799
B: 10.705453547260692
A: 11.59637290465951
