In [14]:
from grnkit.data import Dream4MultifactorialDataset, Dream4TimeseriesDataset
from grnkit.evaluate import evaluate
from grnkit.ss.notears import notears_linear
import numpy as np

In [None]:
import numpy as np
import scipy.linalg as slin
import scipy.optimize as sopt
from scipy.special import expit as sigmoid
from tqdm import tqdm


def notears_linear(X, lambda1, loss_type, max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3):
    """Solve min_W L(W; X) + lambda1 ‖W‖_1 s.t. h(W) = 0 using augmented Lagrangian.

    Args:
        X (np.ndarray): [n, d] sample matrix
        lambda1 (float): l1 penalty parameter
        loss_type (str): l2, logistic, poisson
        max_iter (int): max num of dual ascent steps
        h_tol (float): exit if |h(w_est)| <= htol
        rho_max (float): exit if rho >= rho_max
        w_threshold (float): drop edge if |weight| < threshold

    Returns:
        W_est (np.ndarray): [d, d] estimated DAG
    """
    def _loss(W):
        """Evaluate value and gradient of loss."""
        M = X @ W
        if loss_type == 'l2':
            R = X - M
            loss = 0.5 / X.shape[0] * (R ** 2).sum()
            G_loss = - 1.0 / X.shape[0] * X.T @ R
        elif loss_type == 'logistic':
            loss = 1.0 / X.shape[0] * (np.logaddexp(0, M) - X * M).sum()
            G_loss = 1.0 / X.shape[0] * X.T @ (sigmoid(M) - X)
        elif loss_type == 'poisson':
            S = np.exp(M)
            loss = 1.0 / X.shape[0] * (S - X * M).sum()
            G_loss = 1.0 / X.shape[0] * X.T @ (S - X)
        else:
            raise ValueError('unknown loss type')
        return loss, G_loss

    def _h(W):
        """Evaluate value and gradient of acyclicity constraint."""
        E = slin.expm(W * W)  # (Zheng et al. 2018)
        h = np.trace(E) - d
        #     # A different formulation, slightly faster at the cost of numerical stability
        #     M = np.eye(d) + W * W / d  # (Yu et al. 2019)
        #     E = np.linalg.matrix_power(M, d - 1)
        #     h = (E.T * M).sum() - d
        G_h = E.T * W * 2
        return h, G_h

    def _adj(w):
        """Convert doubled variables ([2 d^2] array) back to original variables ([d, d] matrix)."""
        return (w[:d * d] - w[d * d:]).reshape([d, d])

    def _func(w):
        """Evaluate value and gradient of augmented Lagrangian for doubled variables ([2 d^2] array)."""
        W = _adj(w)
        loss, G_loss = _loss(W)
        h, G_h = _h(W)
        obj = loss + 0.5 * rho * h * h + alpha * h + lambda1 * w.sum()
        G_smooth = G_loss + (rho * h + alpha) * G_h
        g_obj = np.concatenate((G_smooth + lambda1, - G_smooth + lambda1), axis=None)
        return obj, g_obj

    n, d = X.shape
    w_est, rho, alpha, h = np.zeros(2 * d * d), 1.0, 0.0, np.inf  # double w_est into (w_pos, w_neg)
    bnds = [(0, 0) if i == j else (0, None) for _ in range(2) for i in range(d) for j in range(d)]
    if loss_type == 'l2':
        X = X - np.mean(X, axis=0, keepdims=True)
    for epoch in tqdm(range(max_iter)):
        w_new, h_new = None, None
        while rho < rho_max:
            sol = sopt.minimize(_func, w_est, method='L-BFGS-B', jac=True, bounds=bnds)
            w_new = sol.x
            h_new, _ = _h(_adj(w_new))
            if h_new > 0.25 * h:
                rho *= 10
            else:
                break
        w_est, h = w_new, h_new
        alpha += rho * h
        print("Epoch {}: {:.4f}, {:.4f}".format(epoch, h, rho))
        if h <= h_tol or rho >= rho_max:
            break
    W_est = _adj(w_est)
    W_est[np.abs(W_est) < w_threshold] = 0
    return W_est


In [2]:
dt = Dream4MultifactorialDataset(1)

In [20]:
W_est = notears_linear(dt.expression_data, lambda1=0.000, loss_type='l2')

  1%|▍                                          | 1/100 [00:21<35:49, 21.72s/it]

Epoch 0: 0.1382, 1.0000


  2%|▊                                        | 2/100 [02:36<2:23:43, 88.00s/it]

Epoch 1: 0.0306, 10.0000


  3%|█▏                                      | 3/100 [05:25<3:22:34, 125.30s/it]

Epoch 2: 0.0073, 100.0000


  3%|█▏                                      | 3/100 [07:34<4:05:10, 151.65s/it]


KeyboardInterrupt: 

In [19]:
y_pred = np.array([score[1] for score in np.ndenumerate(W_est)])
evaluate(dt.gold_standard, y_pred)

AUROC:0.4220   AUPRC:0.0427


{'fpr': array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.16030534e-04,
        9.16030534e-04, 2.74809160e-03, 2.74809160e-03, 3.56234097e-03,
        3.56234097e-03, 7.63358779e-03, 7.63358779e-03, 1.49618321e-02,
        1.49618321e-02, 1.87277354e-02, 1.87277354e-02, 1.90330789e-02,
        1.90330789e-02, 1.95419847e-02, 1.95419847e-02, 2.20865140e-02,
        2.20865140e-02, 2.30025445e-02, 9.58778626e-01, 9.60610687e-01,
        9.60610687e-01, 9.67837150e-01, 9.67837150e-01, 9.74961832e-01,
        9.74961832e-01, 9.78524173e-01, 9.78524173e-01, 9.78829517e-01,
        9.78829517e-01, 9.79134860e-01, 9.79134860e-01, 9.79745547e-01,
        9.79745547e-01, 9.79847328e-01, 9.79847328e-01, 9.79949109e-01,
        9.79949109e-01, 9.81984733e-01, 9.81984733e-01, 9.85445293e-01,
        9.85445293e-01, 9.89618321e-01, 9.89618321e-01, 9.90229008e-01,
        9.90229008e-01, 9.90941476e-01, 9.90941476e-01, 9.91145038e-01,
        9.91145038e-01, 9.91755725e-01, 9.91755725e-01, 9

In [4]:
dt.expression_data

array([[0.2608087, 0.3773118, 0.4933734, ..., 0.4778708, 0.0328088,
        0.5699604],
       [0.2744197, 0.3338101, 0.3857391, ..., 0.4468232, 0.0065561,
        0.4677665],
       [0.1749744, 0.313279 , 0.4351399, ..., 0.393398 , 0.0302167,
        0.5934666],
       ...,
       [0.0488341, 0.3096201, 0.4898805, ..., 0.5632502, 0.0055914,
        0.526624 ],
       [0.0642824, 0.3880082, 0.4470457, ..., 0.4530106, 0.4743811,
        0.4024433],
       [0.1245994, 0.1909148, 0.3023966, ..., 0.3945163, 0.0040215,
        0.7057016]])

In [13]:
W_est[2, :]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [None]:
if __name__ == '__main__':
    from notears import utils
    utils.set_random_seed(1)

    n, d, s0, graph_type, sem_type = 100, 20, 20, 'ER', 'gauss'
    B_true = utils.simulate_dag(d, s0, graph_type)
    W_true = utils.simulate_parameter(B_true)
    np.savetxt('W_true.csv', W_true, delimiter=',')

    X = utils.simulate_linear_sem(W_true, n, sem_type)
    np.savetxt('X.csv', X, delimiter=',')

    W_est = notears_linear(X, lambda1=0.1, loss_type='l2')
    assert utils.is_dag(W_est)
    np.savetxt('W_est.csv', W_est, delimiter=',')
    acc = utils.count_accuracy(B_true, W_est != 0)
    print(acc)