In [1]:
import networkx as nx
import numpy as np
import torch
import utils
from model import DAGAutoencoder
import inference
import pandas as pd
import matplotlib.pyplot as plt
from datasets import reorder_dag, get_full_ordering

In [2]:
shuffling = 0
seed = 1
standardize = 0
sample_size = 10000
batch_size = 50
max_iters =  30000
eval_interval = 100
eval_iters = 100
validation_fraction = 0.1
np.random.seed(seed=seed)
torch.manual_seed(seed)
device = 'cuda'
dropout_rate = 0.0
learning_rate = 1e-3

neurons_per_layer = [3,6,3]

def generate_data_mediation(N):
    DAGnx = nx.DiGraph()
    
    Ux = np.random.randn(N)
    X =  Ux

    Um = np.random.randn(N)
    M =  0.5 * X + Um

    Uy = np.random.randn(N)
    Y =  0.8 * M + Uy

    M0 = 0.5 * 0 + Um 
    M1 = 0.5 * 1 + Um

    Y0 = 0.8 * M0 + Uy 
    Y1 = 0.8 * M1 +  Uy 

    all_data_dict = {'X': X, 'M': M, 'Y': Y}

    # types can be 'cat' (categorical) 'cont' (continuous) or 'bin' (binary)
    var_types = {'X': 'cont', 'M': 'cont', 'Y': 'cont'}

    DAGnx.add_edges_from([('X', 'M'), ('M', 'Y')])
    DAGnx = reorder_dag(dag=DAGnx)  # topologically sorted dag
    var_names = list(DAGnx.nodes())  # topologically ordered list of variables
    all_data = np.stack([all_data_dict[key] for key in var_names], axis=1)
    causal_ordering = get_full_ordering(DAGnx)
    ordered_var_types = dict(sorted(var_types.items(), key=lambda item: causal_ordering[item[0]]))

    return all_data, DAGnx, var_names, causal_ordering, ordered_var_types, Y0, Y1

_, _, _, _, _, Y0, Y1 = generate_data_mediation(N=1000000)
ATE = (Y1 - Y0).mean()  # ATE based off a large sample
all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size)
print(var_names, ATE)

input_dim = all_data.shape[1]

# prepend the input size to neurons_per_layer if not included in neurons_per_layer
# append the intput size to neurons_per_layer (output) if not included in neurons_per_layer
neurons_per_layer = [6,12,6]
neurons_per_layer.insert(0, input_dim)
neurons_per_layer.append(input_dim)

utils.assert_neuron_layers(layers=neurons_per_layer, input_size=input_dim)

indices = np.arange(0, len(all_data))
np.random.shuffle(indices)

val_inds = indices[:int(validation_fraction*len(indices))]
train_inds = indices[int(validation_fraction*len(indices)):]
train_data = all_data[train_inds]
val_data = all_data[val_inds]

train_data, val_data = torch.from_numpy(train_data).float(),  torch.from_numpy(val_data).float()

initial_adj_matrix = nx.to_numpy_array(DAG)

initial_masks = [torch.from_numpy(mask).float().to(torch.float64) for mask in
                 utils.expand_adjacency_matrix(neurons_per_layer[1:], initial_adj_matrix)]


model = DAGAutoencoder(neurons_per_layer=neurons_per_layer, dag=DAG, causal_ordering=causal_ordering, var_types=var_types, dropout_rate=dropout_rate).to(device)
model.initialize_masks(initial_masks)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

['X', 'M', 'Y'] 0.3999999999999992


In [3]:
 def get_batch(train_data, val_data, split, device, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data), (batch_size,))
    x = data[ix]
    return x.to(device)

all_var_losses = {}
for iter_ in range(0, max_iters):
    # train and update the model
    model.train()

    xb = get_batch(train_data=train_data, val_data=val_data, split='train', device=device, batch_size=batch_size)
    xb_mod = torch.clone(xb.detach())
    X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=shuffling)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


    if iter_ % eval_interval == 0:  # evaluate the loss (no gradients)
        for key in loss_dict.keys():
            if key not in all_var_losses.keys():
                all_var_losses[key] = []
            all_var_losses[key].append(loss_dict[key])

        model.eval()
        eval_loss = {}
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):

                xb = get_batch(train_data=train_data, val_data=val_data, split=split, device=device,
                               batch_size=batch_size)
                xb_mod = torch.clone(xb.detach())
                X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=False)
                losses[k] = loss.item()
            eval_loss[split] = losses.mean()
        print(f"step {iter_} of {max_iters}: train_loss {eval_loss['train']:.4f}, val loss {eval_loss['val']:.4f}")

step 0 of 30000: train_loss 3.0399, val loss 3.2258
step 100 of 30000: train_loss 2.9986, val loss 3.1372
step 200 of 30000: train_loss 2.9429, val loss 2.9852
step 300 of 30000: train_loss 2.4101, val loss 2.4992
step 400 of 30000: train_loss 2.2368, val loss 2.2721
step 500 of 30000: train_loss 2.1131, val loss 2.0585
step 600 of 30000: train_loss 2.0324, val loss 2.0194
step 700 of 30000: train_loss 2.0381, val loss 2.1007
step 800 of 30000: train_loss 2.1115, val loss 2.0330
step 900 of 30000: train_loss 2.0591, val loss 2.0903
step 1000 of 30000: train_loss 2.0327, val loss 2.0160
step 1100 of 30000: train_loss 2.0122, val loss 2.0948
step 1200 of 30000: train_loss 2.0431, val loss 2.0102
step 1300 of 30000: train_loss 2.0317, val loss 2.0281
step 1400 of 30000: train_loss 2.0018, val loss 2.0589
step 1500 of 30000: train_loss 2.0215, val loss 1.9978
step 1600 of 30000: train_loss 2.0242, val loss 2.0173
step 1700 of 30000: train_loss 2.0465, val loss 1.9945
step 1800 of 30000: tr

In [6]:

df = pd.DataFrame(all_data, columns=var_names)
data_dict = df.to_dict(orient='list')
cause_var = 'X'
effect_var = 'M'
effect_index = utils.find_element_in_list(var_names, target_string=effect_var)
ci = inference.CausalInference(model=model, device=device)

model.eval()
intervention_nodes_vals_0 = {'X': 0}
intervention_nodes_vals_1 = {'X': 1}
D0 = ci.forward(data=all_data , intervention_nodes_vals=intervention_nodes_vals_0)
D1 = ci.forward(data=all_data , intervention_nodes_vals=intervention_nodes_vals_1)



est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()
print('True ATE X->M:', 0.5, 'Estimated:', est_ATE)

True ATE: 0.5 Estimated: 0.47507479041814804


In [5]:

df = pd.DataFrame(all_data, columns=var_names)
data_dict = df.to_dict(orient='list')
cause_var = 'X'
effect_var = 'Y'
effect_index = utils.find_element_in_list(var_names, target_string=effect_var)
ci = inference.CausalInference(model=model, device=device)

model.eval()
intervention_nodes_vals_0 = {'X': 0}
intervention_nodes_vals_1 = {'X': 1}
D0 = ci.forward(data=all_data , intervention_nodes_vals=intervention_nodes_vals_0)
D1 = ci.forward(data=all_data , intervention_nodes_vals=intervention_nodes_vals_1)



est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()
print('True ATE X->M:', 0.5, 'Estimated:', est_ATE)

0.37255560606718063
