# Indian Buffet Process

In [25]:
IndianBuffetProcess

__main__.IndianBuffetProcess

In [33]:
import sys
sys.path.append("../../src/crp")
from table import DirichletMultinomialTable, ChineseRestaurantTable, NegativeBinomialTable

In [68]:
import pandas as pd

class NegativeBinomialTable(ChineseRestaurantTable):
    """
    Represents a table in a Negative Binomial model.

    Attributes:
        data (np.ndarray): A 2D array storing the table's data.
        members (set): A set of unique indices representing the table's members.
        alpha (np.ndarray): A 1D array representing the shape parameters.
        beta (np.ndarray): A 1D array representing the rate parameters.
        reference_total (float): The total count of self.data.
    """

    def __init__(self, data: np.ndarray):
        """
        Initializes a NegativeBinomialTable object with the given data.

        Args:
            data (np.ndarray): A 2D array to be stored in the table.

        Raises:
            TypeError: If the data is not a numpy array.
        """
        if not isinstance(data, np.ndarray):
            raise TypeError("Data must be a numpy array")
        self.data = np.array(data)
        self.members = set()
        self.alpha = np.ones(self.data.shape[1])  # prior shape
        self.beta = np.ones(self.data.shape[1])
        self.reference_total = np.mean(np.sum(self.data, axis=1))

    def add_member(self, index: int):
        """
        Adds a member to the table at the specified index.

        Args:
            index (int): The index at which the member is to be added.

        Raises:
            ValueError: If the index is not a valid index.
        """
        if index < 0:
            raise ValueError("Index must be a non-negative integer")
        if index not in self.members:
            self.members.add(index)
            self.alpha += self.data[index]
            self.beta += 1  # One new data point

    def remove_member(self, index: int):
        """
        Removes a member from the table at the specified index.

        Args:
            index (int): The index at which the member is to be removed.

        Raises:
            ValueError: If the index is not a valid index.
        """
        if index < 0:
            raise ValueError("Index must be a non-negative integer")
        if index in self.members:
            self.members.remove(index)
            self.alpha -= self.data[index]
            self.beta -= 1

    def _gamma_poisson_log_likelihood(self, count: np.ndarray, alpha: np.ndarray, beta: np.ndarray) -> float:
        """
        Calculates the log likelihood of the Negative Binomial model.

        Args:
            count (np.ndarray): A 1D array representing the count.
            alpha (np.ndarray): A 1D array representing the shape parameters.
            beta (np.ndarray): A 1D array representing the rate parameters.

        Returns:
            float: The log likelihood of the Negative Binomial model.
        """
        count = np.asarray(count).reshape(-1)
        alpha = np.asarray(alpha).reshape(-1)
        beta = np.asarray(beta).reshape(-1)

        # Compute size factor from total count vs. mean total count of self.data
        total = np.sum(count)
        reference_total = self.reference_total
        size_factor = total / reference_total if reference_total > 0 else 1.0
        log_sf = np.log(size_factor)

        # Gamma-Poisson log-likelihood with offset
        term1 = gammaln(count + alpha)
        term2 = -gammaln(count + 1)
        term3 = -gammaln(alpha)
        term4 = alpha * np.log(beta / (beta + np.exp(log_sf)))
        term5 = count * np.log(np.exp(log_sf) / (beta + np.exp(log_sf)))

        return np.sum(term1 + term2 + term3 + term4 + term5)

    def return_parameters(self, index: int = None) -> pd.DataFrame:
        """
        Returns the parameters of the table.

        Args:
            index (int, optional): The index at which to return the parameters. Defaults to None.

        Returns:
            pd.DataFrame: A pandas DataFrame containing the shape and rate parameters.
        """
        parameters = pd.DataFrame({
            "log_mean": self.log_mean(),
            "dispersion": self.dispersion()
        })

        if index is not None:
            parameters.index = index
            return parameters
        else:
            return parameters

    def log_mean(self):
        return np.log(self.alpha) - np.log(self.beta)

    def dispersion(self):
        return self.alpha  # shape parameter r

    def mean(self):
        return self.alpha / self.beta

    def log_likelihood(self, index: int, posterior: bool = False) -> float:
        """
        Calculates the log likelihood of the table.

        Args:
            index (int): The index of the table.
            posterior (bool, optional): A flag indicating whether to calculate the posterior. Defaults to False.

        Returns:
            float: The log likelihood of the table.
        """
        x = self.data[index]
        if posterior:
            alpha = self.alpha + x
            beta = self.beta + 1
        else:
            alpha = self.alpha
            beta = self.beta

        return self._gamma_poisson_log_likelihood(x, alpha, beta)

    def predict(self, count: np.ndarray) -> float:
        """
        Makes predictions using the table's data.

        Args:
            count (np.ndarray): A 1D array representing the count.

        Returns:
            float: The prediction.
        """
        return self._gamma_poisson_log_likelihood(count, self.alpha, self.beta)

# Load some count data

In [204]:
import numpy as np
import scanpy as sc

adata = sc.read_h5ad("/home/jhaberbe/Data/choroid-plexus/new_annotations.h5ad")
adata = adata[adata.obs["Cell.Subtype"].eq("Macrophage")]
adata = adata[adata.X.sum(axis=1) > 300]
sc.pp.highly_variable_genes(adata, flavor="seurat_v3")

counts = adata[:, adata.var.highly_variable].X.todense()
size_factors = np.log(counts.sum(axis=1) / counts.sum(axis=1).mean())

  adata.uns["hvg"] = {"flavor": flavor}


# Our inference machinery

First, we learn the base distribution, which we'll use as our intercept. We then pull that out. We will be performing updates to classes iteratively.

Now, our goal is to use these latent features to then learn the individual features. Each latent feature will be a coefficient, we then update based on class membership.

In [303]:
import torch
import torch.nn as nn
from torch.distributions import NegativeBinomial

class LatentFeature:
    def __init__(self, X, device):
        self.members = set()
        self.log_mu = nn.Parameter(torch.zeros(X.shape[1], device=device))  # On GPU

    def add_member(self, index):
        self.members.add(index)

    def remove_member(self, index):
        self.members.discard(index)

    def num_members(self):
        return len(self.members)

class IBP:
    def __init__(self, X, alpha, device=None):
        self.device = device or X.device
        self.data = X.to(self.device)  # Ensure data is on GPU
        self.N, self.D = X.shape
        self.alpha = alpha

        self.size_factor = (X.sum(dim=1) / X.sum(dim=1).mean()).log().to(self.device)

        self.latent_features = {}
        self.membership = {i: set() for i in range(self.N)}

        self.log_mu_intercept = nn.Parameter(torch.zeros(self.D, device=self.device))
        self.log_disp = nn.Parameter(torch.ones(self.D, device=self.device))

        self.null_ll_cache = {}
        self.optimizer = torch.optim.Adam([self.log_mu_intercept, self.log_disp], lr=0.1)

    def add_class(self):
        k = 0
        while k in self.latent_features:
            k += 1
        self.latent_features[k] = LatentFeature(self.data, device=self.device)
        return k

    def compute_logits(self, sample_index):
        logits = self.size_factor[sample_index] + self.log_mu_intercept
        for k in self.membership[sample_index]:
            logits += self.latent_features[k].log_mu
        return logits

    def log_likelihood(self, sample_index):
        logits = self.compute_logits(sample_index)
        x = self.data[sample_index]
        theta = self.log_disp.exp()
        logit_param = logits - (theta + logits.exp()).log()
        nb = NegativeBinomial(total_count=theta, logits=logit_param)
        return nb.log_prob(x).sum()

    def fit_intercept(self, n_iter=100):
        for _ in trange(n_iter):
            self.optimizer.zero_grad()
            loss = 0
            for i in range(self.N):
                logits = self.size_factor[i] + self.log_mu_intercept
                theta = self.log_disp.exp()
                x = self.data[i]
                logit_param = logits - (theta + logits.exp()).log()
                nb = NegativeBinomial(total_count=theta, logits=logit_param)
                loss -= nb.log_prob(x).sum()
            loss.backward()
            self.optimizer.step()

    def assignment(self, index):
        x = self.data[index]
        self.null_ll_cache[index] = self.log_likelihood(index)

        # Remove index from current features
        to_remove = list(self.membership[index])
        for k in to_remove:
            self.membership[index].remove(k)
            self.latent_features[k].remove_member(index)

            # Delete latent feature if empty
            if self.latent_features[k].num_members() == 0:
                del self.latent_features[k]
        
        base_ll = self.null_ll_cache[index]

        # Reconsider assignment to existing features
        existing_keys = list(self.latent_features.keys())  # avoid mutation during loop
        for k in existing_keys:
            self.membership[index].add(k)
            self.latent_features[k].add_member(index)

            proposed_ll = self.log_likelihood(index)
            log_p = proposed_ll - base_ll
            p = torch.sigmoid(log_p)

            if torch.rand(1, device=self.device).item() < p:
                base_ll = proposed_ll
            else:
                self.membership[index].remove(k)
                self.latent_features[k].remove_member(index)

                if self.latent_features[k].num_members() == 0:
                    del self.latent_features[k]

        # Sample new features
        lambda_new = self.alpha / self.N
        num_new = torch.poisson(torch.tensor(lambda_new, device=self.device)).item()
        for _ in range(int(num_new)):
            new_k = self.add_class()
            self.latent_features[new_k].add_member(index)
            self.membership[index].add(new_k)

        # Optional: Update parameters locally
        self.optimizer.zero_grad()
        loss = -self.log_likelihood(index)
        loss.backward()
        self.optimizer.step()


In [None]:
ibp = IBP(X=torch.tensor(X, dtype=torch.float32).to("cuda"), alpha=1.0, device="cuda")
ibp.fit_intercept()

for step in range(100):
    for i in trange(X.shape[0]):
        ibp.assignment(i)

  ibp = IBP(X=torch.tensor(X, dtype=torch.float32).to("cuda"), alpha=1.0, device="cuda")
 11%|â–ˆ         | 11/100 [00:06<00:52,  1.71it/s]

#### Explicit formula for the value of p(Z)

In [164]:
from scipy.special import gammaln
def feature_probability(feature, total = 1000):
    structure_term = (feature.alpha / feature.n_classes)

    term_1 = (np.log(structure_term) + gammaln(len(feature.members) + structure_term)) 
    term_2 = gammaln(total - len(feature.members) + 1)
    term_3 = gammaln(total + 1 + structure_term)
    
    return term_1 + term_2 - term_3

def total_probability(feature_dict):
    return np.sum([
        feature_probability(feature_dict[key]) 
        for key in feature_dict 
    ])


In [168]:
total_probability(latent_features)

np.float64(-6439.465591024762)