In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import uniform
#from BFM.MCMC_MGP import Gibbs_sampling
#from BFM.MCMC_LH import Gibbs_sampling
from BFM.MCMC_CSP import Gibbs_sampling
from BFM.VI import NGVI
from BFM.utils import FDR_FNR_mcmc, FDR_FNR_VI, FDR_FNR_COV, ESS

In [None]:
device1 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
P = 1000
K = 5
N = 100

In [None]:
B_0 = np.random.binomial(1, 1 / 3,(P,K)) * np.random.rand(P,K)
cov_0 = B_0 @ B_0.T + np.diag(uniform.rvs(loc = 0.1, scale = 0.9, size = P))
X = np.random.multivariate_normal(np.zeros(P), cov_0, N)

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(cov_0)
plt.colorbar()
plt.title("Covariance")
plt.show()

In [None]:
B_sample, sigma2_sample = Gibbs_sampling(X, device = device1)


In [None]:
mu, Cov, np_sigma, v = NGVI(X, device = device1)

In [None]:
cov_VI = mu @ mu.T + (v / (v - 2)) * torch.diag(torch.vmap(torch.trace)(Cov)) + torch.diag(np_sigma / (0.5 * N))

In [None]:
cov_mcmc = torch.einsum('bij,bjk->ik',B_sample, B_sample.transpose(1,2)) / B_sample.size(0) + torch.diag(sigma2_sample.mean(0))

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(cov_VI)
plt.colorbar()
plt.title("Covariance_VI")
plt.show()

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(cov_mcmc)
plt.colorbar()
plt.title("Covariance_MCMC")
plt.show()

In [None]:
B_mean = B_sample.mean(0)
plt.plot(torch.norm(B_mean, dim = 0).numpy())
plt.plot()

In [None]:
plt.plot(torch.norm(mu, dim = 0).numpy())
plt.plot()

In [None]:
_, _, num_K_mcmc = FDR_FNR_mcmc(B_sample, B_0)

In [None]:
_, _, num_K_VI = FDR_FNR_VI(mu, Cov, v, B_0)

In [None]:
(torch.from_numpy(cov_0) - cov_VI).square().sum().sqrt()

In [None]:
(torch.from_numpy(cov_0) - cov_mcmc).square().sum().sqrt()

In [None]:
FDR_COV_mcmc, FNR_COV_mcmc= FDR_FNR_COV(cov_0, cov_mcmc.numpy())

In [None]:
FDR_COV_VI, FNR_COV_VI= FDR_FNR_COV(cov_0, cov_VI.numpy())

In [None]:
print(FDR_COV_mcmc)
print(FNR_COV_mcmc)

In [None]:
print(FDR_COV_VI)
print(FNR_COV_VI)

In [None]:
ess_B = ESS(B_sample.numpy())

In [None]:
ess_sigma2 = ESS(sigma2_sample.numpy())

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(ess_B, aspect = 0.1)
plt.colorbar()
plt.title("ESS_B")
plt.show()