In [52]:
import pyro
import torch
import patsy
import numpy as np
import scanpy as sc

adata = sc.read_h5ad("/home/jhaberbe/Projects/using_parameters_instead/data/16APR2025.h5ad")
adata.obs.query("folder == '05-27' and cell_type == 'Microglia-PVM'")
adata.obs["log_lipid_droplet_area"] = np.log1p(adata.obs["lipid_droplet_area"])

  utils.warn_names_duplicates("obs")


In [53]:
exog = patsy.dmatrix("log_lipid_droplet_area + near_amyloid", adata.obs, return_type="dataframe")
exog = torch.tensor(exog.values)

In [54]:
intercept = exog[:, 0].float()
log_lipid_droplet_area = exog[:, 1].float()
amyloid = exog[:, 2].float()
counts = torch.tensor(adata[:, "APOE"].layers["transcript"].reshape(-1)).float()
all_counts = torch.tensor(adata.layers["transcript"]).float()
size_factor = torch.tensor(np.log(adata.layers["transcript"].sum(axis=1) / adata.layers["transcript"].sum(axis=1).mean()))

  size_factor = torch.tensor(np.log(adata.layers["transcript"].sum(axis=1) / adata.layers["transcript"].sum(axis=1).mean()))


# Defining our model

In [88]:
import torch
import pyro
import pyro.contrib.gp as gp
from torch.distributions import constraints
import pyro.distributions as dist

import torch
import pyro
import pyro.contrib.gp as gp
from torch.distributions import constraints
import pyro.distributions as dist

def model(amyloid, log_lipid_droplet_area, size_factor, counts):
    device = log_lipid_droplet_area.device

    # Intercept
    intercept = pyro.sample("intercept", dist.Normal(
        torch.tensor(0.0, device=device), torch.tensor(5.0, device=device)))

    # Near Amyloid effect
    amyloid_loc = pyro.param("amyloid", torch.tensor(1.0, device=device))

    # Sparse GP
    kernel = gp.kernels.RBF(
        input_dim=1,
        variance=log_lipid_droplet_area.var(),
        lengthscale=torch.tensor(.1, device=device)
    )
    Xu = torch.linspace(0, 7, 10, device=device).unsqueeze(-1)  # (10, 1)
    gpr = gp.models.VariationalSparseGP(
        log_lipid_droplet_area, Xu=Xu, y=None, kernel=kernel, likelihood=None
    )

    ld_loc, ld_scale = gpr(log_lipid_droplet_area)
    with pyro.plate("data", log_lipid_droplet_area.shape[0]):
        ld_sampled = pyro.sample("ld_sampled", dist.Normal(ld_loc, ld_scale))

    # Logit calculation
    logit = intercept + (amyloid * amyloid_loc) + (log_lipid_droplet_area * ld_sampled) + size_factor
    logit = 20 * torch.tanh(logit / 20)

    # Dispersion
    r = pyro.param("r", torch.tensor(1.0, device=device), constraint=constraints.positive)

    with pyro.plate("observations", len(counts)):
        pyro.sample("counts", dist.NegativeBinomial(total_count=r, logits=logit), obs=counts)


def guide(amyloid, log_lipid_droplet_area, size_factor, counts):
    device = log_lipid_droplet_area.device

    # Intercept
    intercept_loc = pyro.param("intercept_loc", torch.tensor(0.0, device=device))
    intercept_scale = pyro.param("intercept_scale", torch.tensor(1.0, device=device), constraint=constraints.positive)
    pyro.sample("intercept", dist.Normal(intercept_loc, intercept_scale))

    # Amyloid
    amyloid_loc = pyro.param("amyloid_loc", torch.tensor(0.0, device=device))
    amyloid_scale = pyro.param("amyloid_scale", torch.tensor(1.0, device=device), constraint=constraints.positive)
    pyro.sample("amyloid", dist.Normal(amyloid_loc, amyloid_scale))

    # Sparse GP (for consistency, same kernel setup)
    kernel = gp.kernels.RBF(
        input_dim=1,
        variance=log_lipid_droplet_area.var(),
        lengthscale=torch.tensor(0.05, device=device)
    )
    Xu = torch.linspace(0, 7, 10, device=device).unsqueeze(-1)
    gpr = gp.models.VariationalSparseGP(
        log_lipid_droplet_area, Xu=Xu, y=None, kernel=kernel, likelihood=None
    )
    gpr.guide()

    # Dispersion
    r_loc = pyro.param("r_loc", torch.tensor(1.0, device=device), constraint=constraints.positive)
    r_scale = pyro.param("r_scale", torch.tensor(1.0, device=device), constraint=constraints.positive)
    pyro.sample("r", dist.LogNormal(r_loc, r_scale))


In [89]:
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO

pyro.clear_param_store()

device = torch.device("cpu")

adam = pyro.optim.Adam({"lr": .01})
svi = SVI(model, guide, adam, loss=JitTrace_ELBO())

for step in range(1000):
    loss = svi.step(amyloid.to(device), log_lipid_droplet_area.to(device), size_factor.to(device), counts.to(device))
    if step % 2 == 0:
        print(f"Step {step}: Loss = {loss}")

{'amyloid', 'u', 'r'}
  intercept_loc = pyro.param("intercept_loc", torch.tensor(0.0, device=device))
  intercept_scale = pyro.param("intercept_scale", torch.tensor(1.0, device=device), constraint=constraints.positive)
  amyloid_loc = pyro.param("amyloid_loc", torch.tensor(0.0, device=device))
  amyloid_scale = pyro.param("amyloid_scale", torch.tensor(1.0, device=device), constraint=constraints.positive)
  lengthscale=torch.tensor(0.05, device=device)
  eye.view(-1)[: min(m, n) * n : n + 1] = 1
  r_loc = pyro.param("r_loc", torch.tensor(1.0, device=device), constraint=constraints.positive)
  r_scale = pyro.param("r_scale", torch.tensor(1.0, device=device), constraint=constraints.positive)
  torch.tensor(0.0, device=device), torch.tensor(5.0, device=device)))
  amyloid_loc = pyro.param("amyloid", torch.tensor(1.0, device=device))
  lengthscale=torch.tensor(.1, device=device)
  if X.size(1) != Z.size(1):
  r = pyro.param("r", torch.tensor(1.0, device=device), constraint=constraints.posit

Step 0: Loss = 3478736.6501526376
Step 2: Loss = 3341995.3751685717
Step 4: Loss = 3177742.6264904006
Step 6: Loss = 3125752.372673844
Step 8: Loss = 3075860.1234712563
Step 10: Loss = 3024977.0566820204
Step 12: Loss = 3002045.6886047595
Step 14: Loss = 2874864.2679418996
Step 16: Loss = 3025971.294016575
Step 18: Loss = 2852453.760690337
Step 20: Loss = 2707901.586193286
Step 22: Loss = 2704472.8686299473
Step 24: Loss = 2595439.1647904045
Step 26: Loss = 2578527.26710815
Step 28: Loss = 2535223.0239046705
Step 30: Loss = 2469773.6904095765
Step 32: Loss = 2471122.431046986
Step 34: Loss = 2385397.2027744185
Step 36: Loss = 2361206.8523398163
Step 38: Loss = 2240716.7416331423
Step 40: Loss = 2445570.558201366
Step 42: Loss = 2162063.2421009745
Step 44: Loss = 2196585.639430426
Step 46: Loss = 2117160.076130147
Step 48: Loss = 2196425.103969982
Step 50: Loss = 1989566.6950190696
Step 52: Loss = 2318719.0472960128
Step 54: Loss = 1883485.0870610012
Step 56: Loss = 1880533.645572285
St

KeyboardInterrupt: 

In [90]:
dict(pyro.get_param_store())

{'intercept_loc': tensor(0.8028, requires_grad=True),
 'intercept_scale': tensor(0.4316, grad_fn=<AddBackward0>),
 'amyloid_loc': tensor(0., requires_grad=True),
 'amyloid_scale': tensor(7.4633, grad_fn=<AddBackward0>),
 'u_loc': Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True),
 'u_scale_tril': tensor([[ 0.1334,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [-0.0118,  0.0807,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0089, -0.0052,  0.0932,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [-0.0090,  0.0038, -0.0037,  0.1098,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0086, -0.0050,  0.0051, -0.0056,  0.1247,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [-0.0063,  0.0040, -0.0054,  0.0067, -0.0081,  0.1391,  0.0000,  0.0000,
           0.0000,  0.00

# Sampling our data

In [71]:
import torch
import pyro
import pyro.contrib.gp as gp
from torch.distributions import constraints
import pyro.distributions as dist

def model(amyloid, log_lipid_droplet_area, size_factor, counts):
    device = counts.device
    N, K = counts.shape

    # Intercept: one per feature
    with pyro.plate("intercept_passes", K):
        intercept = pyro.sample("intercept", dist.Normal(torch.zeros(1, device=device), 5.0))

    # Amyloid effect: learnable weight per feature
    amyloid_loc = pyro.param("amyloid", torch.ones(K, device=device))

    # Sparse GP
    kernel = gp.kernels.RBF(
        input_dim=1,
        variance=log_lipid_droplet_area.var().to(device),
        lengthscale=torch.tensor(0.05, device=device)
    )
    gpr = gp.models.VariationalSparseGP(
        log_lipid_droplet_area.unsqueeze(-1),
        Xu=torch.linspace(0, 7, 10, device=device).unsqueeze(-1),
        y=None,
        kernel=kernel,
        likelihood=None
    )
    ld_loc, ld_scale = gpr(log_lipid_droplet_area.unsqueeze(-1))
    with pyro.plate("ld_data", N):
        ld_sampled = pyro.sample("ld_sampled", dist.Normal(ld_loc, ld_scale))  # shape [N]

    # Expand: ld_sampled: [N] → [N, K]
    ld_effect = ld_sampled.unsqueeze(-1) * log_lipid_droplet_area.unsqueeze(-1)

    # Logit: [N, K]
    print(amyloid)
    print(amyloid_loc)
    logit = intercept + (amyloid.unsqueeze(-1) * amyloid_loc.unsqueeze(0)) + ld_effect + size_factor.unsqueeze(-1)
    logit = torch.clamp(logit, -20.0, 20.0)

    # Dispersion (can be shared or per-feature)
    r = pyro.param("r", torch.ones(K, device=device), constraint=constraints.positive)

    with pyro.plate("feature", K), pyro.plate("individual", N):
        pyro.sample("counts", dist.NegativeBinomial(total_count=r, logits=logit), obs=counts)

def guide(amyloid, log_lipid_droplet_area, size_factor, counts):
    device = counts.device
    N, K = counts.shape

    # Intercept
    intercept_loc = pyro.param("intercept_loc", torch.zeros(K, device=device))
    intercept_scale = pyro.param("intercept_scale", torch.ones(K, device=device), constraint=constraints.positive)
    pyro.sample("intercept", dist.Normal(intercept_loc, intercept_scale))

    # Amyloid
    amyloid_loc = pyro.param("amyloid_loc", torch.zeros(K, device=device))
    amyloid_scale = pyro.param("amyloid_scale", torch.ones(K, device=device), constraint=constraints.positive)
    pyro.sample("amyloid", dist.Normal(amyloid_loc, amyloid_scale))

    # GP
    kernel = gp.kernels.RBF(
        input_dim=1,
        variance=log_lipid_droplet_area.var().to(device),
        lengthscale=torch.tensor(0.05, device=device)
    )
    gpr = gp.models.VariationalSparseGP(
        log_lipid_droplet_area.unsqueeze(-1),
        Xu=torch.linspace(0, 7, 10, device=device).unsqueeze(-1),
        y=None,
        kernel=kernel,
        likelihood=None
    )
    gpr.guide()

    # Dispersion (can also be sampled)
    r_loc = pyro.param("r_loc", torch.ones(K, device=device), constraint=constraints.positive)
    r_scale = pyro.param("r_scale", torch.ones(K, device=device), constraint=constraints.positive)
    pyro.sample("r", dist.LogNormal(r_loc, r_scale))


In [None]:
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO

pyro.clear_param_store()

device = torch.device("cpu")

adam = pyro.optim.Adam({"lr": .01})
svi = SVI(model, guide, adam, loss=JitTrace_ELBO())

amyloid = amyloid.to(device)
log_lipid_droplet_area = log_lipid_droplet_area.to(device)size_factor.to(device), all_counts.to(device)

for step in range(100):
    loss = svi.step()
    if step % 2 == 0:
        print(f"Step {step}: Loss = {loss}")

tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
 

ValueError: at site "intercept", invalid log_prob shape
  Expected [], actual [366]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions