In this notebook, we check what distribution we obtain with different NN initializations

In [1]:
import numpy as np
import random
import math
from tqdm.auto import tqdm
import torch
import torch.utils.data as data
import pytorch_lightning as pl
pl.seed_everything(42)

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import cm
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.reset_orig()

Global seed set to 42


In [28]:
import sys
sys.path.append("../")

from causal_graphs.graph_utils import *
from causal_graphs.graph_generation import *
from causal_graphs.graph_visualization import *
from causal_graphs.variable_distributions import *
from causal_discovery.multivariable_mlp import create_model

In [29]:
graph = generate_categorical_graph(num_vars=100,
                                   min_categs=10,
                                   max_categs=10,
                                   connected=True,
                                   graph_func=get_graph_func("random_max_10"),
                                   edge_prob=0.008,
                                   use_nn=True,
                                   seed=42)

In [30]:
samples = graph.sample(batch_size=100000, as_array=True)
samples.shape

(100000, 100)

In [31]:
vals = np.arange(10, dtype=np.int32)
counts = (vals[None,None] == samples[:,:,None]).sum(axis=0)
probs = counts / counts.sum(axis=1, keepdims=True)

In [6]:
ce = -(probs * np.log(probs)).sum(axis=1)

In [7]:
print("Mean", ce.mean())
print("Max", ce.max(), ce.argmax())
print("Min", ce.min(), ce.argmin())
print("Median", np.median(ce))

Mean 1.9391802220664622
Max 2.235183420993497 93
Min 0.8664123797698585 7
Median 1.9729833792257185


In [8]:
probs[ce.argmin()]

array([0.0436 , 0.04264, 0.01012, 0.03379, 0.01609, 0.00133, 0.80769,
       0.01728, 0.0179 , 0.00956])

In [9]:
counts[ce.argmin()]

array([ 4360,  4264,  1012,  3379,  1609,   133, 80769,  1728,  1790,
         956])

In [10]:
probs[ce.argmax()]

array([0.08729, 0.16134, 0.09522, 0.04588, 0.06884, 0.15279, 0.07189,
       0.10397, 0.13865, 0.07413])

In [11]:
counts[ce.argmax()]

array([ 8729, 16134,  9522,  4588,  6884, 15279,  7189, 10397, 13865,
        7413])

In [12]:
# graph.variables[100].prob_dist.prob_func.embed_module.weight

 ## Check the importance of an edge

In [34]:
# BASE_PATH = '../experiments/checkpoints/array_job_200_7617862/experiment_2/'
BASE_PATH = '../experiments/checkpoints/2021_04_29__11_19_38/'
graph = CausalDAG.load_from_file(BASE_PATH + 'graph_1.pt')

In [35]:
samples = graph.sample(batch_size=100000, as_array=True)
samples = torch.from_numpy(samples)

In [36]:
class SimpleModel(nn.Module):
    
    def __init__(self, num_vars, num_categs):
        super().__init__()
        self.num_vars = num_vars
        if num_vars == 0:
            num_vars = 1
        self.embedding = nn.Embedding(num_vars*num_categs, 64)
        self.embedding.weight.data.mul_(1./math.sqrt(num_vars))
        self.net = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Linear(64, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, num_categs),
            nn.LogSoftmax(dim=-1)
        )
        pos_trans = torch.arange(num_vars, dtype=torch.long) * num_categs
        self.register_buffer("pos_trans", pos_trans, persistent=False)

        
    def forward(self, x):
        if self.num_vars > 0:
            embed = self.embedding(x + self.pos_trans[None])
            embed = embed.sum(dim=1)
        else:
            embed = self.embedding(x.new_zeros(x.shape[:-1]))
        out = self.net(embed)
        return out

In [37]:
node1 = 64
node2 = 272

In [38]:
parents = torch.from_numpy(np.where(graph.adj_matrix[:,node2])[0])
assert node1 in parents
parents_excl = parents[parents != node1]
assert node1 not in parents_excl

print("Parents", parents)
print("Parents excluded", parents_excl)

Parents tensor([64])
Parents excluded tensor([], dtype=torch.int64)


In [43]:
vals = torch.arange(10, dtype=torch.long)
counts = (vals[None] == samples[:,(node1,node2),None]).sum(dim=0)
probs = counts / counts.float().sum(dim=1, keepdims=True)
print(probs)

tensor([[0.0597, 0.1461, 0.1488, 0.0287, 0.0363, 0.0838, 0.2011, 0.1309, 0.0604,
         0.1041],
        [0.0278, 0.0448, 0.0394, 0.1028, 0.1555, 0.2010, 0.1138, 0.1850, 0.1100,
         0.0198]])


In [22]:
model_full = SimpleModel(num_vars=parents.shape[0], num_categs=10)
model_excl = SimpleModel(num_vars=parents_excl.shape[0], num_categs=10)

optim_full = torch.optim.Adam(model_full.parameters(), lr=5e-3)
optim_excl = torch.optim.Adam(model_excl.parameters(), lr=5e-3)

loss_module = nn.NLLLoss()

dataset = data.TensorDataset(samples)
data_loader = data.DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)

In [23]:
def eval_batch(batch):
    out_full = model_full(batch[:,parents])
    out_excl = model_excl(batch[:,parents_excl])
    labels = batch[:,node2]
    loss_full = loss_module(out_full, labels)
    loss_excl = loss_module(out_excl, labels)
    return loss_full, loss_excl

for _ in tqdm(range(10), leave=False, desc="Epochs"):
    for batch in tqdm(data_loader, leave=False, desc="Iterations"):
        batch = batch[0]
        optim_full.zero_grad()
        optim_excl.zero_grad()
        loss_full, loss_excl = eval_batch(batch)
        loss_full.backward()
        loss_excl.backward()
        optim_full.step()
        optim_excl.step()

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=10.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Iterations', max=782.0, style=ProgressStyle(description_w…

In [24]:
eval_samples = graph.sample(batch_size=100000, as_array=True)
eval_samples = torch.from_numpy(eval_samples)
with torch.no_grad():
    nll_full, nll_excl = eval_batch(eval_samples)
    nll_full, nll_excl = nll_full.item(), nll_excl.item()
    print("NLL all parents", nll_full)
    print("NLL without parents", nll_excl)
    print("Difference", nll_excl-nll_full)

NLL all parents 1.9130173921585083
NLL without parents 2.092456340789795
Difference 0.17943894863128662
