# Broad Structure

Our sampler is going to have three components to it:

1. Learning the sparse latent matrix
    - We will encode this as draws from a bernoulli probability (logits).
    - We should be able to use a polya gamma construction for this!

2. Learning the loading matrix
    - We will use our polya gamma construction here as well.

3. Dispersion parameter
    - Probably, we will use simple MH type updates.

# Indian Buffet Process Class

I think it should be a container for our latent factors, which will have all the coefficients and stuff. However, the math stuff will definitely be done on this thing. 

In [None]:
import scanpy as sc
adata = sc.read_h5ad("/home/jhaberbe/Data/choroid-plexus/new_annotations.h5ad")
adata = adata[adata.X.sum(axis=1) > 300]
adata = adata[adata.obs["Cell.Subtype"].eq("Macrophage")][::5]
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=1000, subset=True, inplace=True)

X = adata.X.todense()

  adata.uns["hvg"] = {"flavor": flavor}


Actual sampling time...

In [29]:
from typing import *
import numpy as np
from scipy.special import gammaln
from tqdm import tqdm, trange

In [35]:
class PolyaGammaSampler:

    def __init__(self):
        pass
    
    def sample_scalar(self, b, c, truncation=100):
        total = 0.0
        for k in range(1, truncation + 1):
            term_1 = np.random.gamma(b, 1.0)
            term_2 = (k - 0.5)**2
            term_3 = (c / (2 * np.pi))**2
            total += term_1 / (term_2 + term_3)
        
        return 0.5 * total / (np.pi**2)

    # def sample(self, b_array, c_array, truncation=100):
    #     b_array = np.asarray(b_array)
    #     c_array = np.asarray(c_array)
    #     N = b_array.shape[0]
    #     omega = np.zeros(N)
        
    #     for n in range(N):
    #         omega[n] = self.sample_scalar(b_array[n], c_array[n], truncation=truncation)
        
    #     return omega

    def sample(self, b_array, c_array, truncation=100):
        b_array = np.asarray(b_array)
        c_array = np.asarray(c_array)
        omega = np.zeros_like(b_array, dtype=float)

        # Flatten arrays to loop over all elements (n, d)
        b_flat = b_array.flatten()
        c_flat = c_array.flatten()

        omega_flat = np.zeros_like(b_flat, dtype=float)

        for i in range(len(b_flat)):
            omega_flat[i] = self.sample_scalar(b_flat[i], c_flat[i], truncation=truncation)

        omega = omega_flat.reshape(b_array.shape)
        return omega


pg = PolyaGammaSampler()

In [40]:
class IndianBuffetProcess:
    """"""

    def __init__(self, counts: np.array, alpha: float = 1.0):
        # Count Matrix, np.array
        self.counts = counts
        self.size_factors = np.log(counts.sum(axis=1) / counts.sum(axis=1).mean())

        # Innovation Parameter
        self.alpha = alpha / np.log(self.counts.shape[0])

        # Latent Factor Matrix.
        self.factors = {}

        # Our prior dispersion parameter, right now it will be 5.0, but this can change.
        self.dispersion = np.ones(self.counts.shape[1]) * 5.0

        # Loading updates stuff for polya gamma sampler.
        self.polya_gamma_sampler = PolyaGammaSampler()
        self.prior_precision = 0.01 * np.ones(self.counts.shape[1])  # Precision vector for A_k


    # This section will deal with factors broadly.
    def add_factor(self, index: Union[None, int]):
        i = 0
        while i in self.factors:
            i += 1

        self.factors[i] = LatentFactors(self.counts)

        if index != None:
            self.factors.add_member(index)

    def remove_factor(self, key: int):
        if key in self.factors:
            self.factors.pop(key)

    def sample_new_factors(self, index):
        """Sample new features for the given index."""
        K_new = np.random.poisson(self.alpha / np.log(self.counts.shape[0]))

        if K_new > 0:

            for _ in range(K_new):

                i = 0
                while i in self.factors:
                    i += 1

                self.factors[i] = LatentFactor(self.counts)
                self.factors[i].add_member(index)

    def remove_empty_factors(self):
        for key in self.factors:
            if self.factors[key].Z.sum() == 0:
                self.factors.pop(key)

    # This section will contain all the logic used in our Gibbs Sampling
    def update_latent_membership(self, index: int):
        """
        Resample membership for all latent factors for a single observation.

        Handles the case when no factors currently exist.
        """
        y_n = self.counts[index]         # (D,)
        offset_n = self.size_factors[index]  # scalar

        # Precompute log-mu with all current features (we’ll subtract them later)
        current_log_mu = np.full_like(y_n, offset_n)  # (D,)

        for factor in self.factors.values():
            if factor.Z[index]:
                current_log_mu += factor.A  # sum all current active A_k

        for k, factor in self.factors.items():
            # Temporarily remove k-th factor's contribution
            was_active = factor.Z[index]
            if was_active:
                log_mu_0 = current_log_mu - factor.A
            else:
                log_mu_0 = current_log_mu

            # Include factor
            log_mu_1 = log_mu_0 + factor.A

            # Compute log-likelihoods
            ll_0 = self.negative_binomial_log_likelihood(y_n, log_mu_0, self.dispersion)
            ll_1 = self.negative_binomial_log_likelihood(y_n, log_mu_1, self.dispersion)

            # Prior log-odds: based on number of other assignments
            m_k = factor.Z.sum() - was_active  # count excluding this index
            N = self.counts.shape[0]
            log_prior_ratio = np.log((m_k + 1e-10) / (N - m_k + 1e-10))

            # Posterior logit and probability
            logit_p = log_prior_ratio + (ll_1 - ll_0)
            p = 1 / (1 + np.exp(-logit_p))

            # Sample new assignment
            new_z = np.random.rand() < p

            # Update Z and current_log_mu accordingly
            if new_z and not was_active:
                factor.Z[index] = 1
                current_log_mu += factor.A
            elif not new_z and was_active:
                factor.Z[index] = 0
                current_log_mu -= factor.A

    def update_latent_loadings(self):
        """
        Update the log-loadings A_k for each latent factor using Polya-Gamma augmentation.
        """
        for k, factor in self.factors.items():
            Z_k = factor.Z  # shape (N,)
            active_indices = np.where(Z_k == 1)[0]
            if len(active_indices) == 0:
                continue

            # Construct X (design matrix): binary indicator (N_active x 1)
            X = np.ones((len(active_indices), 1))  # we only have one latent feature per update

            # Construct y: (N_active x D)
            Y = self.counts[active_indices]  # shape (N_active, D)

            # Construct offset: for each sample, subtract other factors' contribution
            offset = np.zeros_like(Y)  # shape (N_active, D)
            for j, other in self.factors.items():
                if j == k:
                    continue
                offset += np.outer(other.Z[active_indices], other.A)

            offset += self.size_factors[active_indices][:, None]

            eta = offset + factor.A  # shape (N_active, D)
            omega = self.polya_gamma_sampler.sample(Y + self.dispersion, eta)

            # Posterior covariance and mean for each gene d
            A_k_new = np.zeros_like(factor.A)
            for d in range(self.counts.shape[1]):
                XWX = (X.T * omega[:, d]) @ X
                precision = XWX + self.prior_precision[d]
                posterior_cov = 1.0 / precision

                # Posterior mean
                update_grad = np.sum(X.T * (Y[:, d] - self.dispersion[d]) / 2)
                prior_contrib = self.prior_precision[d] * 0.0  # mean is 0
                posterior_mean = posterior_cov * (update_grad + prior_contrib)

                A_k_new[d] = np.random.normal(posterior_mean, np.sqrt(posterior_cov))

            factor.A = A_k_new


    def update_dispersion(self, proposal_sigma=1.0):
        """Update each gene's dispersion parameter via simple MH step."""
        for d in range(self.counts.shape[1]):
            r_old = self.dispersion[d]
            r_prop = np.abs(r_old + np.random.normal(0, proposal_sigma))  # reflect at 0

            loglike_old = 0.0
            loglike_prop = 0.0

            for n in range(self.counts.shape[0]):
                # Compute current log-mu for this observation
                log_mu = self.size_factors[n]
                for factor in self.factors.values():
                    if factor.Z[n]:
                        log_mu += factor.A[d]

                y_nd = self.counts[n, d]

                loglike_old += self.negative_binomial_log_likelihood(
                    np.array([y_nd]), np.array([log_mu]), np.array([r_old])
                )
                loglike_prop += self.negative_binomial_log_likelihood(
                    np.array([y_nd]), np.array([log_mu]), np.array([r_prop])
                )

            # Optional prior on dispersion — Gamma(r | a0, b0)
            a0, b0 = 2.0, 0.1
            prior_old = (a0 - 1) * np.log(r_old) - b0 * r_old
            prior_prop = (a0 - 1) * np.log(r_prop) - b0 * r_prop

            log_accept_ratio = (loglike_prop + prior_prop) - (loglike_old + prior_old)
            if np.log(np.random.rand()) < log_accept_ratio:
                self.dispersion[d] = r_prop  # accept

    # Actual Gibbs Sampling Logic.
    def gibbs_sampling(self, n_epochs: int = 1000, verbose = True):
        # For the current epoch.
        for current_epoch in range(n_epochs):

            # 1. Update membership and sample new factors for each observation
            for index in trange(self.counts.shape[0]):
                self.update_latent_membership(index)
                self.sample_new_factors(index)

            # 2. Update loadings for each factor
            self.update_latent_loadings()

            # 3. Update dispersions
            self.update_dispersion()

            # 4. Remove any now-empty factors
            self.remove_empty_factors()

            if verbose:
                print(f"Epoch {current_epoch}, log-likelihood = {self.log_likelihood()}")
    
    @staticmethod
    def negative_binomial_log_likelihood(counts: np.ndarray, log_mu: np.ndarray, dispersion: np.ndarray) -> float:
        """
        Compute NB log-likelihood for a single observation across D dimensions.

        Parameters
        ----------
        counts : (D,) array
            Observed counts for one sample.
        log_mu : (D,) array
            Log of mean counts.
        dispersion : (D,) array
            Dispersion parameters for NB.

        Returns
        -------
        float
            Sum of log-likelihoods over D dimensions.
        """
        mu = np.exp(log_mu)
        r = dispersion
        y = counts

        return np.sum(
            gammaln(y + r) - gammaln(r) - gammaln(y + 1) +
            r * np.log(r / (r + mu)) +
            y * np.log(mu / (r + mu))
        )

    def log_likelihood(self) -> float:
        """
        Compute the total log-likelihood of the entire dataset under the model.

        Sums over all samples n and genes d:
        log p(y_nd | Z, A, dispersion, size_factors)

        Returns
        -------
        float
            Total log-likelihood scalar.
        """
        total_ll = 0.0
        N, D = self.counts.shape

        for n in range(N):
            # Compute log_mu for sample n
            log_mu = np.full(D, self.size_factors[n])
            for factor in self.factors.values():
                if factor.Z[n]:
                    log_mu += factor.A

            # Counts for sample n
            y_n = self.counts[n]

            # Sum Negative Binomial log-likelihood for this sample
            total_ll += self.negative_binomial_log_likelihood(y_n, log_mu, self.dispersion)

        return total_ll


In [41]:
class LatentFactor:

    def __init__(self, counts: np.array):
        # Store a reference to the counts just for fun.
        self.counts = counts

        # Latent Membership
        # How should this be initialized for logits construction.
        # Simplest case, we'll make this binary at first.
        # ChatGPT says Albert & Chib (1993) is good?
        self.Z = np.zeros(counts.shape[0])

        # Latent Loadings
        # My first instinct is gaussian. Starting off with this.
        self.A = np.zeros(counts.shape[1])
    
    def add_member(self, index):
        """Just add the index"""
        # What else you gotta say?
        self.Z[index] = 1

In [None]:
ibp = IndianBuffetProcess(np.array(X))
ibp.gibbs_sampling(n_epochs=1)