In [1]:
import networkx as nx
import numpy as np
import torch
import utils
from model import CFCN
from utils.inference import CausalInference
import pandas as pd
import matplotlib.pyplot as plt
from datasets import reorder_dag, get_full_ordering


In [3]:
shuffling = 0
seed = 1
standardize = 0
sample_size = 1000000
batch_size = 50
max_iters =  6000
eval_interval = 100
eval_iters = 100
validation_fraction = 0.1
np.random.seed(seed=seed)
torch.manual_seed(seed)
device = 'cuda'
dropout_rate = 0.0
learning_rate = 1e-3

neurons_per_layer = [3,6, 6, 6, 3]

def generate_data_mediation(N):
    DAGnx = nx.DiGraph()
    
    Ux = np.random.randn(N)
    X =  Ux

    Um = np.random.randn(N)
    M =  0.9 * X + Um    

    Uy = np.random.randn(N)
    Y =  0.6 * M + 1.2 * X + Uy

    M0 = 0.9 * 0 + Um 
    M1 = 0.9 * 1 + Um

    Y0 = 0.6 * M0 +  1.2 * 0 + Uy 
    Y1 = 0.6 * M1 +  1.2 * 1 + Uy 

    # X-> M = 0.9
    # X-> Y = 1.2 
    # M -> Y = 0.6 
    # partial effect = 0.9*0.6 = .54
    # total effect = .54 + 1.2 = 1.74

    all_data_dict = {'X': X, 'M': M, 'Y': Y}

    # types can be 'cat' (categorical) 'cont' (continuous) or 'bin' (binary)
    var_types = {'X': 'cont', 'M': 'cont', 'Y': 'cont'}

    DAGnx.add_edges_from([('X', 'M'), ('M', 'Y')])
    DAGnx = reorder_dag(dag=DAGnx)  # topologically sorted dag
    var_names = list(DAGnx.nodes())  # topologically ordered list of variables
    all_data = np.stack([all_data_dict[key] for key in var_names], axis=1)
    causal_ordering = get_full_ordering(DAGnx)
    ordered_var_types = dict(sorted(var_types.items(), key=lambda item: causal_ordering[item[0]]))

    return all_data, DAGnx, var_names, causal_ordering, ordered_var_types, Y0, Y1

_, _, _, _, _, Y0, Y1 = generate_data_mediation(N=1000000)
ATE = (Y1 - Y0).mean()  # ATE based off a large sample
all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size)
print(var_names, ATE)

['X', 'M', 'Y'] 1.740000000000002


In [4]:

def get_batch(train_data, val_data, split, device, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data), (batch_size,))
    x = data[ix]
    return x.to(device)

In [None]:
for i in range(10):
    all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size)
    
    input_dim = all_data.shape[1]
    
    indices = np.arange(0, len(all_data))
    np.random.shuffle(indices)
    
    val_inds = indices[:int(validation_fraction*len(indices))]
    train_inds = indices[int(validation_fraction*len(indices)):]
    train_data = all_data[train_inds]
    val_data = all_data[val_inds]
    
    train_data, val_data = torch.from_numpy(train_data).float(),  torch.from_numpy(val_data).float()
    
    model = CFCN(neurons_per_layer=neurons_per_layer, dag=DAG, causal_ordering=causal_ordering, var_types=var_types, dropout_rate=dropout_rate).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    
    all_var_losses = {}
    for iter_ in range(0, max_iters):
        # train and update the model
        model.train()
    
        xb = get_batch(train_data=train_data, val_data=val_data, split='train', device=device, batch_size=batch_size)
        xb_mod = torch.clone(xb.detach())
        X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=shuffling)
    
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    
    
        if iter_ % eval_interval == 0:  # evaluate the loss (no gradients)
            for key in loss_dict.keys():
                if key not in all_var_losses.keys():
                    all_var_losses[key] = []
                all_var_losses[key].append(loss_dict[key])
    
            model.eval()
            eval_loss = {}
            for split in ['train', 'val']:
                losses = torch.zeros(eval_iters)
                for k in range(eval_iters):
    
                    xb = get_batch(train_data=train_data, val_data=val_data, split=split, device=device,
                                   batch_size=batch_size)
                    xb_mod = torch.clone(xb.detach())
                    X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=False)
                    losses[k] = loss.item()
                eval_loss[split] = losses.mean()
            print(f"step {iter_} of {max_iters}: train_loss {eval_loss['train']:.4f}, val loss {eval_loss['val']:.4f}")
    
    
    df = pd.DataFrame(all_data, columns=var_names)
    data_dict = df.to_dict(orient='list')
    cause_var = 'X'
    effect_var = 'Y'
    effect_index = var_names.index(effect_var)
    ci = CausalInference(dag=DAG)
    
    model.eval()
    intervention_nodes_vals_0 = {'X': 0}
    intervention_nodes_vals_1 = {'X': 1}
    D0 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_0)
    D1 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_1)
    
    
    est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()
    print(est_ATE)



1.8558830618858337
