In [None]:
from causal_learn.causallearn.utils.cit import CIT
import numpy as np

def _binary_all(X_true):
    X = X_true.copy()
    for i in range(X.shape[1]):
        x_tmp = X[:,i]
        var = np.var(x_tmp)
        c = np.median(x_tmp)
        print(f"Random variable: {i}, threshold: {c}")
        x_tmp = np.where(x_tmp > c, 1, 0)
        X[:,i] = x_tmp
    return X

In [None]:
"""
Here is an example of how to use DCT with one step
"""
test = 'dis_test'

samples = 1000
x1 = np.random.normal(5,1,samples)
x2 = np.random.normal(-2,1,samples) + 2*x1
x3 = x1 + np.random.normal(0,1, samples)

X_true = np.array([x1, x2, x3]).T
data_bin = _binary_all(X_true)

dist_test_obj = CIT(data=data_bin, method=test)
p_value = dist_test_obj(1,2,[0])
p_value

Random variable: 0, threshold: 4.987149264691457
Random variable: 1, threshold: 8.02685828770547
Random variable: 2, threshold: 4.9970155100331475
Variance: [2.68566704]
Z-score: [1.11068491]
P-value: [0.266704]


array([0.266704])

In [None]:
"""
Here is an example of how to use DCT-GMM with one step
"""
test = 'dct_gmm'

samples = 1000
x1 = np.random.normal(5,1,samples)
x2 = np.random.normal(-2,1,samples) + 2*x1
x3 = x1 + np.random.normal(0,1, samples)

X_true = np.array([x1, x2, x3]).T
data_bin = _binary_all(X_true)

dist_test_obj = CIT(data=data_bin, method=test, phase_two=False)
p_value = dist_test_obj(1,2,[0])
p_value

Random variable: 0, threshold: 5.033235677514188
Random variable: 1, threshold: 8.116467489503759
Random variable: 2, threshold: 5.0356194371960346
Variance: [2.82225945]
P-value: [0.95057661]


array([0.95057661])

In [None]:
"""
Here is an example of how to use DCT-GMM with two step
"""

test = 'dct_gmm'

samples = 1000
x1 = np.random.normal(5,1,samples)
x2 = np.random.normal(-2,1,samples) + 2*x1
x3 = x1 + np.random.normal(0,1, samples)

X_true = np.array([x1, x2, x3]).T
data_bin = _binary_all(X_true)

dist_test_obj = CIT(data=data_bin, method=test, phase_two=True)
p_value = dist_test_obj(1,2,[0])
p_value

Random variable: 0, threshold: 4.962782480910894
Random variable: 1, threshold: 7.8841030010414235
Random variable: 2, threshold: 4.989001075940564
Variance: [2.85376592]
P-value: [0.22733319]


array([0.22733319])