# Broad Structure

Our sampler is going to have three components to it:

1. Learning the sparse latent matrix
    - We will encode this as draws from a bernoulli probability (logits).
    - We should be able to use a polya gamma construction for this!

2. Learning the loading matrix
    - We will use our polya gamma construction here as well.

3. Dispersion parameter
    - Probably, we will use simple MH type updates.

# Indian Buffet Process Class

I think it should be a container for our latent factors, which will have all the coefficients and stuff. However, the math stuff will definitely be done on this thing. 

In [None]:
import scanpy as sc
adata = sc.read_h5ad("/home/jhaberbe/Data/choroid-plexus/new_annotations.h5ad")
adata = adata[adata.X.sum(axis=1) > 300]
adata = adata[adata.obs["Cell.Subtype"].eq("Macrophage")][::5]
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=1000, subset=True, inplace=True)

X = adata.X.todense()

  adata.uns["hvg"] = {"flavor": flavor}


Actual sampling time...

In [42]:
from typing import *
import numpy as np
from scipy.special import gammaln
from tqdm import tqdm, trange
from pypolyagamma import PyPolyaGamma

In [43]:
import numpy as np
from pypolyagamma import PyPolyaGamma

class PolyaGammaSampler:
    def __init__(self):
        self.pg = PyPolyaGamma(seed=np.random.randint(69420))

    def sample(self, b, c):
        """
        Sample PG(b, c), element-wise over arrays.
        """
        b = np.asarray(b, dtype=np.double)
        c = np.asarray(c, dtype=np.double)
        shape = c.shape

        b_flat = b.flatten()
        c_flat = c.flatten()
        out = np.zeros_like(c_flat)
        self.pg.pgdrawv(b_flat, c_flat, out)

        return out.reshape(shape)

In [None]:
class IndianBuffetProcess:
    """"""

    def __init__(self, counts: np.array, alpha: float = 1.0):
        # Count Matrix, np.array
        self.counts = counts
        self.size_factors = np.log(counts.sum(axis=1) / counts.sum(axis=1).mean())

        # Innovation Parameter
        self.alpha = alpha / np.log(self.counts.shape[0])

        # Latent Factor Matrix.
        self.factors = {}

        # Our prior dispersion parameter, right now it will be 5.0, but this can change.
        self.dispersion = np.ones(self.counts.shape[1]) * 5.0

        # Loading updates stuff for polya gamma sampler.
        self.polya_gamma_sampler = PolyaGammaSampler()
        self.prior_precision = 0.01 * np.ones(self.counts.shape[1])  # Precision vector for A_k


    # This section will deal with factors broadly.
    def add_factor(self, index: Union[None, int]):
        i = 0
        while i in self.factors:
            i += 1

        self.factors[i] = LatentFactors(self.counts)

        if index != None:
            self.factors.add_member(index)

    def remove_factor(self, key: int):
        if key in self.factors:
            self.factors.pop(key)

    def sample_new_factors(self, index):
        """Sample new features for the given index."""
        K_new = np.random.poisson(self.alpha / np.log(self.counts.shape[0]))

        if K_new > 0:

            for _ in range(K_new):

                i = 0
                while i in self.factors:
                    i += 1

                self.factors[i] = LatentFactor(self.counts)
                self.factors[i].add_member(index)

    def remove_empty_factors(self, min_membership=2):
        for key in list(self.factors.keys()):  # <-- make a list copy
            if self.factors[key].Z.sum() <= min_membership:
                self.factors.pop(key)

    # This section will contain all the logic used in our Gibbs Sampling
    def update_latent_membership(self, index: int):
        """
        Resample membership for all latent factors for a single observation.

        Handles the case when no factors currently exist.
        """
        y_n = self.counts[index]         # (D,)
        offset_n = self.size_factors[index]  # scalar

        # Precompute log-mu with all current features (we’ll subtract them later)
        current_log_mu = np.full_like(y_n, offset_n)  # (D,)

        for factor in self.factors.values():
            if factor.Z[index]:
                current_log_mu += factor.A  # sum all current active A_k

        for k, factor in self.factors.items():
            # Temporarily remove k-th factor's contribution
            was_active = factor.Z[index]
            if was_active:
                log_mu_0 = current_log_mu - factor.A
            else:
                log_mu_0 = current_log_mu

            # Include factor
            log_mu_1 = log_mu_0 + factor.A

            # Compute log-likelihoods
            ll_0 = self.negative_binomial_log_likelihood(y_n, log_mu_0, self.dispersion)
            ll_1 = self.negative_binomial_log_likelihood(y_n, log_mu_1, self.dispersion)

            # Prior log-odds: based on number of other assignments
            m_k = factor.Z.sum() - was_active  # count excluding this index
            N = self.counts.shape[0]
            log_prior_ratio = np.log((m_k + 1e-10) / (N - m_k + 1e-10))

            # Posterior logit and probability
            logit_p = log_prior_ratio + (ll_1 - ll_0)
            p = 1 / (1 + np.exp(-logit_p))

            # Sample new assignment
            new_z = np.random.rand() < p

            # Update Z and current_log_mu accordingly
            if new_z and not was_active:
                factor.Z[index] = 1
                current_log_mu += factor.A
            elif not new_z and was_active:
                factor.Z[index] = 0
                current_log_mu -= factor.A

    def update_latent_loadings(self):
        """
        Update the log-loadings A_k for each latent factor using Polya-Gamma augmentation.
        """
        for k, factor in self.factors.items():
            Z_k = factor.Z  # shape (N,)
            active_indices = np.where(Z_k == 1)[0]
            if len(active_indices) == 0:
                continue

            # Construct X (design matrix): binary indicator (N_active x 1)
            X = np.ones((len(active_indices), 1))  # we only have one latent feature per update

            # Construct y: (N_active x D)
            Y = self.counts[active_indices]  # shape (N_active, D)

            # Construct offset: for each sample, subtract other factors' contribution
            offset = np.zeros_like(Y)  # shape (N_active, D)
            for j, other in self.factors.items():
                if j == k:
                    continue
                offset += np.outer(other.Z[active_indices], other.A)

            offset += self.size_factors[active_indices][:, None]

            eta = offset + factor.A  # shape (N_active, D)
            omega = self.polya_gamma_sampler.sample(Y + self.dispersion, eta)

            # Posterior covariance and mean for each gene d
            A_k_new = np.zeros_like(factor.A)
            for d in range(self.counts.shape[1]):
                XWX = (X.T * omega[:, d]) @ X
                precision = XWX + self.prior_precision[d]
                posterior_cov = 1.0 / precision

                # Posterior mean
                update_grad = np.sum(X.T * (Y[:, d] - self.dispersion[d]) / 2)
                prior_contrib = self.prior_precision[d] * 0.0  # mean is 0
                posterior_mean = posterior_cov * (update_grad + prior_contrib)

                A_k_new[d] = np.random.normal(posterior_mean, np.sqrt(posterior_cov))

            factor.A = A_k_new


    def update_dispersion(self, proposal_sigma=1.0):
        """Update each gene's dispersion parameter via simple MH step."""
        for d in range(self.counts.shape[1]):
            r_old = self.dispersion[d]
            r_prop = np.abs(r_old + np.random.normal(0, proposal_sigma))  # reflect at 0

            loglike_old = 0.0
            loglike_prop = 0.0

            for n in range(self.counts.shape[0]):
                # Compute current log-mu for this observation
                log_mu = self.size_factors[n]
                for factor in self.factors.values():
                    if factor.Z[n]:
                        log_mu += factor.A[d]

                y_nd = self.counts[n, d]

                loglike_old += self.negative_binomial_log_likelihood(
                    np.array([y_nd]), np.array([log_mu]), np.array([r_old])
                )
                loglike_prop += self.negative_binomial_log_likelihood(
                    np.array([y_nd]), np.array([log_mu]), np.array([r_prop])
                )

            # Optional prior on dispersion — Gamma(r | a0, b0)
            a0, b0 = 2.0, 0.1
            prior_old = (a0 - 1) * np.log(r_old) - b0 * r_old
            prior_prop = (a0 - 1) * np.log(r_prop) - b0 * r_prop

            log_accept_ratio = (loglike_prop + prior_prop) - (loglike_old + prior_old)
            if np.log(np.random.rand()) < log_accept_ratio:
                self.dispersion[d] = r_prop  # accept

    # Actual Gibbs Sampling Logic.
    def gibbs_sampling(self, n_epochs: int = 1000, verbose = True):
        # For the current epoch.
        for current_epoch in range(n_epochs):

            # 1. Update membership and sample new factors for each observation
            for index in trange(self.counts.shape[0]):
                self.update_latent_membership(index)
                self.sample_new_factors(index)

            # 2. Update loadings for each factor
            self.update_latent_loadings()

            # 3. Update dispersions
            self.update_dispersion()

            # 4. Remove any now-empty factors
            
            self.remove_empty_factors(min_membership=max(2 * np.log(current_epoch), 1))

            if verbose:
                print(f"Epoch {current_epoch}, # of Features = {len(self.factors)}, log-likelihood = {self.log_likelihood()}")
    
    @staticmethod
    def negative_binomial_log_likelihood(counts: np.ndarray, log_mu: np.ndarray, dispersion: np.ndarray) -> float:
        """
        Compute NB log-likelihood for a single observation across D dimensions.

        Parameters
        ----------
        counts : (D,) array
            Observed counts for one sample.
        log_mu : (D,) array
            Log of mean counts.
        dispersion : (D,) array
            Dispersion parameters for NB.

        Returns
        -------
        float
            Sum of log-likelihoods over D dimensions.
        """
        mu = np.exp(log_mu)
        r = dispersion
        y = counts

        return np.sum(
            gammaln(y + r) - gammaln(r) - gammaln(y + 1) +
            r * np.log(r / (r + mu)) +
            y * np.log(mu / (r + mu))
        )

    def log_likelihood(self) -> float:
        """
        Compute the total log-likelihood of the entire dataset under the model.

        Sums over all samples n and genes d:
        log p(y_nd | Z, A, dispersion, size_factors)

        Returns
        -------
        float
            Total log-likelihood scalar.
        """
        total_ll = 0.0
        N, D = self.counts.shape

        for n in range(N):
            # Compute log_mu for sample n
            log_mu = np.full(D, self.size_factors[n])
            for factor in self.factors.values():
                if factor.Z[n]:
                    log_mu += factor.A

            # Counts for sample n
            y_n = self.counts[n]

            # Sum Negative Binomial log-likelihood for this sample
            total_ll += self.negative_binomial_log_likelihood(y_n, log_mu, self.dispersion)

        return total_ll


In [84]:
class LatentFactor:

    def __init__(self, counts: np.array):
        # Store a reference to the counts just for fun.
        self.counts = counts

        # Latent Membership
        # How should this be initialized for logits construction.
        # Simplest case, we'll make this binary at first.
        # ChatGPT says Albert & Chib (1993) is good?
        self.Z = np.zeros(counts.shape[0])

        # Latent Loadings
        # My first instinct is gaussian. Starting off with this.
        self.A = np.zeros(counts.shape[1])
    
    def add_member(self, index):
        """Just add the index"""
        # What else you gotta say?
        self.Z[index] = 1

In [87]:
ibp = IndianBuffetProcess(np.array(X)[::5])
ibp.gibbs_sampling(n_epochs=100)

100%|██████████| 808/808 [00:00<00:00, 1235.73it/s]
  A_k_new[d] = np.random.normal(posterior_mean, np.sqrt(posterior_cov))
  self.remove_empty_factors(min_membership=max(2 * np.log(current_epoch), 1))


Epoch 0, # of Features = 8, log-likelihood = -736567.5552646157


  p = 1 / (1 + np.exp(-logit_p))
100%|██████████| 808/808 [00:00<00:00, 1096.64it/s]


Epoch 1, # of Features = 13, log-likelihood = -342346.3627194388


100%|██████████| 808/808 [00:01<00:00, 603.14it/s]


Epoch 2, # of Features = 20, log-likelihood = -322879.3937728793


100%|██████████| 808/808 [00:01<00:00, 438.37it/s]


Epoch 3, # of Features = 13, log-likelihood = -321660.23025927827


100%|██████████| 808/808 [00:01<00:00, 628.75it/s]


Epoch 4, # of Features = 15, log-likelihood = -304573.23373357527


100%|██████████| 808/808 [00:01<00:00, 610.53it/s]


Epoch 5, # of Features = 11, log-likelihood = -300422.0988268419


100%|██████████| 808/808 [00:01<00:00, 681.15it/s] 


Epoch 6, # of Features = 13, log-likelihood = -289709.88834416895


100%|██████████| 808/808 [00:01<00:00, 603.67it/s] 


Epoch 7, # of Features = 14, log-likelihood = -282554.49387919565


100%|██████████| 808/808 [00:01<00:00, 638.87it/s]


Epoch 8, # of Features = 7, log-likelihood = -290704.5307119689


100%|██████████| 808/808 [00:00<00:00, 836.80it/s] 


Epoch 9, # of Features = 7, log-likelihood = -272658.50756252673


100%|██████████| 808/808 [00:00<00:00, 1078.55it/s]


Epoch 10, # of Features = 7, log-likelihood = -268796.003308564


100%|██████████| 808/808 [00:01<00:00, 780.02it/s] 


Epoch 11, # of Features = 7, log-likelihood = -266110.33207483473


100%|██████████| 808/808 [00:00<00:00, 815.19it/s] 


Epoch 12, # of Features = 8, log-likelihood = -265419.9644792972


100%|██████████| 808/808 [00:01<00:00, 726.19it/s] 


Epoch 13, # of Features = 6, log-likelihood = -273276.0028518394


100%|██████████| 808/808 [00:00<00:00, 1047.76it/s]


KeyboardInterrupt: 

In [88]:
for k in ibp.factors:
    print(k, ibp.factors[k].Z.sum())

0 747.0
2 30.0
8 7.0
17 7.0
12 6.0
1 11.0
3 7.0
4 1.0
5 1.0
6 2.0
7 1.0
9 2.0
10 1.0
11 1.0
13 1.0
14 2.0
15 1.0
16 1.0
18 1.0
19 1.0
20 1.0


In [100]:
import numpy as np

def top5_abs_indices(arr: np.ndarray) -> np.ndarray:
    """
    Returns the indices of the 5 elements with largest absolute values in `arr`.
    If `arr` has fewer than 5 elements, returns all indices sorted by abs value descending.
    """
    n = min(5, arr.size)
    # argsort returns indices sorted ascending, so take last n and reverse
    return np.argsort(arr)[-n:][::-1]


The first group is an intercept basically, this ones a gimme

In [101]:
adata.var_names[top5_abs_indices(ibp.factors[0].A)]

Index(['VCAN', 'MARCO', 'HSP90AA1', 'CD83', 'F13A1'], dtype='object')

This one shows that its for real. I've seen PADI2, HSPs and FTL before, so I'm fairly certain this has some meaning.

In [102]:
adata.var_names[top5_abs_indices(ibp.factors[2].A)]

Index(['HSP90AA1', 'FTL', 'NDRG1', 'HSPH1', 'PADI2'], dtype='object')

This is a smaller one, and just as an example, AQP9 is primarily expressed in vasculature and macrophages. Given the associated VCAN signature, you might call this something like "vasculature associated macrophages".

In [103]:
adata.var_names[top5_abs_indices(ibp.factors[1].A)]

Index(['VCAN', 'AC037198.1', 'TFRC', 'ALPK3', 'AQP9'], dtype='object')

So our sampler now works! and gives us some kind of unique and biologically interesting factors! It will take a generation to learn anything, but thats my fault.