In [1]:
import wandb
import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO

from pyro.contrib import gp
from pyro.contrib.gp.kernels import Matern52
from pyro.contrib.gp.util import conditional

from IPython.display import clear_output
import pandas as pd


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import scanpy as sc
adata = sc.read_h5ad("/home/jhaberbe/Projects/spatial-indian-buffet-process/data/16APR2025.h5ad")
adata = adata[adata.obs["cell_type"].eq("Astrocyte") & adata.obs["folder"].eq("99-15")]

  utils.warn_names_duplicates("obs")


In [3]:
import torch
import numpy as np
import pandas as pd

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    print("CUDA Enabled")

device = torch.device("cpu")

def setup_torch_data(adata, specimen_name: str = "folder"):
    X = torch.tensor(adata[:, (adata.X > 0).mean(axis=0) > 0.05].layers["transcript"])
    size_factor = torch.tensor(np.log((adata.X.sum(axis=1) / adata.X.sum(axis=1).mean())))
    folder = torch.tensor(pd.Categorical(adata.obs[specimen_name]).codes).float()
    return X, size_factor, folder

X, size_factor, folder = setup_torch_data(adata, specimen_name="folder")

X = X.to(device)
size_factor = size_factor.to(device)
folder = folder.to(device)
coordinates = torch.tensor(adata.obs[["x_centroid", "y_centroid"]].values).to(device)
group_assignments = torch.zeros(X.shape[0])

In [4]:
import torch

def select_inducing_points_knn(X: torch.Tensor, M: int) -> torch.Tensor:
    """
    Select M inducing points from X using a farthest-point style KNN heuristic.

    Args:
        X: (N, 2) tensor of 2D coordinates
        M: number of inducing points to select

    Returns:
        (M, 2) tensor of inducing points
    """
    N = X.shape[0]
    inducing_idx = [torch.randint(0, N, (1,)).item()]  # Start with one random point

    for _ in range(1, M):
        dist_matrix = torch.cdist(X[inducing_idx], X)  # (len(inducing_idx), N)
        min_dists, _ = torch.min(dist_matrix, dim=0)   # (N,)
        farthest_idx = torch.argmax(min_dists).item()
        inducing_idx.append(farthest_idx)

    return X[inducing_idx]  # (M, 2)


In [5]:
class SpatialIndianBuffetProcess():

    def __init__(self, adata, n_latent_factors: int = 20, device: str = None, low_rank_approximation_dimension: int = 20, length_scale: int = 100):

        self.feature_names = adata.var_names

        # identify device
        if device == None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else: 
            self.device = device

        # Setup as a tensor, makes life easier, we coerce to a float for ease.
        self.n_latent_factors = torch.tensor(n_latent_factors).float().to(self.device)

        # Low rank dimension of MVN approximmation sampling.
        self.low_rank_approximation_dimension = low_rank_approximation_dimension

        # Determines the degree of correlation between factors due purely based on their distance
        # TODO: Make some heuristic, probably based on co-occurence within some random sasmple.
        self.length_scale = length_scale
    
    def model(self, coordinates, count_matrix, group_assignments):
        N, D = count_matrix.shape
        K = int(self.n_latent_factors.item())
        G = int(torch.max(group_assignments).item() + 1)

        size_factor = torch.log(count_matrix.float().sum(axis=1) / count_matrix.float().sum(axis=1).mean())

        # scaling mu is the strongest impact. 0 cuts a good balance, higher causes K to scale rapidly
        mu = pyro.param("mu", torch.tensor(0.0, device=self.device, dtype=torch.float32))
        tau = pyro.param("tau", torch.tensor(1.0, device=self.device, dtype=torch.float32), constraint=constraints.positive)
        phi = pyro.param("phi", torch.tensor(self.length_scale, device=self.device, dtype=torch.float32), constraint=constraints.positive)

        # Setting up GP (MVN draw).
        kernel = Matern52(input_dim=2, lengthscale=phi)
        cov_matrix = kernel(coordinates)
        scaled_cov_matrix = (1.0 / tau) * cov_matrix
        mean_vec = mu * torch.ones(N, device=self.device, dtype=torch.float32)

        # Sample stick-breaking v_k
        alpha = 1.0
        with pyro.plate("features", K):
            v_k = pyro.sample("v_k", dist.Beta(torch.tensor(1.0, device=self.device), alpha))
            pi_k = torch.cumprod(v_k, dim=0)

        # Sample GP latent features
        with pyro.plate("latent_features", K):
            # Low Rank Adaptation to avoid destroying my device.
            cov_factor = pyro.param("cov_factor", torch.randn(K, N, self.low_rank_approximation_dimension, device=self.device)).to(self.device)
            cov_diag = pyro.param("cov_diag", torch.ones(K, N, device=self.device) * 1e-2, constraint=dist.constraints.positive).to(self.device)

            u_k = pyro.sample(
                "u_k",
                dist.LowRankMultivariateNormal(
                    loc=mean_vec.expand(K, N),       
                    cov_factor=cov_factor,           
                    cov_diag=cov_diag                
                )
            )
        u_k_T = u_k.transpose(0, 1)

        # Sample Z
        with pyro.plate("observations", N):
            z = pyro.sample("z", dist.Bernoulli(probs=pi_k.expand(N, K)).to_event(1))  # [N, K]

        W = pyro.sample(
            "W",
            dist.Normal(torch.zeros(K, D, device=self.device, dtype=torch.float32),
                        torch.ones(K, D, device=self.device, dtype=torch.float32)).to_event(2)
        )

        # Setting up NB draw
        # Multiple folders (when that becomes a problem)
        folder_logit = pyro.param(
            "folder_logit",
            torch.zeros(G, D, device=self.device, dtype=torch.float32)
        )
        r = pyro.sample(
            "r",
            dist.Gamma(torch.full((D,), 2.0, device=self.device, dtype=torch.float32), # 0 makes more sense, right?
                    torch.full((D,), 1.0, device=self.device, dtype=torch.float32)).to_event(1)
        )

        # Compute latent feature contribution to expression.
        # features = z * sigmoid_to_interval(u_k_T) # FIXME: is this interval bounding needed, testing without.
        features = z * u_k_T
        logits = features @ W 
        logits = logits + size_factor.reshape(-1, 1)
        logits = logits + folder_logit[group_assignments.long()]
        # logits = torch.clamp(logits, -15, 15) # FIXME: I don't think clamping is needed here.

        with pyro.plate("data", N):
            # Sample NB.
            pyro.sample("count_matrix", dist.NegativeBinomial(total_count=r, logits=logits).to_event(1), obs=count_matrix)

    def guide(self, coordinates, count_matrix, group_assignments):
        N, D = count_matrix.shape
        K = int(self.n_latent_factors.item())
        size_factor = torch.log(count_matrix.float().sum(axis=1) / count_matrix.float().sum(axis=1).mean())

        # Variational distribution for u_k (latent GPs): mean-field Gaussian
        u_loc = pyro.param(
            "u_loc",
            torch.zeros(K, N, device=self.device),
        )
        u_scale = pyro.param(
            "u_scale",
            0.1 * torch.ones(K, N, device=self.device),
            constraint=constraints.positive
        )
        with pyro.plate("latent_features", K):
            pyro.sample("u_k", dist.Normal(u_loc, u_scale).to_event(1))

        # Variational distribution for stick-breaking IBP: Beta
        v_alpha_q = pyro.param(
            "v_alpha_q",
            torch.ones(K, device=self.device),
            constraint=constraints.positive
        )
        v_beta_q = pyro.param(
            "v_beta_q",
            torch.ones(K, device=self.device),
            constraint=constraints.positive
        )
        with pyro.plate("features", K):
            pyro.sample("v_k", dist.Beta(v_alpha_q, v_beta_q))

        # Variational distribution for W: mean-field Normal
        W_loc = pyro.param(
            "W_loc",
            torch.zeros(K, D, device=self.device)-5,
        )
        W_scale = pyro.param(
            "W_scale",
            0.1 * torch.ones(K, D, device=self.device),
            constraint=constraints.positive
        )
        pyro.sample("W", dist.Normal(W_loc, W_scale).to_event(2))

        # Variational distribution for r: Gamma
        r_alpha_q = pyro.param(
            "r_alpha_q",
            torch.full((D,), 2.0, device=self.device),
            constraint=constraints.positive
        )
        r_beta_q = pyro.param(
            "r_beta_q",
            torch.full((D,), 1.0, device=self.device),
            constraint=constraints.positive
        )
        pyro.sample("r", dist.Gamma(r_alpha_q, r_beta_q).to_event(1))

    def fit(self, coordinates, count_matrix, group_assignments, num_steps=300_000, lr=0.01, clear_param_store = True, wandb_kwargs = {}):
        if clear_param_store == True:
            pyro.clear_param_store()

        wandb.init(**wandb_kwargs, settings=wandb.Settings(_disable_stats=True), reinit=True)

        optimizer = ClippedAdam({"lr": lr, "clip_norm": 5.0})

        svi = SVI(
            model=self.model,
            guide=self.guide,
            optim=optimizer,
            loss=JitTrace_ELBO(num_particles=1),
        )

        try: 
            for step in range(num_steps):
                loss = svi.step(
                    coordinates=coordinates.to(self.device),
                    count_matrix=count_matrix.to(self.device),
                    group_assignments=group_assignments.to(self.device),
                )
                wandb.log({
                    "loss": loss,
                    "mean_logit": pyro.get_param_store()["u_loc"].mean().item(),
                    "feature_sparsity": (pyro.get_param_store()["u_loc"] > 0).float().mean().item(),
                })

                if step % 100 == 0 or step == num_steps - 1:
                    clear_output()
                    print(f"[{step:04d}] ELBO loss: {loss:.2f}")
        except KeyboardInterrupt:
            print("Interrupted Training")
        
        finally:
            wandb.finish()
    
    def return_latent_features(self):
        params = dict(pyro.get_param_store())
        latent_features = pd.DataFrame(
            dict(pyro.get_param_store())["u_loc"].cpu().detach().numpy(),
        ).T
        return latent_features
    
    def return_latent_feature_weights(self):
        weights = pd.DataFrame(
            dict(pyro.get_param_store())["W_loc"].cpu().detach().numpy(),
            columns = self.feature_names
        ).T
        return weights


In [6]:
sibp = SpatialIndianBuffetProcess(adata, device="cpu")

In [7]:
sibp.fit(coordinates, X, group_assignments, num_steps=100_000, clear_param_store=False)

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.
[34m[1mwandb[0m: [32m[41mERROR[0m Problem finishing run


[73900] ELBO loss: 878703.12
Interrupted Training


MailboxClosedError: 