In [67]:
import torch
import pathlib
import numpy as np
import scanpy as sc
import pandas as pd
import anndata as ad

adata = sc.read_h5ad("/home/jhaberbe/Projects/nb-gaussian-processes/data/16APR2025.h5ad")
subset = adata[adata.obs["cell_type"].eq("L2/3 IT") & adata.obs["folder"].eq("05-27")]
coordinates = torch.tensor(subset.obs[["x_centroid", "y_centroid"]].values).float().to("cuda")
counts = torch.tensor(subset.layers["transcript"]).float().to("cuda")
size_factor = torch.log(counts.sum(axis=1) / counts.sum(axis=1).mean()).float().to("cuda")

  utils.warn_names_duplicates("obs")


In [None]:
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.nn.module import PyroSample
from pyro.contrib.gp.kernels import Matern52
from pyro.contrib.gp.models import VariationalSparseGP
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO
from pyro.optim import Adam
from sklearn.cluster import KMeans

import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim, hidden_dims=[64]):
        super().__init__()
        layers = []
        dims = [latent_dim] + hidden_dims
        for in_dim, out_dim in zip(dims[:-1], dims[1:]):
            layers.append(nn.Linear(in_dim, out_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, z):
        return self.model(z)

class SpatialNegativeBinomialGP:
    def __init__(self, coordinates, counts, size_factor, num_inducing=100, learning_rate=1e-3, jitter=1e-3, seed=0, device="cuda"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.coordinates = coordinates.to(self.device)
        self.counts = counts.to(self.device)
        self.size_factor = size_factor.to(self.device)
        self.num_inducing = num_inducing
        self.learning_rate = learning_rate
        self.jitter = jitter

        # Init randomness
        pyro.set_rng_seed(seed)
        pyro.clear_param_store()

        # Inducing inputs via KMeans
        self.inducing_inputs = self._init_inducing_inputs()

        # Kernel setup
        self.kernel = Matern52(input_dim=self.coordinates.size(1))
        self.kernel.lengthscale = pyro.param("lengthscale", torch.tensor(1.0, device=self.device), constraint=constraints.positive)
        self.kernel.variance = pyro.param("variance", torch.tensor(1.0, device=self.device), constraint=constraints.positive)

        # GP model
        self.model = VariationalSparseGP(
            X=self.coordinates,
            y=None,
            kernel=self.kernel,
            Xu=self.inducing_inputs,
            likelihood=None,
            latent_shape=torch.Size([16]),
            whiten=True,
            jitter=self.jitter,
        )

        self.decoder = Decoder(latent_dim=16, output_dim=counts.shape[1]).to(self.device)

        # Optimizer + SVI
        self.optimizer = Adam({"lr": self.learning_rate})
        self.svi = SVI(self.model_fn, self.guide_fn, self.optimizer, loss=JitTrace_ELBO())

    def _init_inducing_inputs(self):
        kmeans = KMeans(n_clusters=self.num_inducing).fit(self.coordinates.cpu().numpy())
        return torch.tensor(kmeans.cluster_centers_, dtype=torch.float32, device=self.device)

    def model_fn(self, coordinates, counts, size_Factor):
        N, K = counts.shape

        # Feature-wise dispersion parameter (K-dimensional)
        r = pyro.param("r", torch.ones(counts.shape[1], device=self.device), constraint=constraints.positive)

        f_loc, f_var = self.model(coordinates)
        z_sample = pyro.distributions.Normal(f_loc, (f_var + 1e-5).sqrt()).rsample()
        logits = self.decoder(z_sample.T)  # or .T if shape mismatch
        logits = size_factor.unsqueeze(-1) + logits

        with pyro.plate("cells", counts.shape[0], dim=-2):
            with pyro.plate("features", counts.shape[1], dim=-1):
                r_broadcast = r.repeat(N, K)  # broadcast r to match logits/counts
                pyro.sample("counts", dist.NegativeBinomial(total_count=r_broadcast, logits=logits), obs=counts)

    def guide_fn(self, coordinates, counts, size_factor):
        return self.model.guide()

    def train(self, num_steps=1000, verbose=True):
        for i in range(num_steps):
            loss = self.svi.step(self.coordinates, self.counts, self.size_factor)
            if verbose and i % 10 == 0:
                r_val = pyro.param("r").item()
                print(f"[{i}] ELBO: {-loss:.2f} | r={r_val:.3f}")

    def predict(self, Xnew=None):
        Xnew = self.coordinates if Xnew is None else Xnew.to(self.device)
        with torch.no_grad():
            f_loc, f_var = self.model(Xnew)
            return f_loc, f_var

    def get_dispersion(self):
        return pyro.param("r").item()

    def get_kernel_params(self):
        return {
            "lengthscale": self.kernel.lengthscale.item(),
            "variance": self.kernel.variance.item(),
        }


In [127]:
snbgp = SpatialNegativeBinomialGP(coordinates=coordinates, counts=counts, size_factor=size_factor)
snbgp.train()

RuntimeError: The size of tensor a (133956) must match the size of tensor b (366) at non-singleton dimension 1
     Trace Shapes:            
      Param Sites:            
                 r         366
                Xu     100   2
             u_loc      16 100
      u_scale_tril 16  100 100
kernel.lengthscale            
   kernel.variance            
     Sample Sites:            
        cells dist           |
             value    5121   |
     features dist           |
             value     366   |

# Lets be broader

In [None]:
import pyro
import pyro.distributions as dist
import torch
import pyro.contrib.gp as gp
from pyro.contrib.gp.models import VariationalSparseGP, VariationalGP
from pyro.nn import PyroModule, PyroSample
from pyro.infer import SVI, Trace_ELBO
from torch import nn

class VAE_GP(PyroModule):
    def __init__(self, coordinates, latent_dim=16, device=None):
        super().__init__()
        self.latent_dim = latent_dim

        # Set the device (CPU or GPU)
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Move coordinates to the same device
        coordinates = torch.tensor(coordinates).float().to(self.device)

        # Define the GP kernel and sparse GP
        self.kernel = gp.kernels.RBF(coordinates.shape[1])
        self.likelihood = gp.likelihoods.Gaussian()

        # Variational sparse GP setup
        self.sparse_gp = gp.models.VariationalGP(coordinates, y=None, kernel=self.kernel, latent_shape=torch.Size([latent_dim]), likelihood=self.likelihood)

        # Decoder (negative binomial likelihood)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Softplus()  # Ensure positive output for count data
        ).to(self.device)  # Move decoder to the same device

    def model(self, coordinates, counts):
        # Ensure coordinates and counts are on the correct device
        coordinates = coordinates.to(self.device)
        counts = counts.to(self.device)

        # Sample latent function (GP output) at the coordinates
        f = self.sparse_gp(coordinates)

        # Decoder to generate counts using a negative binomial
        mu = self.decoder(f)  # The mean of the negative binomial
        theta = torch.exp(mu)  # theta is exp(mu) for the negative binomial

        with pyro.plate("data", len(coordinates)):
            # Negative binomial distribution for counts
            pyro.sample("counts", dist.NegativeBinomial(total_count=1.0, logits=mu), obs=counts)

    def guide(self, coordinates, counts):
        pass

    def train(self, coordinates, counts, num_steps=1000):
        # Ensure coordinates and counts are on the correct device
        coordinates = torch.tensor(coordinates).float().to(self.device)
        counts = torch.tensor(counts).float().to(self.device)

        adam = pyro.optim.Adam({"lr": -1.001})
        svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO())

        for step in range(num_steps):
            loss = svi.step(coordinates, counts)
            if step % 99 == 0:
                print(f"Step {step}: Loss = {loss}")



  from .autonotebook import tqdm as notebook_tqdm


In [159]:
vae_gp = VAE_GP(coordinates)

vae_gp.train(coordinates, counts)

  coordinates = torch.tensor(coordinates).float().to(self.device)
  coordinates = torch.tensor(coordinates).float().to(self.device)
  counts = torch.tensor(counts).float().to(self.device)


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
               Trace Shapes:             
                Param Sites:             
             sparse_gp.f_loc      16 5121
      sparse_gp.f_scale_tril 16 5121 5121
sparse_gp.kernel.lengthscale             
   sparse_gp.kernel.variance             
               Sample Sites:             

In [162]:
dict(pyro.get_param_store())

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.56 GiB. GPU 0 has a total capacity of 23.51 GiB of which 303.81 MiB is free. Process 207331 has 702.00 MiB memory in use. Including non-PyTorch memory, this process has 21.49 GiB memory in use. Of the allocated memory 19.67 GiB is allocated by PyTorch, and 1.34 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)