In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from tqdm.notebook import trange, tqdm

In [3]:
%matplotlib notebook
np.set_printoptions(precision=4, linewidth=500, threshold=500, suppress=True)

In [4]:
from numpy import kron, eye as I, exp, trace as tr, diag
from numpy.linalg import inv, eigh

In [5]:
from utils import vec, mat, get_chain_graph, get_random_graph, matrix_derivative_numerical, mat_pow, diagi

In [6]:
import stan

In [7]:
import nest_asyncio
nest_asyncio.apply()

In [8]:
def get_params(T, N, gamma, beta, random_graph=False, seed=True, p=0.5):
    
    if seed:
        np.random.seed(1)
    
    Y = np.random.normal(size=(N, T)) 
    S = np.random.choice([0, 1], p=[1 - p, p], replace=True, size=(N, T))
    S_ = 1 - S
    Y = Y * S

    K = np.exp(-(np.linspace(0, 3, T)[:, None] - np.linspace(0, 3, T)[None, :]) ** 2) + 1e-4 * I(T)
    
    if random_graph:
        _, LT = get_random_graph(T)
        _, LN = get_random_graph(N)
        
    else:
        _, LT = get_chain_graph(T)
        _, LN = get_chain_graph(N)

    lamLT, UT = eigh(LT)
    lamLN, UN = eigh(LN)
    lamK, V = eigh(K)
        
    lamT = exp(-beta * lamLT) ** 2
    lamN = exp(-beta * lamLN) ** 2

    HT = UT @ diag(lamT) @ UT.T
    HN = UN @ diag(lamN) @ UN.T
    
    J = np.outer(lamN, lamT) / (np.outer(lamN, lamT) + gamma)
    G = np.outer(lamN, lamT)
    
    return T, N, gamma, Y, S, S_, lamT, lamN, HT, HN, UT, UN, K, lamK, V, J, G

In [9]:
T, N, gamma, Y, S, S_, lamT, lamN, HT, HN, UT, UN, K, lamK, V, J, G = get_params(T=15, N=10, beta=1, gamma=1.4, random_graph=False)

In [10]:
G

array([[1.    , 0.9163, 0.7076, 0.4658, 0.2662, 0.1353, 0.063 , 0.0278, 0.0121, 0.0053, 0.0025, 0.0013, 0.0007, 0.0005, 0.0004],
       [0.8222, 0.7534, 0.5818, 0.383 , 0.2189, 0.1113, 0.0518, 0.0229, 0.0099, 0.0044, 0.002 , 0.001 , 0.0006, 0.0004, 0.0003],
       [0.4658, 0.4268, 0.3296, 0.217 , 0.124 , 0.063 , 0.0294, 0.013 , 0.0056, 0.0025, 0.0012, 0.0006, 0.0003, 0.0002, 0.0002],
       [0.1923, 0.1762, 0.1361, 0.0896, 0.0512, 0.026 , 0.0121, 0.0053, 0.0023, 0.001 , 0.0005, 0.0002, 0.0001, 0.0001, 0.0001],
       [0.063 , 0.0578, 0.0446, 0.0294, 0.0168, 0.0085, 0.004 , 0.0018, 0.0008, 0.0003, 0.0002, 0.0001, 0.    , 0.    , 0.    ],
       [0.0183, 0.0168, 0.013 , 0.0085, 0.0049, 0.0025, 0.0012, 0.0005, 0.0002, 0.0001, 0.    , 0.    , 0.    , 0.    , 0.    ],
       [0.0053, 0.0049, 0.0038, 0.0025, 0.0014, 0.0007, 0.0003, 0.0001, 0.0001, 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],
       [0.0017, 0.0016, 0.0012, 0.0008, 0.0005, 0.0002, 0.0001, 0.    , 0.    , 0.    , 0.    , 0

In [18]:
graph_code2 = """

data {
  int<lower=1> N;         
  int<lower=1> T;         
  real<lower=0> gamma; 
  
  matrix[N, T] Y;
  matrix<lower=0, upper=1>[N, T] S;
  matrix<lower=0, upper=1>[N, T] G_h;
  matrix[N, N] UN;
  matrix[T, T] UTT;

}


parameters {
  matrix[N, T] Z;
}


model {
  to_vector(Z) ~ normal(0, 1 / sqrt(gamma));
  to_vector(Y) ~ normal(to_vector(S .* (UN * (G_h .* Z) * UTT)), 1);
}

"""

In [19]:
GG = G ** 0.5
GG[0, 0] = 1 - 1e-6

graph_data = {"N": N, 'T': T, 'gamma': gamma, 'Y': Y, 'S': S, 'G_h': GG, 'UN': UN, 'UTT': UT.T}

posterior = stan.build(graph_code2, data=graph_data)


[36mBuilding:[0m 0.3s
[1A[0J[36mBuilding:[0m 0.4s
[1A[0J[36mBuilding:[0m 0.6s
[1A[0J[36mBuilding:[0m 0.7s
[1A[0J[36mBuilding:[0m 0.8s
[1A[0J[36mBuilding:[0m 0.9s
[1A[0J[36mBuilding:[0m 1.0s
[1A[0J[36mBuilding:[0m 1.1s
[1A[0J[36mBuilding:[0m 1.2s
[1A[0J[36mBuilding:[0m 1.3s
[1A[0J[36mBuilding:[0m 1.4s
[1A[0J[36mBuilding:[0m 1.5s
[1A[0J[36mBuilding:[0m 1.6s
[1A[0J[36mBuilding:[0m 1.7s
[1A[0J[36mBuilding:[0m 1.8s
[1A[0J[36mBuilding:[0m 1.9s
[1A[0J[36mBuilding:[0m 2.0s
[1A[0J[36mBuilding:[0m 2.1s
[1A[0J[36mBuilding:[0m 2.2s
[1A[0J[36mBuilding:[0m 2.4s
[1A[0J[36mBuilding:[0m 2.5s
[1A[0J[36mBuilding:[0m 2.6s
[1A[0J[36mBuilding:[0m 2.7s
[1A[0J[36mBuilding:[0m 2.8s
[1A[0J[36mBuilding:[0m 2.9s
[1A[0J[36mBuilding:[0m 3.0s
[1A[0J[36mBuilding:[0m 3.1s
[1A[0J[36mBuilding:[0m 3.2s
[1A[0J[36mBuilding:[0m 3.3s
[1A[0J[36mBuilding:[0m 3.4s
[1A[0J[36mBuilding:[0m 3.5s
[1A[0J[36mBui

In file included from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/model/model_header.hpp:7,
                 from /home/ed/.cache/httpstan/4.7.2/models/tdhwq6xt/model_tdhwq6xt.cpp:2:
/home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/io/dump.hpp: In member function ‘virtual std::vector<std::complex<double> > stan::io::dump::vals_c(const string&) const’:
  694 |       for (comp_iter = 0, real_iter = 0; real_iter < val_r->second.first.size();
      |                                          ~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
  707 |              real_iter < val_i->second.first.size();
      |              ~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/model/indexing.hpp:5,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/model/model_header.hpp:17,
                 from /h

[1A[0J[36mBuilding:[0m 16.5s
[1A[0J[36mBuilding:[0m 16.6s
[1A[0J[36mBuilding:[0m 16.7s
[1A[0J[36mBuilding:[0m 16.8s
[1A[0J[36mBuilding:[0m 16.9s
[1A[0J[36mBuilding:[0m 17.0s
[1A[0J[36mBuilding:[0m 17.1s
[1A[0J[36mBuilding:[0m 17.2s
[1A[0J[36mBuilding:[0m 17.3s


/home/ed/.cache/httpstan/4.7.2/models/tdhwq6xt/model_tdhwq6xt.cpp: In instantiation of ‘void model_tdhwq6xt_namespace::model_tdhwq6xt::transform_inits_impl(VecVar&, VecI&, VecVar&, std::ostream*) const [with VecVar = std::vector<double, std::allocator<double> >; VecI = std::vector<int>; stan::require_std_vector_t<T>* <anonymous> = 0; stan::require_vector_like_vt<std::is_integral, VecI>* <anonymous> = 0; std::ostream = std::basic_ostream<char>]’:
/home/ed/.cache/httpstan/4.7.2/models/tdhwq6xt/model_tdhwq6xt.cpp:675:69:   required from here
  494 |       int pos__ = std::numeric_limits<int>::min();
      |           ^~~~~


[1A[0J[36mBuilding:[0m 17.4s
[1A[0J[36mBuilding:[0m 17.5s
[1A[0J[36mBuilding:[0m 17.6s
[1A[0J[36mBuilding:[0m 17.7s
[1A[0J[36mBuilding:[0m 17.8s
[1A[0J[36mBuilding:[0m 17.9s
[1A[0J[36mBuilding:[0m 18.0s
[1A[0J[36mBuilding:[0m 18.1s
[1A[0J[36mBuilding:[0m 18.2s
[1A[0J[36mBuilding:[0m 18.3s
[1A[0J[36mBuilding:[0m 18.4s
[1A[0J[36mBuilding:[0m 18.5s
[1A[0J[36mBuilding:[0m 18.6s
[1A[0J[36mBuilding:[0m 18.7s
[1A[0J[36mBuilding:[0m 18.8s
[1A[0J[36mBuilding:[0m 19.0s
[1A[0J[36mBuilding:[0m 19.1s
[1A[0J[36mBuilding:[0m 19.2s
[1A[0J[36mBuilding:[0m 19.3s
[1A[0J[36mBuilding:[0m 19.4s
[1A[0J[36mBuilding:[0m 19.5s
[1A[0J[36mBuilding:[0m 19.6s
[1A[0J[36mBuilding:[0m 19.7s
[1A[0J[36mBuilding:[0m 19.8s
[1A[0J[36mBuilding:[0m 19.9s
[1A[0J[36mBuilding:[0m 20.0s
[1A[0J[36mBuilding:[0m 20.1s
[1A[0J[36mBuilding:[0m 20.2s
[1A[0J[36mBuilding:[0m 20.3s
[1A[0J[36mBuilding:[0m 20.4s
[1A[0J[

/home/ed/.cache/httpstan/4.7.2/models/tdhwq6xt/model_tdhwq6xt.cpp: In instantiation of ‘void model_tdhwq6xt_namespace::model_tdhwq6xt::write_array_impl(RNG&, VecR&, VecI&, VecVar&, bool, bool, std::ostream*) const [with RNG = boost::random::additive_combine_engine<boost::random::linear_congruential_engine<unsigned int, 40014, 0, 2147483563>, boost::random::linear_congruential_engine<unsigned int, 40692, 0, 2147483399> >; VecR = Eigen::Matrix<double, -1, 1>; VecI = std::vector<int>; VecVar = std::vector<double, std::allocator<double> >; stan::require_vector_like_vt<std::is_floating_point, VecR>* <anonymous> = 0; stan::require_vector_like_vt<std::is_integral, VecI>* <anonymous> = 0; stan::require_std_vector_vt<std::is_floating_point, VecVar>* <anonymous> = 0; std::ostream = std::basic_ostream<char>]’:
/home/ed/.cache/httpstan/4.7.2/models/tdhwq6xt/model_tdhwq6xt.cpp:606:7:   required from ‘void model_tdhwq6xt_namespace::model_tdhwq6xt::write_array(RNG&, Eigen::Matrix<double, -1, 1>&, Eig

[1A[0J[36mBuilding:[0m 22.6s
[1A[0J[36mBuilding:[0m 22.7s
[1A[0J[36mBuilding:[0m 22.8s
[1A[0J[36mBuilding:[0m 22.9s
[1A[0J[36mBuilding:[0m 23.0s
[1A[0J[36mBuilding:[0m 23.1s
[1A[0J[36mBuilding:[0m 23.2s
[1A[0J[36mBuilding:[0m 23.3s
[1A[0J[36mBuilding:[0m 23.4s


In file included from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math/prim/err/check_not_nan.hpp:5,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math/prim/err/check_2F1_converges.hpp:5,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math/prim/err.hpp:4,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math/rev/core/profiling.hpp:9,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math/rev/core.hpp:53,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math/rev.hpp:8,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/stan/math.hpp:19,
                 from /home/ed/miniconda3/envs/p310/lib/python3.10/site-packages/httpstan/include/st

[1A[0J[36mBuilding:[0m 23.5s
[1A[0J[36mBuilding:[0m 23.6s
[1A[0J[36mBuilding:[0m 23.7s
[1A[0J[36mBuilding:[0m 23.8s
[1A[0J[36mBuilding:[0m 23.9s
[1A[0J[36mBuilding:[0m 24.0s
[1A[0J[36mBuilding:[0m 24.1s
[1A[0J[36mBuilding:[0m 24.2s
[1A[0J[36mBuilding:[0m 24.3s
[1A[0J[36mBuilding:[0m 24.5s
[1A[0J[36mBuilding:[0m 24.6s
[1A[0J[36mBuilding:[0m 24.7s
[1A[0J[36mBuilding:[0m 24.8s
[1A[0J[36mBuilding:[0m 24.9s
[1A[0J[36mBuilding:[0m 25.0s
[1A[0J[36mBuilding:[0m 25.1s
[1A[0J[36mBuilding:[0m 25.2s
[1A[0J[36mBuilding:[0m 25.3s
[1A[0J[36mBuilding:[0m 25.4s
[1A[0J[36mBuilding:[0m 25.5s
[1A[0J[36mBuilding:[0m 25.6s
[1A[0J[36mBuilding:[0m 25.7s
[1A[0J[36mBuilding:[0m 25.8s
[1A[0J[36mBuilding:[0m 25.9s
[1A[0J[36mBuilding:[0m 26.0s
[1A[0J[36mBuilding:[0m 26.1s
[1A[0J[36mBuilding:[0m 26.2s
[1A[0J[36mBuilding:[0m 26.3s
[1A[0J[36mBuilding:[0m 26.4s
[1A[0J[36mBuilding:[0m 26.5s
[1A[0J[

[32mBuilding:[0m 49.1s, done.
[36mMessages from [0m[36;1mstanc[0m[36m:[0m


In [29]:
fit = posterior.sample(num_chains=4, num_samples=800)

[36mSampling:[0m   0%
[1A[0J[36mSampling:[0m   3% (200/7200)
[1A[0J[36mSampling:[0m   7% (500/7200)
[1A[0J[36mSampling:[0m  17% (1200/7200)
[1A[0J[36mSampling:[0m  42% (3000/7200)
[1A[0J[36mSampling:[0m  64% (4600/7200)
[1A[0J[36mSampling:[0m  85% (6100/7200)
[1A[0J[36mSampling:[0m 100% (7200/7200)
[1A[0J[32mSampling:[0m 100% (7200/7200), done.
[36mMessages received during sampling:[0m
  Gradient evaluation took 8.7e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.87 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 7.4e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.74 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 7.4e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.74 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 0.000106 seconds
  1000 transitions using 10 leapf

In [30]:
fit['Z'].shape

(10, 15, 3200)

In [23]:
Sigma_true_inv = diag(vec(S)) + gamma * kron(inv(HT), inv(HN))
Sigma_true = inv(Sigma_true_inv)
Omega_true = mat(diag(Sigma_true), like=J)

In [33]:
Fs = np.array([(UN @ (G ** 0.5 * fit['Z'][:, :, i]) @ UT.T) for i in range(800)])

In [35]:
fig, ax = plt.subplots(ncols=2)

ax[0].imshow(np.var(Fs, axis=0))
ax[1].imshow(Omega_true)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fc5820cf9d0>

In [11]:
graph_code = """

functions {
  real graph_signal_lpdf(matrix Z, matrix Y, matrix S, matrix G_h, matrix UN, matrix UTT, real gamma, int N, int T) {
    matrix[N, T] F_ = Y - S .* (UN * (G_h .* Z) * UTT);
    return -0.5 * (sum(F_ .* F_) + gamma * sum(Z .* Z));
  }
}


data {
  int<lower=1> N;         
  int<lower=1> T;         
  real<lower=0> gamma; 
  
  matrix[N, T] Y;
  matrix<lower=0, upper=1>[N, T] S;
  matrix<lower=0, upper=1>[N, T] G_h;
  matrix[N, N] UN;
  matrix[T, T] UTT;

}


parameters {
  matrix[N, T] Z;
}


model {
  target += graph_signal_lpdf(Z | Y, S, G_h, UN, UTT, gamma, N, T);       // prior log-density
}

"""

# graph_data = {"N": N, 'T': T, 'gamma': gamma, 'Y': Y, 'S': S, 'G_h': G ** 0.5, 'UN': UN, 'UTT': UT.T}

# posterior = stan.build(graph_code, data=graph_data)
# fit = posterior.sample(num_chains=4, num_samples=1000)


In [36]:
from PIL import Image

im = Image.open("pic.jpg").convert('L')

In [37]:
Y_clean = np.array(im)

N, T = Y_clean.shape

T, N, gamma, Y, S, S_, lamT, lamN, HT, HN, UT, UN, K, lamK, V, J, G = get_params(T=T, N=N, gamma=5, beta=5, random_graph=False)

S = np.round(UN @ (G * (UN.T @ S @ UT)) @ UT.T + 2e-2)
S_ = 1 - S

Y_clean = (Y_clean - Y_clean.mean()) / Y_clean.std()


Y_noisy = Y_clean + np.random.normal(size=Y_clean.shape)


Y_partial = Y_noisy.copy()
Y_partial[S_.astype(bool)] = np.nan


Y = Y_noisy.copy()
Y[S_.astype(bool)] = 0

# F, its = get_sol(Y, S)


In [38]:
GG = G ** 0.5
GG[0, 0] = 1 - 1e-6

graph_data = {"N": N, 'T': T, 'gamma': gamma, 'Y': Y, 'S': S, 'G_h': GG, 'UN': UN, 'UTT': UT.T}

posterior = stan.build(graph_code2, data=graph_data)


[32mBuilding:[0m found in cache, done.
[36mMessages from [0m[36;1mstanc[0m[36m:[0m


In [39]:
fit = posterior.sample(num_chains=4, num_samples=200)

[36mSampling:[0m   0%
[1A[0J[36mSampling:[0m   0% (1/4800)
[1A[0J[36mSampling:[0m   0% (2/4800)
[1A[0J[36mSampling:[0m   0% (3/4800)
[1A[0J[36mSampling:[0m   0% (4/4800)


In [44]:
T, N, gamma, Y, S, S_, lamT, lamN, HT, HN, UT, UN, K, lamK, V, J, G = get_params(T=15, N=10, beta=1, gamma=1.4, random_graph=False)

In [76]:
def fz(Z):
    F = UN @ (G ** 0.5 * Z) @ UT.T
    return tr((Y - S * F).T @ (Y - S * F)) + gamma * tr(Z.T @ Z)

def fq(Q):
    F = UN @ Q @ UT.T
    return tr((Y - S * F).T @ (Y - S * F)) + gamma * tr(Q.T @ (G ** -1 * Q))

def matrix_derivative_numerical(f, W):
    
    out = np.zeros_like(W)
    dx = 0.001
    
    T, M = W.shape
    
    for i in range(T):
        for j in range(M):
            
            W_ = W.copy()
            _W = W.copy()
            W_[i, j] += dx / 2
            _W[i, j] -= dx / 2
            out[i, j] = (f(W_) - f(_W)) / dx
    
    return out


def derivz(Z):
    return 2 * gamma * Z - 2 * UN.T @ Y @ UT * G ** 0.5 + 2 * (UN.T @ (S * (UN @ (G ** 0.5 * Z) @ UT.T)) @ UT) * G ** 0.5

def derivq(Q):
    return 2 * G ** -1 * Q - 2 * UN.T @ (S * Y) @ UT + 2 * UN.T @ (S * (UN @ Q @ UT.T)) @ UT

In [77]:
matrix_derivative_numerical(fq, Y)

array([[       0.3307,       -0.3223,       -0.5376,       -7.2048,        8.8829,      -46.4487,        2.2593,      -78.4015,        2.4048,        0.5132,       -0.0068,    -4578.3876,        1.8   ,    -2268.7342,     8672.1615],
       [      -0.8951,        1.5788,        0.2526,       -0.4081,        9.5519,       -0.3672,       64.4751,      111.6269,      141.468 ,      576.5339,     -940.0919,       -0.3793,    -4424.3235,    -1926.0803,        5.8603],
       [      -0.0022,       -4.0415,        0.456 ,      -12.8625,       -2.6034,        0.583 ,       -1.5786,       50.8153,      827.2043,       -1.7707,     -464.7762,    -4231.6242,        0.1503,       -1.1231,      833.7915],
       [     -13.703 ,       -1.2919,       45.3149,        0.0104,       35.6299,       31.3638,      -81.1088,     -599.2522,     -422.2851,        2.2082,     3444.5166,     9693.8845,    18828.9669,     8771.8916,    35210.1495],
       [     -34.5901,        3.089 ,       33.5144,       -0.16

In [78]:
derivq(Y)

array([[       0.3307,       -0.3223,        0.0595,       -5.3621,        6.2822,      -32.8438,        2.2593,      -56.5145,        2.4048,        0.5132,       -0.0068,    -3270.5202,        1.8   ,    -1620.6199,     6194.6876],
       [      -0.8951,        1.5788,        0.2526,       -0.4081,        7.4217,       -0.3672,       46.8076,       80.0974,      100.9165,      411.8071,     -671.7035,       -0.3793,    -3159.9751,    -1376.2409,        5.8603],
       [      -0.0022,       -3.2979,        0.456 ,       -9.7465,       -2.6034,        0.7437,       -1.5786,       36.3462,      590.7865,       -1.7707,     -331.8662,    -3021.9462,        0.1503,       -1.1231,      595.4576],
       [     -11.0526,       -1.2919,       32.9657,        0.0104,       25.983 ,       22.1352,      -57.8604,     -428.3936,     -301.7276,        2.2082,     2459.8109,     6923.688 ,    13449.2175,     6265.2733,    25150.3984],
       [     -25.017 ,        3.089 ,       24.3163,       -0.16