In [1]:
import numpy as np
import scanpy as sc

adata = sc.read_h5ad("/home/jhaberbe/Projects/using_parameters_instead/data/16APR2025.h5ad")
adata.obs["log_oil_red_o_area"] = np.log1p(adata.obs["oil_red_o_area"])
adata.obs["log_lipid_droplet_area"] = np.log1p(adata.obs["lipid_droplet_area"])

  utils.warn_names_duplicates("obs")


# Oil-Red-O

In [2]:
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.nn.module import PyroSample
from pyro.contrib.gp.kernels import Matern52
from pyro.contrib.gp.models import VariationalSparseGP
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO
from pyro.optim import Adam
from sklearn.cluster import KMeans
from torch.functional import F


class SpatialNegativeBinomialGP:
    def __init__(self, coordinates, y, num_inducing=100, learning_rate=1e-2, jitter=1e-3, seed=0, device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.coordinates = coordinates.to(self.device)
        self.y = y.to(self.device)
        self.num_inducing = num_inducing
        self.learning_rate = learning_rate
        self.jitter = jitter

        # Init randomness
        pyro.set_rng_seed(seed)
        pyro.clear_param_store()

        # Inducing inputs via KMeans
        self.inducing_inputs = self._init_inducing_inputs()

        # Kernel setup
        self.kernel = Matern52(input_dim=self.coordinates.size(1))
        self.kernel.lengthscale = pyro.param("lengthscale", torch.tensor(1.0, device=self.device), constraint=constraints.positive)
        self.kernel.variance = pyro.param("variance", torch.tensor(1.0, device=self.device), constraint=constraints.positive)

        # GP model for mu (continuous lognormal part)
        self.model_mu = VariationalSparseGP(
            X=self.coordinates,
            y=None,
            kernel=self.kernel,
            Xu=self.inducing_inputs,
            likelihood=None,
            latent_shape=torch.Size([1]),
            whiten=True,
            jitter=self.jitter,
        ).to(self.device)
        self.model_mu._pyro_name = "model_mu"

        # GP model for logit (zero probability)
        self.model_logit = VariationalSparseGP(
            X=self.coordinates,
            y=None,
            kernel=self.kernel,
            Xu=self.inducing_inputs,
            likelihood=None,
            latent_shape=torch.Size([1]),
            whiten=True,
            jitter=self.jitter,
        ).to(self.device)
        self.model_logit._pyro_name = "model_logit"

        # Optimizer + SVI
        self.optimizer = Adam({"lr": self.learning_rate})
        self.svi = SVI(self.model_fn, self.guide_fn, self.optimizer, loss=JitTrace_ELBO())

    def _init_inducing_inputs(self):
        kmeans = KMeans(n_clusters=self.num_inducing).fit(self.coordinates.cpu().numpy())
        return torch.tensor(kmeans.cluster_centers_, dtype=torch.float32, device=self.device)

    def model_fn(self, coordinates, y):
        mu, scale = self.model_mu(coordinates)
        logit, _ = self.model_logit(coordinates)

        mu = mu.squeeze(-1)
        scale = F.softplus(scale.squeeze(-1)) + 1e-4
        probs = torch.sigmoid(logit.squeeze(-1))

        with pyro.plate("data", coordinates.shape[0]):
            is_zero = pyro.sample("is_zero", dist.Bernoulli(probs))
            pyro.sample("obs_pos", dist.Normal(mu, scale).to_event(1), obs=y)


    def guide_fn(self, coordinates, y):
        return self.model_mu.guide(), self.model_logit.guide()

    def train(self, num_steps=1000, verbose=True):
        for i in range(num_steps):
            loss = self.svi.step(self.coordinates, self.y)
            if verbose and i % 200 == 0:
                print(f"[{i}] ELBO: {-loss:.2f}")

    def predict(self, Xnew=None):
        Xnew = self.coordinates if Xnew is None else Xnew.to(self.device)
        with torch.no_grad():
            logit_loc, _ = self.model_logit(Xnew)
            mu_loc, _ = self.model_mu(Xnew)
            return mu_loc.squeeze(-1), torch.sigmoid(logit_loc.squeeze(-1))

    def get_kernel_params(self):
        return {
            "lengthscale": self.kernel.lengthscale.item(),
            "variance": self.kernel.variance.item(),
        }


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
import matplotlib.pyplot as plt

for folder in adata.obs["folder"].unique():
    for cell_type in adata.obs["cell_type"].unique():
        subset = adata.obs.query("folder == @folder and cell_type == @cell_type")

        try:
            coordinates = torch.tensor(subset[["x_centroid", "y_centroid"]].values).float().to("cuda")
            y = torch.tensor(subset[["log_oil_red_o_area"]].values).float().to("cuda")

            snbgp = SpatialNegativeBinomialGP(coordinates=coordinates, y=y)
            snbgp.train()

            mean = snbgp.predict(coordinates)[0]
            prob = snbgp.predict(coordinates)[1]

            subset["mean"] = mean.cpu().detach().numpy().reshape(-1)
            subset["prob"] = prob.cpu().detach().numpy().reshape(-1)
        
            try:
                os.mkdir(f"/home/jhaberbe/Projects/using_parameters_instead/output/pathology-kde/{folder}")
            except:
                pass

            try:
                os.mkdir(f"/home/jhaberbe/Projects/using_parameters_instead/output/pathology-kde/{folder}/{cell_type.replace('/', '-')}")
            except:
                pass

            subset.plot.scatter(x="x_centroid", y="y_centroid", c="mean", cmap="bwr", s=2)
            plt.savefig(f"/home/jhaberbe/Projects/using_parameters_instead/output/pathology-kde/{folder}/{cell_type.replace('/', '-')}/oil-red-o-mean.svg")
            plt.close()

            subset.plot.scatter(x="x_centroid", y="y_centroid", c="prob", cmap="bwr", s=2)
            plt.savefig(f"/home/jhaberbe/Projects/using_parameters_instead/output/pathology-kde/{folder}/{cell_type.replace('/', '-')}/oil-red-o-prob.svg")
            plt.close()

        except:
            pass


In [None]:
import os
import matplotlib.pyplot as plt

for folder in adata.obs["folder"].unique():
    for cell_type in adata.obs["cell_type"].unique():
        subset = adata.obs.query("folder == @folder and cell_type == @cell_type")

        try:
            coordinates = torch.tensor(subset[["x_centroid", "y_centroid"]].values).float().to("cuda")
            y = torch.tensor(subset[["log_lipid_droplet_area"]].values).float().to("cuda")

            snbgp = SpatialNegativeBinomialGP(coordinates=coordinates, y=y)
            snbgp.train()

            mean = snbgp.predict(coordinates)[0]
            prob = snbgp.predict(coordinates)[1]

            subset["mean"] = mean.cpu().detach().numpy().reshape(-1)
            subset["prob"] = prob.cpu().detach().numpy().reshape(-1)
        
            try:
                os.mkdir(f"/home/jhaberbe/Projects/using_parameters_instead/output/ld-kde/{folder}")
            except:
                pass

            try:
                os.mkdir(f"/home/jhaberbe/Projects/using_parameters_instead/output/ld-kde/{folder}/{cell_type.replace('/', '-')}")
            except:
                pass

            subset.plot.scatter(x="x_centroid", y="y_centroid", c="mean", cmap="bwr", s=2)
            plt.savefig(f"/home/jhaberbe/Projects/using_parameters_instead/output/ld-kde/{folder}/{cell_type.replace('/', '-')}/ld-mean.svg")
            plt.close()

            subset.plot.scatter(x="x_centroid", y="y_centroid", c="prob", cmap="bwr", s=2)
            plt.savefig(f"/home/jhaberbe/Projects/using_parameters_instead/output/ld-kde/{folder}/{cell_type.replace('/', '-')}/ld-prob.svg")
            plt.close()

        except:
            pass
