In [9]:
import numpy as np
from numpy.random import gamma

def polya_gamma_sample(b, c, trunc=100):
    """
    Sample from PG(b, c) using the infinite sum approximation (truncated).

    Parameters:
        b (float): Shape parameter (often 1).
        c (float): Tilt parameter (often logit).
        trunc (int): Number of terms in the series to approximate.
    
    Returns:
        float: A single sample from PG(b, c)
    """
    pi = np.pi
    out = 0.0
    c = np.abs(c)

    for n in range(1, trunc + 1):
        lambda_n = (n - 0.5)**2 * pi**2 + 0.25 * c**2
        out += gamma(b, 1.0) / lambda_n

    return 0.5 * out

In [99]:
import anndata as ad
adata = ad.read_h5ad("/home/jhaberbe/Data/choroid-plexus/new_annotations.h5ad")
adata = adata[adata.obs["cell_type"].eq("Macrophage")][::10]
X = adata.X.todense()
X = np.array(X)

In [None]:
import numpy as np
from tqdm import trange
from scipy.special import gammaln

# ---- Input: X assumed provided ----
# X: (N, D) Negative Binomial count data
K = 10
N, D = X.shape

# ---- Parameters and latent state initialization ----
size_factors = np.log(X.sum(axis=1) / X.sum(axis=1).mean())  # (N,)
intercept = np.random.normal(size=D)                         # (D,)
dispersion = np.ones(D)                                      # (D,)
Z = np.random.rand(N, K) > 0.5                                # (N, K)
A = np.random.normal(size=(K, D))                             # (K, D)
alpha = 1.0                                                   # IBP concentration

# ---- Utility function ----
def nb_log_likelihood(x, mu, r):
    term1 = gammaln(x + r) - gammaln(r) - gammaln(x + 1)
    term2 = r * np.log(r / (np.exp(mu) + r))
    term3 = x * (mu - np.log(np.exp(mu) + r))
    return (term1 + term2 + term3).sum()

# ---- Gibbs sampling ----
for gibbs_iter in trange(100):
    # --- Sample Z[n, k] ---
    for k in range(K):
        for n in range(N):
            Z[n, k] = 0  # temporarily set to 0

            m_k = Z[:, k].sum()
            prior_z1 = (m_k + alpha / K) / (N + alpha / K)
            prior_z0 = 1.0 - prior_z1

            phi_0 = size_factors[n] + intercept + Z[n] @ A
            phi_1 = phi_0 + A[k]

            ll_0 = nb_log_likelihood(X[n], phi_0, dispersion)
            ll_1 = nb_log_likelihood(X[n], phi_1, dispersion)

            logit = ll_1 - ll_0 + np.log(prior_z1) - np.log(prior_z0)
            p = 1.0 / (1.0 + np.exp(-logit))
            Z[n, k] = np.random.rand() < p

    # --- Sample A[k] via PG augmentation ---
    phi = size_factors[:, None] + intercept[None, :] + Z @ A
    omega = polya_gamma_sample(X + dispersion, phi)
    kappa = X - 0.5 * dispersion[None, :]

    for k in range(K):
        Z_k = Z[:, k]                          # (N,)
        phi_wo_k = phi - np.outer(Z_k, A[k])  # Remove A[k] contribution

        mu_k = np.zeros(D)
        sigma_k = np.zeros(D)

        for d in range(D):
            w = omega[:, d]
            residual = kappa[:, d] - w * phi_wo_k[:, d]
            precision = np.sum(w * Z_k**2)
            sigma2 = 1.0 / precision if precision > 1e-12 else 1e12
            mu = sigma2 * np.sum(Z_k * residual)
            mu_k[d] = mu
            sigma_k[d] = np.sqrt(sigma2)

        A[k] = np.random.normal(loc=mu_k, scale=sigma_k)

    def propose_r(r_old, step_size=0.1):
        log_r_new = np.log(r_old) + np.random.normal(scale=step_size)
        return np.exp(log_r_new)

    def log_nb_likelihood_column(x, r, phi):
        mu = np.exp(phi)
        term1 = gammaln(x + r) - gammaln(r) - gammaln(x + 1)
        term2 = r * np.log(r / (mu + r))
        term3 = x * (phi - np.log(mu + r))
        return term1 + term2 + term3

    # Inside Gibbs loop:
    for d in range(D):
        r_old = dispersion[d]
        r_new = propose_r(r_old)

        ll_old = log_nb_likelihood_column(X[:, d], r_old, phi[:, d]).sum()
        ll_new = log_nb_likelihood_column(X[:, d], r_new, phi[:, d]).sum()

        prior_old = -r_old  # exponential prior with rate=1
        prior_new = -r_new

        log_accept_ratio = ll_new + prior_new - ll_old - prior_old

        if np.log(np.random.rand()) < log_accept_ratio:
            dispersion[d] = r_new


  p = 1.0 / (1.0 + np.exp(-logit))
  term2 = r * np.log(r / (np.exp(mu) + r))
  term2 = r * np.log(r / (np.exp(mu) + r))
  term3 = x * (mu - np.log(np.exp(mu) + r))
  term3 = x * (mu - np.log(np.exp(mu) + r))
  3%|▎         | 3/100 [09:40<5:15:08, 194.93s/it]