# Read Data

In [1]:
import numpy as np
import scanpy as sc

adata = sc.read_h5ad(
    "/home/jhaberbe/Projects/Personal/TokenGT/data/output-dgi-10-10-20MAY2025.h5ad"
)
adata = adata[adata.layers["transcript"].sum(axis=1) > 20].copy()
adata.obs["log_plin2_area"] = np.log1p(adata.obs["plin2_area"])
adata.obs["log_oil_red_o_area"] = np.log1p(adata.obs["oil_red_o_area"])
adata.obs["log_lipid_droplet_area"] = np.log1p(adata.obs["lipid_droplet_area"])

adata.X = adata.layers["transcript"].copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
adata.X seems to be already log-transformed.


In [2]:
from scipy.spatial import cKDTree
adata.obs["z_centroid"] = adata.obs["folder"].replace({
    k: i * 10_000
    for i, k in enumerate(adata.obs["folder"].unique())
})

ckd_tree = cKDTree(adata.obs[["x_centroid", "y_centroid", "z_centroid"]])

_, neighbor_indicies = ckd_tree.query(adata.obs[["x_centroid", "y_centroid", "z_centroid"]], k = 31)
neighbor_indicies = neighbor_indicies.tolist()

  adata.obs["z_centroid"] = adata.obs["folder"].replace({
  adata.obs["z_centroid"] = adata.obs["folder"].replace({


# Data Set 

In [None]:
import torch

class SpatialSingleCellDataSet:
    
    def __init__(
        self, 
        counts,
        log_normalized,
        plin2_area,
        oil_red_o_area,
        lipid_droplet_area,
        near_amyloid,
        neighbor_indices,
        specimen_ids
    ):
        # Gene Expression Information
        self.counts = self._to_tensor(counts, torch.float)
        self.log_normalized = self._to_tensor(log_normalized, torch.float)

        self.size_factors = (self.counts.sum(axis=1) / self.counts.sum(axis=1).mean()).log()

        # Pathology Information
        self.plin2_area = self._to_tensor(plin2_area, torch.float)
        self.oil_red_o_area = self._to_tensor(oil_red_o_area, torch.float)
        self.lipid_droplet_area = self._to_tensor(lipid_droplet_area, torch.float)
        self.near_amyloid = self._to_tensor(near_amyloid, torch.float)

        # Neighborhood Information
        self.specimen_ids = self._to_tensor(specimen_ids, torch.long)
        self.neighbor_indices = self._to_tensor(neighbor_indices, torch.long)

    @staticmethod
    def _to_tensor(x, dtype=torch.float):
        if isinstance(x, torch.Tensor):
            return x.detach().clone().to(dtype)
        else:
            return torch.tensor(x, dtype=dtype)

    def __len__(self):
        return self.counts.size(0)

    def __getitem__(self, idx):
        return {
            # Expression Information
            "counts": self.counts[idx],
            "log_normalized": self.log_normalized[idx],
            "size_factors": self.size_factors[idx],

            # Pathology Information
            "plin2_area": self.plin2_area[idx],
            "oil_red_o_area": self.oil_red_o_area[idx],
            "lipid_droplet_area": self.lipid_droplet_area[idx],
            "near_amyloid": self.near_amyloid[idx],

            # Neighborhood Information
            "neighbor_indices": self.neighbor_indices[idx],

            # Cell Metadata
            "specimen_ids": self.specimen_ids[idx],
        }

counts = torch.tensor(adata.layers["transcript"])
log_normalized = torch.tensor(adata.X)

plin2_area = torch.tensor(adata.obs["plin2_area"].values).log1p()
oil_red_o_area = torch.tensor(adata.obs["oil_red_o_area"].values).log1p()
lipid_droplet_area = torch.tensor(adata.obs["lipid_droplet_area"].values).log1p()
near_amyloid = torch.tensor(adata.obs["near_amyloid"].values).float()

neighbor_indices = torch.tensor(neighbor_indicies)
specimen_ids = torch.tensor(adata.obs["folder"].cat.codes.values)

dataset = SpatialSingleCellDataSet(
    counts,
    log_normalized,
    plin2_area,
    oil_red_o_area,
    lipid_droplet_area,
    near_amyloid,
    neighbor_indices,
    specimen_ids
)

input_data = dataset[0]

dataset[input_data["neighbor_indices"]]

In [121]:
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F

# Gradient Reversal Layer implementation
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)  # Identity forward pass

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None

def grad_reverse(x, lambda_=1.0):
    return GradReverse.apply(x, lambda_)

# Variational Encoder (unchanged)
class VariationalEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        self.mu = nn.Linear(hidden_dim, embedding_dim)
        self.log_var = nn.Linear(hidden_dim, embedding_dim)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        h = self.shared(x)
        log_mu = self.mu(h)
        log_var = self.log_var(h)
        return log_mu, log_var

In [122]:

# Spatial Decoder (same as before, you can copy your existing decoder here)
class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_genes, n_batches):
        super().__init__()
        self.batch_emb = nn.Embedding(n_batches, embedding_dim)
        self.shared = nn.Sequential(
            nn.Linear(embedding_dim * 2, hidden_dim),
            nn.ReLU()
        )
        self.nb_mu = nn.Linear(hidden_dim, n_genes)
        self.log_theta = nn.Parameter(torch.zeros(1))
        self.hurdle_logits = nn.ModuleDict({
            "plin2": nn.Linear(hidden_dim, 1),
            "oil_red_o": nn.Linear(hidden_dim, 1),
            "lipid_droplet": nn.Linear(hidden_dim, 1)
        })
        self.hurdle_mu = nn.ModuleDict({
            "plin2": nn.Linear(hidden_dim, 1),
            "oil_red_o": nn.Linear(hidden_dim, 1),
            "lipid_droplet": nn.Linear(hidden_dim, 1)
        })
        self.hurdle_log_var = nn.ModuleDict({
            "plin2": nn.Linear(hidden_dim, 1),
            "oil_red_o": nn.Linear(hidden_dim, 1),
            "lipid_droplet": nn.Linear(hidden_dim, 1)
        })
        self.near_amyloid_logit = nn.Linear(hidden_dim, 1)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, z, specimen_ids):
        batch_embedding = self.batch_emb(specimen_ids)
        h = self.shared(torch.cat([z, batch_embedding], dim=-1))
        log_mu_counts = self.nb_mu(h)
        log_theta = self.log_theta.expand_as(log_mu_counts)
        hurdle_out = {}
        for k in self.hurdle_logits.keys():
            hurdle_out[k] = {
                "logit_p": self.hurdle_logits[k](h),
                "mu": self.hurdle_mu[k](h),
                "log_var": self.hurdle_log_var[k](h)
            }
        near_amyloid_logit = self.near_amyloid_logit(h)
        return {
            "log_mu_counts": log_mu_counts,
            "log_theta": log_theta,
            "hurdle": hurdle_out,
            "near_amyloid_logit": near_amyloid_logit
        }

In [123]:
# Discriminator to predict specimen_ids from latent z
class Discriminator(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_batches):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_batches)
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, z):
        return self.net(z)


In [125]:
# Full VAE with adversarial batch correction
class VAEWithAdversarial(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim, n_genes, n_batches):
        super().__init__()
        self.encoder = VariationalEncoder(input_dim, hidden_dim, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_dim, n_genes, n_batches)
        self.discriminator = Discriminator(embedding_dim, hidden_dim // 2, n_batches)

    def reparameterize(self, log_mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return log_mu + eps * std

    def forward(self, input_data):
        x = input_data["log_normalized"]
        specimen_ids = input_data["specimen_ids"]

        log_mu, log_var = self.encoder(x)
        z = self.reparameterize(log_mu, log_var)

        outputs = self.decoder(z, specimen_ids)
        outputs["size_factors"] = input_data["size_factors"]
        return outputs, log_mu, log_var, z

    def discriminate(self, z, lambda_grl=1.0):
        # Apply gradient reversal on z before discriminator
        z_rev = grad_reverse(z, lambda_grl)
        logits = self.discriminator(z_rev)
        return logits


In [129]:
import torch
import torch.nn.functional as F

# 1) Negative Binomial loss for counts

def negative_binomial_loss(x, log_mu, log_theta, size_factors):
    """
    x: observed counts (integer tensor)
    log_mu: decoded log mean (float tensor)
    log_theta: scalar or tensor, shared dispersion (log scale)
    size_factors: tensor, log scale size factors (same shape as batch size)
    """
    # Adjust log_mu by adding size factors (log scale)
    # size_factors shape: (batch_size,)
    # log_mu shape: (batch_size, n_genes)
    log_mu_adj = log_mu + size_factors.unsqueeze(-1)  # broadcast size_factors

    theta = torch.exp(log_theta)  # dispersion

    # logits for NB parameterization: logits = log_mu - log(mu + theta)
    logits = log_mu_adj - torch.log(torch.exp(log_mu_adj) + theta)

    nb_dist = torch.distributions.NegativeBinomial(total_count=theta, logits=logits)
    # Negative log likelihood (sum over genes and batch)
    neg_log_likelihood = -nb_dist.log_prob(x).mean()

    return neg_log_likelihood


# 2) Hurdle normal loss for each pathology feature

def hurdle_normal_loss(x, logit_p, mu, log_var):
    p = torch.sigmoid(logit_p)
    is_zero = (x == 0).float()

    bern_loss = -(is_zero * torch.log(1 - p + 1e-8) + (1 - is_zero) * torch.log(p + 1e-8))

    std = torch.exp(0.5 * log_var)
    const = torch.log(torch.tensor(2 * torch.pi, device=x.device, dtype=x.dtype))
    gaussian_nll = 0.5 * ( ((x - mu) / std) ** 2 + 2 * torch.log(std) + const )
    gaussian_nll = gaussian_nll * (1 - is_zero)

    total_loss = (bern_loss + gaussian_nll).mean()
    return total_loss


# 3) Logistic loss for near_amyloid (binary classification)

def near_amyloid_loss(logits, labels):
    """
    logits: raw logits output from decoder (before sigmoid)
    labels: binary labels (0/1 float tensor)
    """
    loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels, reduction='sum')
    return loss


# 4) KL divergence between latent posterior and prior

def kl_divergence(log_mu, log_var):
    """
    Standard VAE KL divergence
    """
    kl = -0.5 * torch.sum(1 + log_var - log_mu.pow(2) - log_var.exp())
    return kl

def discriminator_loss(discriminator_logits, specimen_ids):
    """
    Cross-entropy loss for discriminator predicting specimen_ids.
    """
    return F.cross_entropy(discriminator_logits, specimen_ids)

def compute_total_loss(
    outputs,
    input_data,
    log_mu,
    log_var,
    discriminator_logits=None,
    weight_kl=1.0,
    weight_nb=1.0,
    weight_hurdle=1.0,
    weight_amyloid=1.0,
    weight_adv=0.0,
):
    """
    Computes total loss for VAE model, optionally including adversarial loss.

    If `discriminator_logits` is provided, compute adversarial loss and include weighted in total loss.
    """

    # Negative Binomial loss (counts)
    nb_loss = negative_binomial_loss(
        x=input_data["counts"],
        log_mu=outputs["log_mu_counts"],
        log_theta=outputs["log_theta"],
        size_factors=input_data["size_factors"]
    )

    # Hurdle normal losses (sum over 3 features)
    hurdle_loss = 0.0
    for key in ["plin2", "oil_red_o", "lipid_droplet"]:
        h = outputs["hurdle"][key]
        x = input_data[f"{key}_area"]
        hurdle_loss += hurdle_normal_loss(x, h["logit_p"], h["mu"], h["log_var"])

    # Near amyloid logistic loss
    amyloid_loss = near_amyloid_loss(outputs["near_amyloid_logit"], input_data["near_amyloid"])

    # KL divergence
    kl_loss = kl_divergence(log_mu, log_var)

    # Adversarial loss (optional)
    if discriminator_logits is not None:
        adv_loss = discriminator_loss(discriminator_logits, input_data["specimen_ids"])
    else:
        adv_loss = torch.tensor(0.0, device=log_mu.device)

    # Total loss
    total_loss = (
        weight_nb * nb_loss
        + weight_hurdle * hurdle_loss
        + weight_amyloid * amyloid_loss
        + weight_kl * kl_loss
        + weight_adv * adv_loss
    )

    return total_loss, {
        "total_loss": total_loss.item(),
        "nb_loss": nb_loss.item(),
        "hurdle_loss": hurdle_loss.item(),
        "amyloid_loss": amyloid_loss.item(),
        "kl_loss": kl_loss.item(),
        "adv_loss": adv_loss.item() if isinstance(adv_loss, torch.Tensor) else adv_loss
    }


In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 2048
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

vae = VAE(
    input_dim=366,
    hidden_dim=64,
    embedding_dim=16,
    n_genes=366,
    n_batches=12
).to(device)
vae.train()

discriminator = Discriminator(embedding_dim=16, hidden_dim=32, n_batches=12).to(device)
discriminator.train()

optimizer_vae = Adam(vae.parameters(), lr=1e-3)
optimizer_disc = Adam(discriminator.parameters(), lr=1e-3)

num_epochs = 10
adv_weight = 1.0
lambda_grl = 1.0

for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_idx, batch_samples in enumerate(data_loader):
        # Convert list of dicts to dict of batched tensors and send to device
        batch_data = {key: torch.stack([sample[key] for sample in batch_samples]).to(device) for key in batch_samples[0].keys()}

        # --- Step 1: Train discriminator ---
        optimizer_disc.zero_grad()
        with torch.no_grad():
            log_mu, log_var = vae.encoder(batch_data["log_normalized"])
            z = vae.reparameterize(log_mu, log_var)

        disc_logits = discriminator(z.detach())
        disc_loss = torch.nn.functional.cross_entropy(disc_logits, batch_data["specimen_ids"])
        disc_loss.backward()
        optimizer_disc.step()

        # --- Step 2: Train VAE ---
        optimizer_vae.zero_grad()

        log_mu, log_var = vae.encoder(batch_data["log_normalized"])
        z = vae.reparameterize(log_mu, log_var)

        outputs = vae.decoder(z, batch_data["specimen_ids"])
        outputs["size_factors"] = batch_data["size_factors"]

        z_rev = grad_reverse(z, lambda_grl)
        adv_logits = discriminator(z_rev)

        loss, loss_items = compute_total_loss(
            outputs=outputs,
            input_data=batch_data,
            log_mu=log_mu,
            log_var=log_var,
            discriminator_logits=adv_logits,
            weight_adv=adv_weight
        )

        loss.backward()
        optimizer_vae.step()

        running_loss += loss.item()

        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}], Loss: {loss.item():.4f}, Disc Loss: {disc_loss.item():.4f}")

    epoch_loss = running_loss / len(data_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {epoch_loss:.4f}")

In [134]:
import torch
from torch.utils.data import DataLoader

def extract_embeddings(model, dataset, batch_size=256, use_mean=True, device=None):
    """
    Extract latent embeddings for all samples in `dataset`.

    Args:
        model: trained VAE model with encoder
        dataset: dataset object (e.g. SpatialSingleCellDataSet)
        batch_size: batch size for DataLoader
        use_mean: if True, use encoder's mean (log_mu) as embedding,
                  else sample from latent distribution
        device: torch device (e.g. 'cuda' or 'cpu'), default auto-detect

    Returns:
        embeddings: Tensor of shape (n_samples, embedding_dim)
    """

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    model.to(device)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)

    all_embeddings = []

    with torch.no_grad():
        for batch_samples in loader:
            batch_data = {key: torch.stack([sample[key] for sample in batch_samples]).to(device) for key in batch_samples[0].keys()}

            log_mu, log_var = model.encoder(batch_data["log_normalized"])

            if use_mean:
                embeddings = log_mu
            else:
                std = torch.exp(0.5 * log_var)
                eps = torch.randn_like(std)
                embeddings = log_mu + eps * std

            all_embeddings.append(embeddings.cpu())

    return torch.cat(all_embeddings, dim=0)

token_embeddings = extract_embeddings(vae, dataset)

# Embeddings

In [141]:
token_embeddings[::100].shape

torch.Size([5798, 16])

In [151]:
import torch
from torch.utils.data import Dataset

class SpatialEmbeddingDataset(Dataset):
    def __init__(self, embeddings, neighbor_indices):
        """
        Args:
            embeddings: Tensor of shape (n_cells, embedding_dim)
            neighbor_indices: LongTensor of shape (n_cells, n_neighbors)
        """
        self.embeddings = embeddings
        self.neighbor_indices = neighbor_indices

    def __len__(self):
        return self.embeddings.size(0)

    def __getitem__(self, idx):
        """
        Returns:
            {
                "center": embedding of center cell (embedding_dim,),
                "neighbors": embeddings of neighbors (n_neighbors, embedding_dim),
                "center_idx": index of the center cell
            }
        """
        center_embedding = self.embeddings[idx]
        neighbor_idxs = self.neighbor_indices[idx]
        neighbor_embeddings = self.embeddings[neighbor_idxs]

        return {
            "center": center_embedding,
            "neighbors": neighbor_embeddings,
            "center_idx": idx
        }

spatial_embedding_dataset = SpatialEmbeddingDataset(
    embeddings=token_embeddings,
    neighbor_indices=neighbor_indices[:, 1:]
)

In [163]:
from nflows.flows import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform

def build_simple_flow(embedding_dim, n_transforms=4, hidden_dim=64):
    base_dist = StandardNormal([embedding_dim])
    transforms = []
    for _ in range(n_transforms):
        transforms.append(
            MaskedAffineAutoregressiveTransform(
                features=embedding_dim,
                hidden_features=hidden_dim
            )
        )
    # Combine into a proper composite transform
    transform = CompositeTransform(transforms)
    return Flow(transform, base_dist)


In [164]:
class SpatialTransformerFlow(nn.Module):
    def __init__(self, embedding_dim, n_heads=4, n_layers=2, n_neighbors=30):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.n_neighbors = n_neighbors

        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))

        # Positional encoding (simple learnable)
        self.positional_enc = nn.Parameter(torch.randn(1, n_neighbors + 2, embedding_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, nhead=n_heads, dim_feedforward=embedding_dim * 4, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Normalizing Flow for predicting next-token embedding
        self.flow = build_simple_flow(embedding_dim)

    def forward(self, center, neighbors):
        """
        center: (B, embedding_dim)
        neighbors: (B, n_neighbors, embedding_dim)
        """
        B = center.size(0)
        device = center.device

        # Start with [CLS] + center
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        current_seq = torch.cat([cls_tokens, center.unsqueeze(1)], dim=1)

        all_log_probs = []

        for step in range(neighbors.size(1)):
            seq = current_seq + self.positional_enc[:, :current_seq.size(1), :]

            h = self.transformer(seq)  # (B, seq_len, D)
            context = h[:, -1, :]  # last token representation

            # Predict probability of the next token using the flow
            next_token = neighbors[:, step, :]
            log_prob = self.flow.log_prob(next_token)  # unconditional for now
            all_log_probs.append(log_prob)

            # Append the ground truth neighbor to the sequence for next step
            current_seq = torch.cat([current_seq, next_token.unsqueeze(1)], dim=1)

        all_log_probs = torch.stack(all_log_probs, dim=1)  # (B, n_neighbors)
        return all_log_probs

def perplexity_loss(log_probs):
    """
    log_probs: (B, n_neighbors), log p(next_token)
    Returns: scalar loss (average perplexity)
    """
    nll = -log_probs  # negative log-likelihood per step
    entropy = nll.mean()  # average entropy across batch & steps
    perplexity = torch.exp(entropy)
    return perplexity

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam

model = SpatialTransformerFlow(embedding_dim=token_embeddings.size(1), n_neighbors=30)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=1e-4)

loader = DataLoader(spatial_embedding_dataset, batch_size=64, shuffle=True)

for epoch in range(10):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader):
        center = batch["center"].to(device)
        neighbors = batch["neighbors"].to(device)

        optimizer.zero_grad()
        log_probs = model(center, neighbors)
        loss = perplexity_loss(log_probs)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Avg Perplexity = {total_loss / len(loader):.4f}")