# Demo of ALM and QPM for NOTEARS-MLP
- Code modified from https://github.com/xunzheng/notears/blob/ba61337bd0e5410c04cc708be57affc191a8c424/notears/nonlinear.py#L213

In [1]:
import sys
sys.path.append('..')    # To import notears from parent directory

import numpy as np
import torch

from notears.nonlinear import NotearsMLP, notears_nonlinear
from notears import utils

# Setup
torch.set_default_dtype(torch.double)
np.set_printoptions(precision=3)

## Generate data

In [2]:
utils.set_random_seed(123)

n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
B_true = utils.simulate_dag(d, s0, graph_type)
X = utils.simulate_nonlinear_sem(B_true, n, sem_type)

## Constrained optimization with quadratic penalty method

In [3]:
%%time
model = NotearsMLP(dims=[d, 10, 1], bias=True)
W_est = notears_nonlinear(model, X, lambda1=0.01, lambda2=0.01, opt_type='qpm')
assert utils.is_dag(W_est)
acc = utils.count_accuracy(B_true, W_est != 0)
print(acc)

{'fdr': 0.1, 'tpr': 1.0, 'fpr': 1.0, 'shd': 1, 'nnz': 10}
CPU times: user 57.2 s, sys: 40.5 s, total: 1min 37s
Wall time: 24.7 s


## Constrained optimization with augmented Lagrangian method

In [4]:
%%time
model = NotearsMLP(dims=[d, 10, 1], bias=True)
W_est = notears_nonlinear(model, X, lambda1=0.01, lambda2=0.01, opt_type='alm')
assert utils.is_dag(W_est)
acc = utils.count_accuracy(B_true, W_est != 0)
print(acc)

{'fdr': 0.1, 'tpr': 1.0, 'fpr': 1.0, 'shd': 1, 'nnz': 10}
CPU times: user 1min 19s, sys: 1min 1s, total: 2min 21s
Wall time: 35.7 s
