In [109]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy.stats import norm

%matplotlib inline

In [122]:
class NotFittedError(Exception):
    pass


class StructuralCausalModel:
    def __init__(
        self,
        structure:dict,
        size:int
    ):
        self.structure = structure
        self.size = size
        self.noise_terms = [
            target 
            for (target, var_name) in self.structure.items()
            if not var_name
        ]
        self.model_params = dict()        
        
    def fit(self, variables:dict):
        if (
            set(self.structure.keys()) - set(variables.keys())
        ):
            raise ValueError(
                "Please provide all the variables required from the provided structure"
            )
        if (
            set(variables.keys()) -  set(self.structure.keys())
        ):
            raise ValueError(
                "Please provide an appropriate structure that mapps to the provided variables"
            )       
        assert all(
            map(
                lambda a: a.shape[0] == self.size, 
                variables.values()
            )
        ), "All provided variables should have the same length"
        
        for (target, vars_) in self.structure.items():
            if target in self.noise_terms:
                continue
            Y = variables.get(target).reshape(-1,1)
            var_names = [var_ for var_ in vars_]
            X = np.concatenate(
                [variables.get(var_).reshape(-1,1) for var_ in vars_], 
                axis=1
            )
            X = sm.add_constant(X)
            model = sm.OLS(Y, X)
            result = model.fit()
            
            self.model_params[target] = {
                **{"Intercept": result.params[0]},
                **{
                    var_: coef 
                    for (var_, coef) in zip(var_names, result.params[1:])
                }
                    
            }
    
    def sample(
        self, 
        sample_size:int,
        noise_dist=random_var_dist
    ):
        if not self.model_params:
            raise NotFittedError(
                "Please fit the SCM prior sampling"
            )
            
        random_sample_eps_A = noise_dist.rvs(sample_size)
        random_sample_eps_X = noise_dist.rvs(sample_size)
        random_sample_eps_Y = noise_dist.rvs(sample_size)   
        A = (
            self.model_params["A"].get("Intercept") + self.model_params["A"].get("eps_A") * random_sample_eps_A
        )
        X = (
            self.model_params["X"].get("Intercept") + self.model_params["X"].get("A") * A +
            self.model_params["X"].get("eps_X") * random_sample_eps_X
        )
        Y = (
            self.model_params["Y"].get("Intercept") + 
            self.model_params["Y"].get("A") * A +
            self.model_params["X"].get("A") * A +            
            self.model_params["Y"].get("eps_Y") * random_sample_eps_Y
        )     
        return Y
            
        
        
        
structure = {
    "Y": ["X", "A", "eps_Y"],
    "X": ["A", "eps_X"],
    "A": ["eps_A"],
    "eps_A": [],
    "eps_X": [],
    "eps_Y": []
}

SIZE = 10_000
NOISE_EPS = 0.3
random_var_dist = norm(0, NOISE_EPS)
eps_A = random_var_dist.rvs(SIZE)
eps_X = random_var_dist.rvs(SIZE)
eps_Y = random_var_dist.rvs(SIZE)
A = eps_A
X = 2 * A + eps_X
Y = 3 * A + 5 * X + eps_Y

variables = {
    "A": A,
    "X": X,
    "Y": Y,
    "eps_A": eps_A,
    "eps_X": eps_X,
    "eps_Y": eps_Y    
}


scm = StructuralCausalModel(
    structure, 
    size=SIZE
)
scm.fit(variables)

scm.sample(20)

array([ 1.09491246, -0.27845601,  2.40656785, -2.00392536,  1.14228338,
        0.23018398,  0.91591741,  1.23414305,  2.80032669,  0.65528429,
       -0.40725641,  0.369558  , -0.59140744,  1.08663552, -1.81488153,
        0.33992503,  0.41384151, -1.51136998, -1.22007632, -0.35140382])

In [123]:
scm.model_params["Y"]

{'Intercept': 5.724587470723463e-17,
 'X': 5.000000000000005,
 'A': 2.999999999999999,
 'eps_Y': 0.9999999999999993}