# Testing Causal Discovery

In [1]:
import numpy as np
import random
from tqdm.notebook import tqdm
from copy import deepcopy
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import cm
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.reset_orig()

In [2]:
import sys
sys.path.append("../")

from multivariable_mlp import *
from graph_fitting import *
from graph_scoring import *
from graph_update import *
from graph_discovery import *
from causal_graphs.graph_generation import generate_categorical_graph, generate_chain, generate_random_graph
from causal_graphs.graph_visualization import visualize_graph
from causal_graphs.graph_utils import adj_matrix_to_edges

In [3]:
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

## Multi-variable MLP

### Layer test

In [4]:
BATCH_SIZE = 2
NUM_VARS = 4
input_data = torch.randn(BATCH_SIZE, NUM_VARS)

In [5]:
linear_layer = MultivarLinear(c_in=NUM_VARS, c_out=NUM_VARS, extra_dims=[NUM_VARS])

In [6]:
out = linear_layer(input_data)
print("Out", out.shape)

Out torch.Size([2, 4, 4])


In [7]:
input_data = torch.randn(BATCH_SIZE, NUM_VARS)
out1 = linear_layer(input_data)
input_data[0,0] = -10
out2 = linear_layer(input_data)
out_equals = (out1 == out2)
assert out_equals[1:].all(), "Batch not independent"
assert (~out_equals[0]).all(), "Not all inputs were influenced"

### MLP test

In [8]:
input_mask_module = InputMask(None)

In [9]:
mlp = MultivarMLP(input_dims=NUM_VARS, 
                  hidden_dims=[64, 64], 
                  output_dims=1, 
                  extra_dims=[NUM_VARS],
                  pre_layers=input_mask_module,
                  actfn=lambda : nn.LeakyReLU(0.1))
print(mlp)

MultivarMLP(
  (layers): ModuleList(
    (0): InputMask()
    (1): MultivarLinear(c_in=4, c_out=64, extra_dims=[4])
    (2): LeakyReLU(negative_slope=0.1)
    (3): MultivarLinear(c_in=64, c_out=64, extra_dims=[4])
    (4): LeakyReLU(negative_slope=0.1)
    (5): MultivarLinear(c_in=64, c_out=1, extra_dims=[4])
  )
)


In [10]:
mask = torch.bernoulli(torch.zeros(BATCH_SIZE, NUM_VARS, NUM_VARS)+0.5)
mask[:,torch.arange(mask.shape[1]),torch.arange(mask.shape[2])] = 0.
out = mlp(input_data, mask=mask)
out.shape

torch.Size([2, 4, 1])

In [11]:
input_data[:,0] = 0

In [12]:
out2 = mlp(input_data, mask=mask)
out_equal = (out == out2)

In [13]:
# Check something

### MLP categorical

In [14]:
NUM_CATEGS = 10
input_data = torch.randint(NUM_CATEGS, size=(BATCH_SIZE, NUM_VARS), dtype=torch.long)

In [15]:
embed = EmbedLayer(num_vars=NUM_VARS,
                   num_categs=NUM_CATEGS,
                   hidden_dim=64,
                   input_mask=input_mask_module,
                   share_embeds=False,
                   sparse_embeds=True)
print(embed)

EmbedLayer(
  (input_mask): InputMask()
  (embedding): Embedding(160, 64)
)


In [16]:
out = embed(input_data, mask=mask)
print(out.shape)
print(out.flatten(0,1).std(0))
print(out)

torch.Size([2, 4, 64])
tensor([0.9217, 1.2478, 0.9326, 0.7497, 0.6657, 0.6379, 0.6909, 1.0763, 0.3857,
        0.8445, 0.5484, 0.9653, 0.9020, 1.2382, 0.8558, 1.0464, 2.1659, 2.1048,
        1.4442, 0.7675, 0.8855, 0.9104, 1.3654, 1.3076, 1.7468, 1.4325, 0.6232,
        0.9345, 0.5947, 1.1745, 1.1164, 0.3979, 0.6818, 1.9140, 1.1503, 1.1840,
        0.9970, 0.3121, 0.8908, 0.7076, 1.0906, 0.7537, 1.3655, 1.3164, 1.9447,
        0.7903, 1.4882, 1.4770, 0.4753, 1.4019, 1.8230, 1.4317, 0.4038, 1.2740,
        1.2223, 1.8110, 0.8602, 1.1388, 0.8527, 0.6196, 0.7831, 0.5557, 0.8361,
        1.3762], grad_fn=<StdBackward1>)
tensor([[[ 0.4782, -1.0840,  0.8132,  1.6424,  0.2828,  1.1481,  1.0430,
           2.7088, -0.7327, -0.0379,  0.0641, -0.9786,  0.3501, -1.3878,
          -1.2287, -0.8569,  1.3064, -1.7483, -0.9376,  0.2871, -1.1425,
          -0.1643, -3.6239, -3.5391, -1.4775,  1.5825,  0.1483,  1.9029,
           0.3089,  0.1018,  0.9233,  0.7635, -0.0615,  1.6846,  2.1429,
           

In [17]:
embed.sparse_embeds = False
out = embed(input_data, mask=mask)
print(out.shape)
print(out.flatten(0,1).std(0))
print(out)

torch.Size([2, 4, 64])
tensor([0.7945, 2.2059, 1.9231, 2.2923, 1.2264, 1.0312, 1.4934, 1.6251, 2.2076,
        1.2514, 1.0858, 1.8541, 1.8954, 3.4863, 2.0298, 1.5322, 3.6553, 2.3711,
        1.6388, 1.7063, 2.8773, 2.4683, 2.0014, 1.4685, 1.5479, 1.8328, 1.2035,
        2.6172, 0.5401, 0.9879, 1.8226, 1.0069, 1.9381, 2.4227, 2.2222, 2.0826,
        1.0361, 2.4461, 1.2708, 1.5977, 2.2510, 0.8937, 2.7040, 2.5244, 1.8126,
        0.5556, 1.7591, 1.0348, 1.0649, 1.8053, 2.0368, 2.1577, 2.1166, 1.4252,
        1.2361, 1.5452, 2.0223, 0.6172, 1.5423, 1.5789, 1.4353, 1.8903, 2.1281,
        1.8252], grad_fn=<StdBackward1>)
tensor([[[-4.5362e-01,  1.5029e+00,  4.3225e-01,  1.6117e+00, -6.0602e-01,
           2.0990e+00,  5.0617e-01,  5.3775e+00,  3.5724e-01, -2.0461e+00,
           7.8914e-01,  2.5200e-02,  3.4035e-01, -2.1932e+00, -7.8100e-01,
           1.3508e+00,  1.1449e+00, -4.0495e+00, -2.2813e+00, -4.4673e-01,
          -3.5669e+00,  2.2372e+00, -2.5506e+00, -2.0740e+00, -1.2954e+00,
 

In [18]:
mlp = MultivarMLP(input_dims=embed.output_dim, 
                  hidden_dims=[64, 64], 
                  output_dims=NUM_CATEGS, 
                  extra_dims=[NUM_VARS],
                  pre_layers=embed)
mlp.eval()
print(mlp)

TypeError: __init__() missing 1 required positional argument: 'actfn'

In [None]:
out = mlp(input_data, mask=mask)
print(out.shape)

In [None]:
a = torch.zeros(8, 3, 2)
a.chunk(5, dim=0)

## Graph fitting

### Dataset generation

In [None]:
class CategoricalData(torch.utils.data.Dataset):
    
    def __init__(self, graph, dataset_size):
        super().__init__()
        self.graph = graph
        self.var_names = [v.name for v in self.graph.variables]
        data = graph.sample(batch_size=dataset_size, as_array=True)
        self.data = torch.from_numpy(data).long()
        
    
    def __len__(self):
        return self.data.shape[0]
    
    
    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
graph = generate_categorical_graph(num_vars=NUM_VARS,
                                   min_categs=NUM_CATEGS,
                                   max_categs=NUM_CATEGS,
                                   edge_prob=0.0,
                                   connected=True,
                                   seed=42)
visualize_graph(graph, show_plot=True, figsize=(3, 2), layout="circular")

In [None]:
dataset = CategoricalData(graph, dataset_size=64*128)

In [None]:
dataset[0]

In [None]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)

### Graph fitting

In [None]:
model = create_model(num_vars=NUM_VARS, num_categs=NUM_CATEGS, hidden_dims=[64, 64])
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

In [None]:
fittingModule = GraphFitting(model, optimizer, data_loader)

In [None]:
gamma = nn.Parameter(torch.randn(NUM_VARS, NUM_VARS))
sample_matrix = torch.sigmoid(gamma).detach()

In [None]:
avg_loss = 0
for _ in tqdm(range(100)):
    avg_loss += fittingModule.fit_step(sample_matrix)
print("Average loss", avg_loss / 100)

## Graph scoring

In [None]:
scoringModule = GraphScoring(model=model, graph=graph, N_s=4, C_s=2, batch_size=64)

In [None]:
gammagrad, logregret, var_idx = scoringModule.score(gamma)

In [None]:
print("Variable to perform intervention on:", var_idx)

In [None]:
print("Shape", gammagrad.shape, logregret.shape)

## Graph Update

In [None]:
updateModule = GraphUpdate(lambda_sparse=0.1, lambda_DAG=0.1)

In [None]:
updateModule.update(gammagrad, logregret, gamma, var_idx)

In [None]:
gamma.grad

In [None]:
print(graph)

In [None]:
graph.adj_matrix

## Graph Discovery

### Toy graph

In [None]:
NUM_VARS = 3
toy_graph = generate_categorical_graph(num_vars=NUM_VARS, min_categs=3, max_categs=3, graph_func=generate_random_graph, edge_prob=0.4, seed=0)
visualize_graph(toy_graph, show_plot=True, figsize=(2, 2), layout="graphviz")

In [None]:
def uniform(inputs, batch_size):
    return np.ones((3,))/3

print(toy_graph.variables[0])
toy_graph.variables[0].prob_dist.prob_func = uniform
for i in range(3):
    print("Output %i: %4.2f" % (i, toy_graph.variables[0].prob_dist.prob(None, i)))

In [None]:
def noisy_identity(inputs, batch_size, noise_level=0.1):
    probs = np.zeros((batch_size, 3))
    for val in inputs.values():
        val_grid = np.array([noise_level]*3)
        val_grid = np.repeat(val_grid[None], batch_size, axis=0)
        val_grid[np.arange(batch_size), val] = 1 - noise_level*2
        probs += val_grid
    probs /= len(inputs)
    return probs

print(toy_graph.variables[1])
toy_graph.variables[1].prob_dist.prob_func = lambda *args, **kwargs: noisy_identity(*args, **kwargs, noise_level=0.25)
print("---")
for a in range(3):
    for i in range(3):
        print("Prob for val=%i if A=%i: %4.2f" % (i, a, toy_graph.variables[1].prob_dist.prob({"A": a}, i)))
    print("---")

In [None]:
if NUM_VARS == 3:
    print(toy_graph.variables[2])
    toy_graph.variables[2].prob_dist.prob_func = noisy_identity
    print("---")
    for a in range(3):
        for c in range(3):
            for i in range(3):
                print("Prob for val=%i if A=%i,C=%i: %4.2f" % (i, a, c, toy_graph.variables[2].prob_dist.prob({"A":a, "C": c}, i)))
            print("---")

In [None]:
print("Adjacency matrix:")
toy_graph.adj_matrix.astype(np.int32)

In [None]:
toy_graph.sample(batch_size=8, as_array=False)

In [None]:
discModule = GraphDiscovery(graph=toy_graph, 
                            model_iters=1000, 
                            gamma_iters=50, 
                            dataset_size=10000, 
                            N_s=10,
                            C_s=20,
                            lambda_sparse=0.02, 
                            lambda_DAG=0.1,
                            hidden_dims=[64],
                            lr_gamma=2e-1,
                            betas_gamma=(0.1,-1.0),
                            guide_inter=True
                           )

In [None]:
gamma = discModule.discover_graph(num_epochs=1)

In [None]:
print(gamma)

### More advanced graph

In [None]:
NUM_VARS = 4
NUM_CATEGS = 5

graph = generate_categorical_graph(num_vars=NUM_VARS,
                                   min_categs=NUM_CATEGS,
                                   max_categs=NUM_CATEGS,
                                   edge_prob=0.3,
                                   connected=True,
                                   inputs_independent=False,
                                   use_nn=True, 
                                   seed=123)
visualize_graph(graph, show_plot=True, figsize=(8, 5), layout="graphviz", filename="example_graph_8_nodes.pdf")

In [None]:
%%time
discModule = GraphDiscovery(graph=graph, 
                            model_iters=400, 
                            gamma_iters=50, 
                            dataset_size=100000, 
                            lambda_sparse=0.1, 
                            lambda_DAG=2.0,
                            hidden_dims=[64],
                            betas_gamma=(0.1, -1.0),
                            lr_gamma=2e-2)
discModule.print_gamma_statistics()

In [None]:
gamma = discModule.discover_graph(num_epochs=1)

In [None]:
rounded_gamma = torch.round(gamma * 100)/100
rounded_gamma[torch.arange(gamma.shape[0]), torch.arange(gamma.shape[1])] = 0
print(rounded_gamma.detach())

In [None]:
print("Predicted adjacency matrix")
(gamma>0).int()

In [None]:
print("True adjacency matrix")
graph.adj_matrix.astype(np.int32)

In [None]:
print("Predicted graph")
copied_graph = deepcopy(graph)
copied_graph.adj_matrix = (gamma > 0.0).numpy()
copied_graph.edges = adj_matrix_to_edges(copied_graph.adj_matrix)
visualize_graph(copied_graph, show_plot=True, figsize=(4, 3), layout="circular", filename="example_graph_8_nodes_predicted.pdf")

In [None]:
discModule.print_gamma_statistics()