In [1]:
import torch
import numpy as np
from utils.metrics import distance_correlation, id_correlation
from anatome.similarity import svcca_distance
from tqdm import trange

In [2]:
def linear_correlation(X, Y):
    return np.corrcoef(X, Y, rowvar=False)[0,1]**2

In [None]:
N_points=5000
np.random.seed(1)
torch.manual_seed(1)

In [None]:
random_idcor_mle = []
random_lin = []
random_dcor = []
random_pv_mle = []
random_idcor_2nn = []
random_pv_2nn = []
random_id_mle = []
random_id_2nn = []
partial_idcor_mle = []
partial_lin = []
partial_dcor = []
partial_pv_mle = []
partial_idcor_2nn = []
partial_pv_2nn = []
partial_id_mle = []
partial_id_2nn = []
total_idcor_mle = []
total_lin = []
total_dcor = []
total_pv_mle = []
total_idcor_2nn = []
total_pv_2nn = []
total_id_mle = []
total_id_2nn = []
for s in trange(10):
    b = np.pi
    random_X = torch.as_tensor(
        np.random.normal(0, b, (N_points, 4)), dtype=torch.float64
    )
    random_Y = torch.as_tensor(
        np.random.normal(0, b, (N_points, 4)), dtype=torch.float64
    )

    partial_Y = random_Y.clone()
    partial_Y[:, :2] = torch.stack(
        [torch.cos(random_X[:, 0] + random_X[:, 2]), torch.sin(random_X[:, 1])], dim=1
    )
    partial_dcor.append(distance_correlation(random_X, partial_Y))
    res = id_correlation(random_X, partial_Y, algorithm="MLE", k=100)
    partial_idcor_mle.append(res["corr"])
    partial_pv_mle.append(res["p"])
    partial_id_mle.append(res["id"])
    res = id_correlation(random_X, partial_Y, algorithm="twoNN")
    partial_idcor_2nn.append(res["corr"])
    partial_pv_2nn.append(res["p"])
    partial_id_2nn.append(res["id"])
    partial_lin.append(
        1 - svcca_distance(random_X, partial_Y, accept_rate=0.99, backend="svd")
    )
    total_Y = torch.stack(
        [
            torch.cos(random_X[:, 0] + random_X[:, 2]),
            torch.sin(random_X[:, 1]),
            torch.sin(random_X[:, 1] * random_X[:, 3]),
            torch.cos(random_X[:, 2]),
        ],
        dim=1,
    )
    total_dcor.append(distance_correlation(random_X, total_Y))
    res = id_correlation(random_X, total_Y, algorithm="MLE", k=100)
    total_idcor_mle.append(res["corr"])
    total_pv_mle.append(res["p"])
    total_id_mle.append(res["id"])
    res = id_correlation(random_X, total_Y, algorithm="twoNN")
    total_idcor_2nn.append(res["corr"])
    total_pv_2nn.append(res["p"])
    total_id_2nn.append(res["id"])
    total_lin.append(
        1 - svcca_distance(random_X, total_Y, accept_rate=0.99, backend="svd")
    )
    random_dcor.append(distance_correlation(random_X, random_Y))
    res = id_correlation(random_X, random_Y, algorithm="MLE", k=100)
    random_idcor_mle.append(res["corr"])
    random_pv_mle.append(res["p"])
    random_id_mle.append(res["id"])
    res = id_correlation(random_X, random_Y, algorithm="twoNN")
    random_idcor_2nn.append(res["corr"])
    random_pv_2nn.append(res["p"])
    random_id_2nn.append(res["id"])
    random_lin.append(
        1 - svcca_distance(random_X, random_Y, accept_rate=0.99, backend="svd")
    )


In [None]:
total_idcor_mle=np.array(total_idcor_mle)
total_lin=np.array(total_lin)
total_dcor=np.array(total_dcor)
total_pv_mle=np.array(total_pv_mle)
total_idcor_2nn=np.array(total_idcor_2nn)
total_pv_2nn=np.array(total_pv_2nn)
total_id_mle=np.array(total_id_mle)
total_id_2nn=np.array(total_id_2nn)
print('Lin: ', total_lin.mean(), total_lin.std())
print('DCor: ', total_dcor.mean(), total_dcor.std())
print('IDCor MLE: ', total_idcor_mle.mean(), total_idcor_mle.std())
print('IDCor 2NN: ', total_idcor_2nn.mean(), total_idcor_2nn.std())
print('ID MLE: ', total_id_mle.mean(), total_id_mle.std())
print('ID 2NN: ', total_id_2nn.mean(), total_id_2nn.std())
print('PV MLE: ', total_pv_mle.min(), total_pv_mle.max())
print('PV 2NN: ', total_pv_2nn.min(), total_pv_2nn.max())

In [None]:
partial_idcor_mle=np.array(partial_idcor_mle)
partial_lin=np.array(partial_lin)
partial_dcor=np.array(partial_dcor)
partial_pv_mle=np.array(partial_pv_mle)
partial_idcor_2nn=np.array(partial_idcor_2nn)
partial_pv_2nn=np.array(partial_pv_2nn)
partial_id_mle=np.array(partial_id_mle)
partial_id_2nn=np.array(partial_id_2nn)
print('Lin: ', partial_lin.mean(), partial_lin.std())
print('DCor: ', partial_dcor.mean(), partial_dcor.std())
print('IDCor MLE: ', partial_idcor_mle.mean(), partial_idcor_mle.std())
print('IDCor 2NN: ', partial_idcor_2nn.mean(), partial_idcor_2nn.std())
print('ID MLE: ', partial_id_mle.mean(), partial_id_mle.std())
print('ID 2NN: ', partial_id_2nn.mean(), partial_id_2nn.std())
print('PV MLE: ', partial_pv_mle.min(), partial_pv_mle.max())
print('PV 2NN: ', partial_pv_2nn.min(), partial_pv_2nn.max())

In [None]:
random_idcor_mle=np.array(random_idcor_mle)
random_lin=np.array(random_lin)
random_dcor=np.array(random_dcor)
random_pv_mle=np.array(random_pv_mle)
random_idcor_2nn=np.array(random_idcor_2nn)
random_pv_2nn=np.array(random_pv_2nn)
random_id_mle=np.array(random_id_mle)
random_id_2nn=np.array(random_id_2nn)
print('Lin: ', random_lin.mean(), random_lin.std())
print('DCor: ', random_dcor.mean(), random_dcor.std())
print('IDCor MLE: ', random_idcor_mle.mean(), random_idcor_mle.std())
print('IDCor 2NN: ', random_idcor_2nn.mean(), random_idcor_2nn.std())
print('ID MLE: ', random_id_mle.mean(), random_id_mle.std())
print('ID 2NN: ', random_id_2nn.mean(), random_id_2nn.std())
print('PV MLE: ', random_pv_mle.min(), random_pv_mle.max())
print('PV 2NN: ', random_pv_2nn.min(), random_pv_2nn.max())

In [None]:
from tqdm import tqdm
pv=torch.zeros(5,30)
corr=torch.zeros(5,30)
np.random.seed(42)
for s in range(5):
    for n,noise in tqdm(enumerate(np.arange(0.01,3, 0.1))):
        noisy_Y = total_Y + torch.as_tensor(np.random.normal(0, noise, (N_points,4)), dtype=torch.float64)
        res = id_correlation(random_X, noisy_Y, algorithm='twoNN')
        corr[s,n]=(res['corr'])
        pv[s,n]=(res['p'])

In [None]:
import matplotlib.pyplot as plt

fig, ax1 = plt.subplots()

mean_corr = corr.mean(0).numpy()
std_corr = corr.std(0).numpy()
x = np.arange(0.01, 3, 0.1)
ax1.plot(x, mean_corr, 'o-', c='peru', label='Mean Corr')
ax1.fill_between(x, mean_corr - std_corr, mean_corr + std_corr, color='peru', alpha=0.2)

ax2 = ax1.twinx()
parts = ax2.violinplot(pv.numpy(), positions=x, widths=0.05, showmeans=False, showextrema=True, showmedians=False)

ax2.set_ylabel('$p$-value')

ax1.set_xlabel('Gaussian noise $\sigma$')
ax1.set_ylabel('$I_d$Cor')

plt.savefig('results/noisy_corr.pdf', dpi=200, bbox_inches='tight', format='pdf')
