In [10]:
import networkx as nx, numpy as np, pandas as pd
from dowhy import gcm

# Let's generate some "normal" data we assume we're given from our problem domain:
X = np.random.normal(loc=0, scale=1, size=1000)
Y = 2 * X + np.random.normal(loc=0, scale=1, size=1000)
Z = 3 * Y + np.random.normal(loc=0, scale=1, size=1000)
data = pd.DataFrame(dict(X=X, Y=Y, Z=Z))

# Step 1: Model our system:
causal_model = gcm.StructuralCausalModel(nx.DiGraph([('X', 'Y'), ('Y', 'Z')]))
gcm.auto.assign_causal_mechanisms(causal_model, data)

# Step 2: Train our causal model with the data from above:
gcm.fit(causal_model, data)

# Step 3: Perform a causal analysis. E.g. we have an:
anomalous_record = pd.DataFrame(dict(X=[.7], Y=[100.0], Z=[303.0]))
# ... and would like to answer the question:
# "Which node is the root cause of the anomaly in Z?":
anomaly_attribution = gcm.attribute_anomalies(causal_model, "Z", anomalous_record)

Fitting causal mechanism of node Z: 100%|██████████| 3/3 [00:00<00:00, 374.79it/s]
Evaluate set function: 8it [00:00, ?it/s]


In [11]:

import networkx as nx
causal_graph = nx.DiGraph([('X', 'Y'), ('Y', 'Z')])

In [12]:

from dowhy import gcm
causal_model = gcm.StructuralCausalModel(causal_graph)

In [13]:

from dowhy import gcm
causal_model = gcm.StructuralCausalModel(causal_graph)

In [14]:
gcm.auto.assign_causal_mechanisms(causal_model, data)


In [15]:
causal_model.set_causal_mechanism('X', gcm.EmpiricalDistribution())
causal_model.set_causal_mechanism('Y', gcm.AdditiveNoiseModel(gcm.ml.create_linear_regressor()))
causal_model.set_causal_mechanism('Z', gcm.AdditiveNoiseModel(gcm.ml.create_linear_regressor()))

In [16]:

gcm.fit(causal_model, data)

Fitting causal mechanism of node Z: 100%|██████████| 3/3 [00:00<00:00, 428.56it/s]


In [17]:
samples = gcm.interventional_samples(causal_model,
                                     {'Y': lambda y: 2.34 },
                                     num_samples_to_draw=1000)
samples.head()

Unnamed: 0,X,Y,Z
0,-0.767502,2.34,6.777626
1,0.274842,2.34,6.650801
2,0.483652,2.34,5.100161
3,1.524137,2.34,8.081731
4,-0.456372,2.34,8.083934
