In [1]:
import pandas as pd
import igraph as ig
import matplotlib.pyplot as plt
import numpy as np

from sklearn.preprocessing import minmax_scale

from notears import linear, nonlinear, utils

from CausalDisco.analytics import (
    var_sortability,
    r2_sortability,
    snr_sortability
)

In [2]:
def var_sort_lin(X_norm, d, sorting):
    X_varsorted = X_norm.copy()
    
    vars = np.linspace(1, d, d)

    X_varsorted[:, sorting] *= vars
    return X_varsorted

def var_sort_lin_inv(X_norm, d, sorting):
    X_varsorted = X_norm.copy()
    
    vars = np.linspace(d, 1, d)

    X_varsorted[:, sorting] *= vars
    return X_varsorted

def var_sort_exp(X_norm, d, sorting):
    X_varsorted = X_norm.copy()
    
    vars = np.logspace(1, d, d, base=2)
    vars /= (vars[-1] / (d+1))
    
    X_varsorted[:, sorting] *= vars
    return X_varsorted

def var_sort_exp_inv(X_norm, d, sorting):
    X_varsorted = X_norm.copy()

    vars = np.logspace(1, d, d, base=2)
    vars /= (vars[-1] / (d+1))
    vars = vars[::-1]

    X_varsorted[:, sorting] *= vars
    return X_varsorted

def var_sort_log(X_norm, d, sorting):
    X_varsorted = X_norm.copy()

    vars = np.logspace(1, d, d, base=0.5)
    vars = np.full(vars.shape, vars.max()) - vars
    vars = minmax_scale(vars, feature_range=(1, d))

    X_varsorted[:, sorting] *= vars
    return X_varsorted

def var_sort_log_inv(X_norm, d, sorting):
    X_varsorted = X_norm.copy()

    vars = np.logspace(1, d, d, base=0.5)
    vars = np.full(vars.shape, vars.max()) - vars
    vars = minmax_scale(vars, feature_range=(1, d))
    vars = vars[::-1]

    X_varsorted[:, sorting] *= vars
    return X_varsorted

In [9]:
d = 30
n = 1000
s0 = 40
graph_type = "ER" 

B_true = utils.simulate_dag(d, s0, graph_type)
W = utils.simulate_parameter(B_true)
X = utils.simulate_linear_sem(W, n, "gauss")

In [10]:
X.std(axis=0)

array([ 0.99434307,  2.12417004,  3.41407501,  3.45950264, 12.34374197,
        1.02293016,  1.75241781,  2.83139094,  1.86930739,  1.70561319,
        4.69766821,  1.21909482,  2.03095911,  6.19377394,  1.87197683,
        1.00257275,  1.00395284,  3.57353628,  0.99117597,  4.54685261,
        1.84555526,  0.98917805,  5.77993461,  1.4743296 ,  2.41531321,
        1.50326305,  7.1326341 ,  0.96509989,  0.97648681,  0.98540439])

In [11]:
g = ig.Graph.Adjacency(B_true, loops=False)
g.vs["label"] = list(range(d))

sorting = g.topological_sorting()

print(sorting)

[0, 5, 15, 16, 18, 21, 27, 28, 29, 9, 14, 11, 1, 12, 6, 20, 23, 7, 19, 2, 3, 17, 8, 25, 13, 22, 10, 24, 26, 4]


In [12]:
# ORIGINAL
print("VS-original", var_sortability(X, W))
print("R2-original", r2_sortability(X, W))
print("SNR-original", snr_sortability(X, W))
print("---")

# NORMALIZE
X = (X  - X.mean(axis=0)) / X.std(axis=0)
print("VS-normalised", var_sortability(X, W))
print("R2-normalised", r2_sortability(X, W))
print("SNR-normalised", snr_sortability(X, W))
print("---")

# CONTROL VARSORT
X = var_sort_log(X, d, sorting)
print("VS-controlled-log", var_sortability(X, W))
print("R2-controlled-log", r2_sortability(X, W))
print("SNR-controlled-log", snr_sortability(X, W))

VS-original 0.9696969696969697
R2-original 0.8080808080808081
SNR-original 0.9696969696969697
---
VS-normalised 0.5858585858585859
R2-normalised 0.8080808080808081
SNR-normalised 0.9696969696969697
---
VS-controlled-log 1.0
R2-controlled-log 0.8080808080808081
SNR-controlled-log 0.9696969696969697


In [17]:
simpleDAG = np.array([
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
])
# A -> B -> C -> ... -> J

simpleW = utils.simulate_parameter(simpleDAG)
simpleX = utils.simulate_linear_sem(simpleW, n, "gauss")

g = ig.Graph.Adjacency(simpleDAG, loops=False)
g.vs["label"] = list(range(10))

simplesorting = g.topological_sorting()

In [19]:
# ORIGINAL
print("VS-original", var_sortability(simpleX, simpleW))
print("R2-original", r2_sortability(simpleX, simpleW))
print("SNR-original", snr_sortability(simpleX, simpleW))
print("---")

# NORMALIZE
normalised_simpleX = (simpleX  - simpleX.mean(axis=0)) / simpleX.std(axis=0)
print("VS-normalised", var_sortability(normalised_simpleX, simpleW))
print("R2-normalised", r2_sortability(normalised_simpleX, simpleW))
print("SNR-normalised", snr_sortability(normalised_simpleX, simpleW))
print("---")

# CONTROL VARSORT LOG
simpleX = var_sort_log(normalised_simpleX, 10, simplesorting)
print("VS-controlled-log", var_sortability(simpleX, simpleW))
print("R2-controlled-log", r2_sortability(simpleX, simpleW))
print("SNR-controlled-log", snr_sortability(simpleX, simpleW))
print("---")

# CONTROL VARSORT LIN
simpleX = var_sort_lin(normalised_simpleX, 10, simplesorting)
print("VS-controlled-lin", var_sortability(simpleX, simpleW))
print("R2-controlled-lin", r2_sortability(simpleX, simpleW))
print("SNR-controlled-lin", snr_sortability(simpleX, simpleW))
print("---")

# CONTROL VARSORT EXP
simpleX = var_sort_exp(normalised_simpleX, 10, simplesorting)
print("VS-controlled-exp", var_sortability(simpleX, simpleW))
print("R2-controlled-exp", r2_sortability(simpleX, simpleW))
print("SNR-controlled-exp", snr_sortability(simpleX, simpleW))
print("---")

# CONTROL VARSORT LOG INV
simpleX = var_sort_log_inv(normalised_simpleX, 10, simplesorting)
print("VS-controlled-log-inv", var_sortability(simpleX, simpleW))
print("R2-controlled-log-inv", r2_sortability(simpleX, simpleW))
print("SNR-controlled-log-inv", snr_sortability(simpleX, simpleW))
print("---")

# CONTROL VARSORT LIN INV
simpleX = var_sort_lin_inv(normalised_simpleX, 10, simplesorting)
print("VS-controlled-lin-inv", var_sortability(simpleX, simpleW))
print("R2-controlled-lin-inv", r2_sortability(simpleX, simpleW))
print("SNR-controlled-lin-inv", snr_sortability(simpleX, simpleW))
print("---")

# CONTROL VARSORT EXP INV
simpleX = var_sort_exp_inv(normalised_simpleX, 10, simplesorting)
print("VS-controlled-exp-inv", var_sortability(simpleX, simpleW))
print("R2-controlled-exp-inv", r2_sortability(simpleX, simpleW))
print("SNR-controlled-exp-inv", snr_sortability(simpleX, simpleW))

VS-original 0.0
R2-original 0.9111111111111111
SNR-original 0.9555555555555556
---
VS-normalised 0.4888888888888889
R2-normalised 0.9111111111111111
SNR-normalised 0.9555555555555556
---
VS-controlled-log 1.0
R2-controlled-log 0.9111111111111111
SNR-controlled-log 0.9555555555555556
---
VS-controlled-lin 1.0
R2-controlled-lin 0.9111111111111111
SNR-controlled-lin 0.9555555555555556
---
VS-controlled-exp 1.0
R2-controlled-exp 0.9111111111111111
SNR-controlled-exp 0.9555555555555556
---
VS-controlled-log-inv 0.0
R2-controlled-log-inv 0.9111111111111111
SNR-controlled-log-inv 0.9555555555555556
---
VS-controlled-lin-inv 0.0
R2-controlled-lin-inv 0.9111111111111111
SNR-controlled-lin-inv 0.9555555555555556
---
VS-controlled-exp-inv 0.0
R2-controlled-exp-inv 0.9111111111111111
SNR-controlled-exp-inv 0.9555555555555556


In [None]:
vars = np.logspace(1, d, d, base=0.5)
vars = np.full(vars.shape, vars.max()) - vars
vars = minmax_scale(vars, feature_range=(1, d))
vars = vars[::-1]

vars

In [None]:
varsortability(X, W)

In [None]:
varsortability(X_