In [1]:
import numpy as np
import torch
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib as mpl
import time
import torch.nn as nn

In [2]:
from src.probabilistic_dag_model.probabilistic_dag import ProbabilisticDAG
from src.probabilistic_dag_model.train_dag import train

# Sampling time

In [3]:
n_samples = 30
sampling_times = np.zeros(n_samples)
prob_dag_model = ProbabilisticDAG(n_nodes=100,
                                  order_type='topk',
                                  #order_type='sinkhorn',
                                  initial_adj=None, 
                                  seed=100)
for i in range(n_samples):
    t0 = time.time()
    A = prob_dag_model.sample().detach().cpu().numpy()
    sampling_times[i] = time.time() - t0
print('Mean sampling time: ', sampling_times.mean())

Mean sampling time:  0.0018559137980143229


# DAG learning with a ground-truth dag

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

n_nodes=10
true_dag_adj = torch.triu(torch.ones(n_nodes, n_nodes, device=device), 1)
model = ProbabilisticDAG(n_nodes=n_nodes,
                         hard=True,
                         #order_type='sinkhorn',
                         order_type='topk',
                         lr=1e-2,
                         seed=0)

In [5]:
model, losses, sampled_mse_losses = train(model,
                                          true_dag_adj=true_dag_adj,
                                          max_epochs=30000,
                                          patience=100
                                         )

Epoch 0 -> prob_abs_loss 54.51499938964844 | sampled_nll_loss 1.6532267332077026
Model saved
Epoch 10 -> prob_abs_loss 23.762046813964844 | sampled_nll_loss 1.7032265663146973
Model saved
Epoch 20 -> prob_abs_loss 23.06319808959961 | sampled_nll_loss 1.593226671218872
Model saved
Epoch 30 -> prob_abs_loss 22.320295333862305 | sampled_nll_loss 1.66322660446167
Model saved
Epoch 40 -> prob_abs_loss 21.602474212646484 | sampled_nll_loss 1.6732265949249268
Model saved
Epoch 50 -> prob_abs_loss 20.840595245361328 | sampled_nll_loss 1.5832267999649048
Model saved
Epoch 60 -> prob_abs_loss 17.916603088378906 | sampled_nll_loss 1.6532267332077026
Model saved
Epoch 70 -> prob_abs_loss 17.129655838012695 | sampled_nll_loss 1.6432266235351562
Model saved
Epoch 80 -> prob_abs_loss 17.504119873046875 | sampled_nll_loss 1.603226661682129
Epoch 90 -> prob_abs_loss 16.755226135253906 | sampled_nll_loss 1.573226809501648
Model saved
Epoch 100 -> prob_abs_loss 15.992427825927734 | sampled_nll_loss 1.553