In [1]:
###################################
## comment
###################################
# few things to check -
#         1. data generation => n, d, s0, graph_type, sem_type = 200, 10, 40, 'ER', 'mlp'
#         2. env generation =>  noise_scales = [0.2, 1, 2, 5, 10]
#         3. nhidden=10, lambda1=0.01, lambda2=0.01, w_threshold=0.4, std_lambda=50, learning_rate=0.1
#         4. in utils.count_accuracy() => is_dag() is disabled/enabled?

In [2]:
###################################
## install and import
###################################

In [3]:
# !pip install python-igraph
from notears.locally_connected import LocallyConnected
from notears.lbfgsb_scipy import LBFGSBScipy
from notears.trace_expm import trace_expm
import torch
import torch.nn as nn
import numpy as np
import math
import notears.utils as ut
from sklearn.preprocessing import StandardScaler
from scipy.optimize import minimize

In [4]:
###################################
## class
###################################

In [5]:
class NotearsMLP(nn.Module):
    def __init__(self, dims, bias=True):
        super(NotearsMLP, self).__init__()
        assert len(dims) >= 2
        assert dims[-1] == 1
        d = dims[0]
        self.dims = dims
        # fc1: variable splitting for l1
        self.fc1_pos = nn.Linear(d, d * dims[1], bias=bias)
        self.fc1_neg = nn.Linear(d, d * dims[1], bias=bias)
        self.fc1_pos.weight.bounds = self._bounds()
        self.fc1_neg.weight.bounds = self._bounds()
        # fc2: local linear layers
        layers = []
        for l in range(len(dims) - 2):
            layers.append(LocallyConnected(d, dims[l + 1], dims[l + 2], bias=bias))
        self.fc2 = nn.ModuleList(layers)

    def _bounds(self):
        d = self.dims[0]
        bounds = []
        for j in range(d):
            for m in range(self.dims[1]):
                for i in range(d):
                    if i == j:
                        bound = (0, 0)
                    else:
                        bound = (0, None)
                    bounds.append(bound)
        return bounds

    def forward(self, x):  # [n, d] -> [n, d]
        x = self.fc1_pos(x) - self.fc1_neg(x)  # [n, d * m1]
        x = x.view(-1, self.dims[0], self.dims[1])  # [n, d, m1]
        for fc in self.fc2:
            x = torch.sigmoid(x)  # [n, d, m1]
            x = fc(x)  # [n, d, m2]
        x = x.squeeze(dim=2)  # [n, d]
        return x

    def h_func(self):
        """Constrain 2-norm-squared of fc1 weights along m1 dim to be a DAG"""
        d = self.dims[0]
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j * m1, i]
        fc1_weight = fc1_weight.view(d, -1, d)  # [j, m1, i]
        A = torch.sum(fc1_weight * fc1_weight, dim=1).t()  # [i, j]
        h = trace_expm(A) - d  # (Zheng et al. 2018)
        # A different formulation, slightly faster at the cost of numerical stability
        # M = torch.eye(d) + A / d  # (Yu et al. 2019)
        # E = torch.matrix_power(M, d - 1)
        # h = (E.t() * M).sum() - d
        return h

    def l2_reg(self):
        """Take 2-norm-squared of all parameters"""
        reg = 0.
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j * m1, i]
        reg += torch.sum(fc1_weight ** 2)
        for fc in self.fc2:
            reg += torch.sum(fc.weight ** 2)
        return reg

    def fc1_l1_reg(self):
        """Take l1 norm of fc1 weight"""
        reg = torch.sum(self.fc1_pos.weight + self.fc1_neg.weight)
        return reg

    @torch.no_grad()
    def fc1_to_adj(self) -> np.ndarray:  # [j * m1, i] -> [i, j]
        """Get W from fc1 weights, take 2-norm over m1 dim"""
        d = self.dims[0]
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j * m1, i]
        fc1_weight = fc1_weight.view(d, -1, d)  # [j, m1, i]
        A = torch.sum(fc1_weight * fc1_weight, dim=1).t()  # [i, j]
        W = torch.sqrt(A)  # [i, j]
        W = W.cpu().detach().numpy()  # [i, j]
        return W

class NotearsSobolev(nn.Module):
    def __init__(self, d, k):
        """d: num variables k: num expansion of each variable"""
        super(NotearsSobolev, self).__init__()
        self.d, self.k = d, k
        self.fc1_pos = nn.Linear(d * k, d, bias=False)  # ik -> j
        self.fc1_neg = nn.Linear(d * k, d, bias=False)
        self.fc1_pos.weight.bounds = self._bounds()
        self.fc1_neg.weight.bounds = self._bounds()
        nn.init.zeros_(self.fc1_pos.weight)
        nn.init.zeros_(self.fc1_neg.weight)
        self.l2_reg_store = None

    def _bounds(self):
        # weight shape [j, ik]
        bounds = []
        for j in range(self.d):
            for i in range(self.d):
                for _ in range(self.k):
                    if i == j:
                        bound = (0, 0)
                    else:
                        bound = (0, None)
                    bounds.append(bound)
        return bounds

    def sobolev_basis(self, x):  # [n, d] -> [n, dk]
        seq = []
        for kk in range(self.k):
            mu = 2.0 / (2 * kk + 1) / math.pi  # sobolev basis
            psi = mu * torch.sin(x / mu)
            seq.append(psi)  # [n, d] * k
        bases = torch.stack(seq, dim=2)  # [n, d, k]
        bases = bases.view(-1, self.d * self.k)  # [n, dk]
        return bases

    def forward(self, x):  # [n, d] -> [n, d]
        bases = self.sobolev_basis(x)  # [n, dk]
        x = self.fc1_pos(bases) - self.fc1_neg(bases)  # [n, d]
        self.l2_reg_store = torch.sum(x ** 2) / x.shape[0]
        return x

    def h_func(self):
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j, ik]
        fc1_weight = fc1_weight.view(self.d, self.d, self.k)  # [j, i, k]
        A = torch.sum(fc1_weight * fc1_weight, dim=2).t()  # [i, j]
        h = trace_expm(A) - d  # (Zheng et al. 2018)
        # A different formulation, slightly faster at the cost of numerical stability
        # M = torch.eye(self.d) + A / self.d  # (Yu et al. 2019)
        # E = torch.matrix_power(M, self.d - 1)
        # h = (E.t() * M).sum() - self.d
        return h

    def l2_reg(self):
        reg = self.l2_reg_store
        return reg

    def fc1_l1_reg(self):
        reg = torch.sum(self.fc1_pos.weight + self.fc1_neg.weight)
        return reg

    @torch.no_grad()
    def fc1_to_adj(self) -> np.ndarray:
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j, ik]
        fc1_weight = fc1_weight.view(self.d, self.d, self.k)  # [j, i, k]
        A = torch.sum(fc1_weight * fc1_weight, dim=2).t()  # [i, j]
        W = torch.sqrt(A)  # [i, j]
        W = W.cpu().detach().numpy()  # [i, j]
        return W

In [6]:
###################################
## function
###################################

In [14]:
def squared_loss(output, target):
    n = target.shape[0]
    loss = 0.5 / n * torch.sum((output - target) ** 2)
    return loss

def notears_nonlinear(model: nn.Module,
                      X: np.ndarray,
                      lambda1: float = 0.,
                      lambda2: float = 0.,
                      max_iter: int = 100,
                      h_tol: float = 1e-8,
                      rho_max: float = 1e+16,
                      w_threshold: float = 0.1):
    model.w_threshold = w_threshold
    rho, alpha, h = 1.0, 0.0, np.inf
    for _ in range(max_iter):
        rho, alpha, h = dual_ascent_step(model, X, lambda1, lambda2,
                                         rho, alpha, h, rho_max)
        if h <= h_tol or rho >= rho_max:
            break
    W_est = model.fc1_to_adj()
    W_est[np.abs(W_est) < w_threshold] = 0
    return W_est

def notears_nonlinear_with_loss_std(model: nn.Module,
                                    X_list: list,  
                                    lambda1: float = 0.0,
                                    lambda2: float = 0.0,
                                    max_iter: int = 100,
                                    h_tol: float = 1e-8,
                                    rho_max: float = 1e+16,
                                    w_threshold: float = 0.3,
                                    std_lambda: float = 1.0):  
    model.w_threshold = w_threshold    
    rho, alpha, h = 1.0, 0.0, np.inf
    for iter_no in range(max_iter):
        rho, alpha, h = dual_ascent_step_with_loss_std(
            model, X_list, lambda1, lambda2, std_lambda, rho, alpha, h, rho_max, iter_no)
        if h <= h_tol or rho >= rho_max:
            break
    W_est = model.fc1_to_adj()
    W_est[np.abs(W_est) < w_threshold] = 0
    return W_est

def dual_ascent_step(model, X, lambda1, lambda2, rho, alpha, h, rho_max):
    """Perform one step of dual ascent in augmented Lagrangian."""
    h_new = None
    optimizer = LBFGSBScipy(model.parameters())
    X_torch = torch.from_numpy(X)
    while rho < rho_max:
        def closure():
            optimizer.zero_grad()
            X_hat = model(X_torch)
            loss = squared_loss(X_hat, X_torch)
            h_val = model.h_func()
            penalty = 0.5 * rho * h_val * h_val + alpha * h_val
            l2_reg = 0.5 * lambda2 * model.l2_reg()
            l1_reg = lambda1 * model.fc1_l1_reg()
            primal_obj = loss + penalty + l2_reg + l1_reg
            # primal_obj = loss + penalty
            primal_obj.backward()
            return primal_obj
        optimizer.step(closure)  # NOTE: updates model in-place
        with torch.no_grad():
            h_new = model.h_func().item()
        if h_new > 0.25 * h:
            rho *= 10
        else:
            break
    alpha += rho * h_new
    return rho, alpha, h_new

def dual_ascent_step_with_loss_std(model, X_list, lambda1, lambda2, std_lambda, rho, alpha, h, rho_max, iter_no):
    """Perform one step of dual ascent in augmented Lagrangian, with consistent gradient-based learning (CGLearn) for each predictor."""
    
    crp = std_lambda  # Consistency ratio percentile
    h_new = None
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Using Adam optimizer instead of LBFGSBScipy
    X_tensors = [torch.from_numpy(X).float().to(torch.double).to(model.fc1_pos.weight.device) for X in X_list]  # Ensure tensors are on the correct device
    
    while rho < rho_max:
        feature_l2_norms_pos_per_predictor = [[] for _ in range(model.dims[0])]  # Separate L2 norms for fc1_pos
        feature_l2_norms_neg_per_predictor = [[] for _ in range(model.dims[0])]  # Separate L2 norms for fc1_neg
        list_all_grads = []  # To store gradients for the entire model (all layers)

        # Step 1: Compute loss and gradients for each environment
        for X in X_tensors:
            # Loss calculation
            X_hat = model(X)
            loss_mse = squared_loss(X_hat, X)
            h_val = model.h_func()
            penalty = 0.5 * rho * h_val * h_val + alpha * h_val
            l2_reg = 0.5 * lambda2 * model.l2_reg()
            l1_reg = lambda1 * model.fc1_l1_reg()
            final_loss = (loss_mse + penalty + l2_reg + l1_reg) / len(X_tensors)  # Averaged across environments

            # Grad calculation
            optimizer.zero_grad()
            final_loss.backward()

            # Collect all gradients for the entire model (for all layers)
            grads = []
            for param in model.parameters():
                grads.append(param.grad.clone().flatten())
            all_grads = torch.cat(grads)  # Flatten and concatenate all gradients for all parameters
            list_all_grads.append(all_grads)  # Store gradients for this dataset

            # Compute and store gradients per predictor for the first hidden layer
            for i in range(model.dims[0]):  # Iterate over each predictor (input feature)
                # Get the gradients for the first hidden layer (fc1_pos and fc1_neg)
                pos_g = model.fc1_pos.weight.grad.view(model.dims[0], -1, model.dims[0])[i, :, :]  # Shape [5, 10]
                neg_g = model.fc1_neg.weight.grad.view(model.dims[0], -1, model.dims[0])[i, :, :]  # Shape [5, 10]

                # Compute L2 norms of the gradients across the hidden neurons for each feature
                l2_norms_pos = torch.norm(pos_g, dim=0, p=2).detach()  # L2 norm across hidden neurons, result shape: [10]
                l2_norms_neg = torch.norm(neg_g, dim=0, p=2).detach()  # Same for fc1_neg

                # Store L2 norms for each predictor
                feature_l2_norms_pos_per_predictor[i].append(l2_norms_pos)
                feature_l2_norms_neg_per_predictor[i].append(l2_norms_neg)

        # Step 2: Compute mean of all gradients for the entire model across all datasets
        list_all_grads = torch.stack(list_all_grads)  # Stack all gradients for all datasets
        mean_all_grads = torch.mean(list_all_grads, dim=0)  # Mean gradient for the entire model

        # Step 3: Gradient consistency calculations (CGLearn) per predictor
        consistency_masks_pos = []
        consistency_masks_neg = []
        
        for i in range(model.dims[0]):  # Iterate over each predictor (input feature)
            # Consistency check for fc1_pos weights of predictor i
            feature_l2_norms_pos = torch.stack(feature_l2_norms_pos_per_predictor[i])  # Stack L2 norms across datasets
            mean_norms_pos = torch.mean(feature_l2_norms_pos, dim=0)  # Mean of L2 norms for each feature
            std_norms_pos = torch.std(feature_l2_norms_pos, dim=0) + 1e-8  # Standard deviation of L2 norms for each feature
            cr_pos = torch.abs(mean_norms_pos) / std_norms_pos  # Consistency ratio
            ct_pos = np.percentile(cr_pos.cpu().numpy(), crp)  # Threshold based on the percentile
            consistency_mask_pos = torch.where(cr_pos >= ct_pos, torch.tensor(1., device=model.fc1_pos.weight.device), torch.tensor(0., device=model.fc1_pos.weight.device))            
            consistency_masks_pos.append(consistency_mask_pos.repeat(model.dims[1], 1))  # Broadcast the mask
            
            # Consistency check for fc1_neg weights of predictor i
            feature_l2_norms_neg = torch.stack(feature_l2_norms_neg_per_predictor[i])  # Same for fc1_neg
            mean_norms_neg = torch.mean(feature_l2_norms_neg, dim=0)
            std_norms_neg = torch.std(feature_l2_norms_neg, dim=0) + 1e-8
            cr_neg = torch.abs(mean_norms_neg) / std_norms_neg  # Consistency ratio
            ct_neg = np.percentile(cr_neg.cpu().numpy(), crp)
            consistency_mask_neg = torch.where(cr_neg >= ct_neg, torch.tensor(1., device=model.fc1_neg.weight.device), torch.tensor(0., device=model.fc1_neg.weight.device))
            consistency_masks_neg.append(consistency_mask_neg.repeat(model.dims[1], 1))
            
        # Step 4: Apply the masks to gradients before updating parameters
        cmp = torch.stack(consistency_masks_pos).view(-1, model.dims[0])     
        cmn = torch.stack(consistency_masks_neg).view(-1, model.dims[0])
        
        start_index = 0
        for name, param in model.named_parameters():
            param_numel = param.numel()
            mean_grad = mean_all_grads[start_index: start_index + param_numel].view_as(param)
            if 'fc1_pos.weight' in name:
                param.grad = mean_grad * cmp       
            elif 'fc1_neg.weight' in name:
                param.grad = mean_grad * cmn    
            else:
                param.grad = mean_grad
            start_index += param_numel

        # Step 5: Now that the gradients have been modified, perform the optimization step
        optimizer.step()

        # Step 6: Check convergence and adjust rho
        with torch.no_grad():
            h_new = model.h_func().item()
        if h_new > 0.25 * h:
            rho *= 10
        else:
            break
            
        # ww = model.fc1_to_adj()
        # ww[np.abs(ww) < model.w_threshold] = 0
        # print(ww)

    alpha += rho * h_new
    return rho, alpha, h_new


In [15]:
###################################
## initialize
###################################

In [16]:
torch.set_default_dtype(torch.double)
np.set_printoptions(precision=3)
ut.set_random_seed(123)
n, d, s0, graph_type, sem_type = 200, 10, 40, 'ER', 'mlp'
B_true = ut.simulate_dag(d, s0, graph_type)
np.savetxt('inputs/W_true.csv', B_true, delimiter=',')
# noise_scale = np.ones(d)
# X = ut.simulate_nonlinear_sem(B_true, n, sem_type, noise_scale)
# np.savetxt('X.csv', X, delimiter=',')
noise_scales = [0.2, 1, 2, 5, 10]
for i, noise_scale_value in enumerate(noise_scales):
    noise_scale = np.full(d, noise_scale_value)  
    X = ut.simulate_nonlinear_sem(B_true, n, sem_type, noise_scale)
    np.savetxt(f'inputs/X_{i}.csv', X, delimiter=',')  
    
X_0 = np.loadtxt('inputs/X_0.csv', delimiter=',')
X_1 = np.loadtxt('inputs/X_1.csv', delimiter=',')
X_2 = np.loadtxt('inputs/X_2.csv', delimiter=',')
X_3 = np.loadtxt('inputs/X_3.csv', delimiter=',')
X_4 = np.loadtxt('inputs/X_4.csv', delimiter=',')
X_list = [X_0, X_1, X_2, X_3, X_4]  # List of datasets
scaler = StandardScaler()
X_list_standardized = [scaler.fit_transform(X) for X in X_list]
X_combined = np.vstack([X for X in X_list_standardized]) 
print(X_combined.shape)
noise_scale_value_test = np.random.uniform(min(noise_scales), max(noise_scales))
noise_scale_test = np.full(d, noise_scale_value_test)
X_test = ut.simulate_nonlinear_sem(B_true, n, sem_type, noise_scale_test)
print(X_test.shape)

(1000, 10)
(200, 10)


In [17]:
###################################
## NOTEARS
###################################

In [18]:
list_fdr, list_shd, list_tpr, list_nnz = [], [], [], []
for i in range(5):
    model = NotearsMLP(dims=[d, 10, 1], bias=True)
    W_est = notears_nonlinear(model, X_combined, lambda1=0.01, lambda2=0.01)
    # assert ut.is_dag(W_est)
    np.savetxt('outputs/W_est.csv', W_est, delimiter=',')
    acc = ut.count_accuracy(B_true, W_est != 0)
    print(i, acc)
    list_fdr.append(acc['fdr'])
    list_shd.append(acc['shd'])
    list_tpr.append(acc['tpr'])
    list_nnz.append(acc['nnz'])    
print()
print()
print(f'FDR: {np.mean(list_fdr):.4f} ± {np.std(list_fdr):.4f}')
print(f'SHD: {np.mean(list_shd):.4f} ± {np.std(list_shd):.4f}')
print(f'TPR: {np.mean(list_tpr):.4f} ± {np.std(list_tpr):.4f}')
print(f'NNZ: {np.mean(list_nnz):.4f} ± {np.std(list_nnz):.4f}')
X_hat_test = model(torch.from_numpy(X_test).float().to(torch.double))
loss_test = squared_loss(X_hat_test, torch.from_numpy(X_test).float().to(torch.double))
print(f'Test squared loss on original values: {loss_test.item():.4f}')

0 {'fdr': 0.36666666666666664, 'tpr': 0.475, 'fpr': 2.2, 'shd': 24, 'nnz': 30}
1 {'fdr': 0.36666666666666664, 'tpr': 0.475, 'fpr': 2.2, 'shd': 24, 'nnz': 30}
2 {'fdr': 0.36666666666666664, 'tpr': 0.475, 'fpr': 2.2, 'shd': 24, 'nnz': 30}
3 {'fdr': 0.36666666666666664, 'tpr': 0.475, 'fpr': 2.2, 'shd': 24, 'nnz': 30}
4 {'fdr': 0.36666666666666664, 'tpr': 0.475, 'fpr': 2.2, 'shd': 24, 'nnz': 30}


FDR: 0.3667 ± 0.0000
SHD: 24.0000 ± 0.0000
TPR: 0.4750 ± 0.0000
NNZ: 30.0000 ± 0.0000
Test squared loss on original values: 726.0170


In [19]:
###################################
## invariant NOTEARS
###################################

In [20]:
list_fdr, list_shd, list_tpr, list_nnz = [], [], [], []
for i in range(5):
    model = NotearsMLP(dims=[d, 10, 1], bias=True)
    W_est = notears_nonlinear_with_loss_std(model, X_list_standardized, lambda1=0.01, lambda2=0.01, std_lambda=50)
    # assert ut.is_dag(W_est)
    np.savetxt('outputs/W_est2.csv', W_est, delimiter=',')
    acc = ut.count_accuracy(B_true, W_est != 0)
    print(i, acc)
    list_fdr.append(acc['fdr'])
    list_shd.append(acc['shd'])
    list_tpr.append(acc['tpr'])
    list_nnz.append(acc['nnz'])        
print()
print()
print(f'FDR: {np.mean(list_fdr):.4f} ± {np.std(list_fdr):.4f}')
print(f'SHD: {np.mean(list_shd):.4f} ± {np.std(list_shd):.4f}')
print(f'TPR: {np.mean(list_tpr):.4f} ± {np.std(list_tpr):.4f}')
print(f'NNZ: {np.mean(list_nnz):.4f} ± {np.std(list_nnz):.4f}')
X_hat_test = model(torch.from_numpy(X_test).float().to(torch.double))
loss_test = squared_loss(X_hat_test, torch.from_numpy(X_test).float().to(torch.double))
print(f'Test squared loss on original values: {loss_test.item():.4f}')

0 {'fdr': 0.6842105263157895, 'tpr': 0.3, 'fpr': 5.2, 'shd': 37, 'nnz': 38}
1 {'fdr': 0.42105263157894735, 'tpr': 0.275, 'fpr': 1.6, 'shd': 33, 'nnz': 19}
2 {'fdr': 0.4, 'tpr': 0.45, 'fpr': 2.4, 'shd': 28, 'nnz': 30}
3 {'fdr': 0.5714285714285714, 'tpr': 0.3, 'fpr': 3.2, 'shd': 33, 'nnz': 28}
4 {'fdr': 0.6388888888888888, 'tpr': 0.325, 'fpr': 4.6, 'shd': 35, 'nnz': 36}


FDR: 0.5431 ± 0.1142
SHD: 33.2000 ± 2.9933
TPR: 0.3300 ± 0.0620
NNZ: 30.2000 ± 6.7052
Test squared loss on original values: 721.1958
