In [1]:
import numpy as np
import math
from scipy.spatial import distance_matrix
import stan

import nest_asyncio
nest_asyncio.apply()
del nest_asyncio

In [2]:
n = 100
np.random.seed(400)
n_pairs = math.comb(n, 2)
# TODO: generalize to multiple covariates
locs  = np.random.uniform(-1, 1, 2*n).reshape(n, 2)
distances = distance_matrix(locs, locs)

In [3]:
# paired features
x1 = np.random.normal(size = n)
x2 = np.random.normal(size = n)
# first column is intercept term
x = np.ones((n_pairs, 2))
counter = 0
for i in range(n - 1):
    for j in range(i+1, n):
        x[counter, 1] = abs(x1[i] - x2[j])
        counter += 1

In [5]:
# individual features
z0 = np.random.normal(size=(n, 2))
z = np.zeros((n_pairs,2))
counter = 0
for i in range(n-1):
    for j in range(i+1, n):
        z[counter] = z0[i] + z0[j]
        counter += 1

In [6]:
# pairs matrix
counter = 0
w = np.zeros((n_pairs, n))
for i in range(n-1):
    for j in range(i+1, n):
        w[counter, i] = 1
        w[counter, j] = 1
        counter = counter + 1

In [7]:
# pair coefficients
beta = np.array([1, -0.3])
# individual covariates
gamma = np.array([2.1, -1.75])
# noise variance
sigma2_eps = 0.1
# baseline spatial variance
tau2 = 2
# individual-effect variance
sigma2_zeta = 0.01

# exponential kernel scale
ell = 0.2
# kernel matrix
K = tau2 * np.exp(-distances / ell)
# spatial factor of GP
eta = np.random.multivariate_normal(np.zeros(n), K)
# incorporate individual effects
theta = eta + np.random.normal(0, np.sqrt(sigma2_zeta), n)
# mean centering
theta = theta - theta.mean()
# mean function
mu = np.dot(x, beta) + np.dot(z, gamma) # + np.dot(w, theta)


In [8]:
# log patristic distances
y = np.random.normal(loc=mu, scale=np.sqrt(sigma2_eps))

In [9]:
data = {
    "N": n,
    "M": n_pairs,
    "P": 2,
    "Q": 2,
    "y": y,
    "X": x,
    "Z": z,
    "W": w,
    "S": distances
}

In [22]:
with open("../models/patristic_distance.stan", "r") as file:
    model_code = file.read()

In [28]:
posterior = stan.build(model_code, data=data, random_seed=1)

Building...

In file included from /Users/pchatha/Library/Caches/httpstan/4.10.1/models/m2ktnixj/model_m2ktnixj.cpp:2:
In file included from /Users/pchatha/epi/gene-pairs/.venv/lib/python3.11/site-packages/httpstan/include/stan/model/model_header.hpp:4:
In file included from /Users/pchatha/epi/gene-pairs/.venv/lib/python3.11/site-packages/httpstan/include/stan/math.hpp:19:
In file included from /Users/pchatha/epi/gene-pairs/.venv/lib/python3.11/site-packages/httpstan/include/stan/math/rev.hpp:4:
In file included from /Users/pchatha/epi/gene-pairs/.venv/lib/python3.11/site-packages/httpstan/include/stan/math/prim/fun/Eigen.hpp:23:
In file included from /Users/pchatha/epi/gene-pairs/.venv/lib/python3.11/site-packages/httpstan/include/Eigen/Sparse:26:
In file included from /Users/pchatha/epi/gene-pairs/.venv/lib/python3.11/site-packages/httpstan/include/Eigen/SparseCore:61:
      Index count = 0;
            ^
In file included from /Users/pchatha/Library/Caches/httpstan/4.10.1/models/m2ktnixj/model_m2





Building: 22.1s, done.Messages from stanc:
    beginning with # are deprecated and this syntax will be removed in Stan
    2.33.0. Use // to begin line comments; this can be done automatically
    using the auto-format flag to stanc
    beginning with # are deprecated and this syntax will be removed in Stan
    2.33.0. Use // to begin line comments; this can be done automatically
    using the auto-format flag to stanc
    beginning with # are deprecated and this syntax will be removed in Stan
    2.33.0. Use // to begin line comments; this can be done automatically
    using the auto-format flag to stanc
    beginning with # are deprecated and this syntax will be removed in Stan
    2.33.0. Use // to begin line comments; this can be done automatically
    using the auto-format flag to stanc
    0.01 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    is a gamma or inverse-gamma distribution with parameters

In [29]:
fit = posterior.sample(num_chains=1, num_samples=200)

Sampling:   0%
Sampling:   0% (1/1200)
Sampling:   8% (100/1200)
Sampling:  17% (200/1200)
Sampling:  25% (300/1200)
Sampling:  33% (400/1200)
Sampling:  42% (500/1200)
Sampling:  50% (600/1200)
Sampling:  58% (700/1200)
Sampling:  67% (800/1200)
Sampling:  75% (900/1200)

In [27]:
fit["gamma"][-20:].mean(1)

array([ 2.10338061, -1.75422689])

In [25]:
fit.param_names

('beta', 'gamma', 'sigma2_eps', 'tau2', 'sigma2_zeta', 'ell', 'zeta', 'eta')