# Reading the Data

In [17]:
import scanpy as sc
adata = sc.read_h5ad("../data/output-dgi-10-10-20MAY2025.h5ad")
adata.layers["counts"] = adata.layers["counts"].todense()
adata = adata[adata.obs["folder"].eq("05-27")]

  utils.warn_names_duplicates("obs")


# Spatial Dataset Generator

In [20]:
import torch
import torch_geometric
from torch_geometric.data import Data, Dataset

import numpy as np
from scipy.spatial import cKDTree

class GraphDatasetGenerator:
    def __init__(self, adata):
        self.adata = adata

    def generate_graph_data(self):
        # Edge index
        edge_index = self.create_edge_index()

        data = Data(
            # Expression
            x=torch.tensor(self.adata.X).float(), 
            counts = torch.tensor(self.adata.layers["counts"]).float(),

            # Batch information
            batch = torch.tensor(self.adata.obs["folder"].cat.codes),

            # Spatial 
            edge_index=edge_index,

            # Pathology
            plin2 = torch.tensor(self.adata.obs["plin2_area"]).float(),
            oro = torch.tensor(self.adata.obs["oil_red_o_area"]).float(),
            lipid_droplet = torch.tensor(self.adata.obs["lipid_droplet_area"]).float(),

            # Distance to nearest amyloid
            distance_to_nearest_amyloid = torch.tensor(self.adata.obs["lipid_droplet_area"]).float()
        )

        return data

    def create_edge_index(self):
        tree = cKDTree(
            self.adata.obs[['x_centroid', 'y_centroid']].values
        )

        _, neighbors = tree.query(
            self.adata.obs[['x_centroid', 'y_centroid']].values, 
            k=31
        )

        rows = np.repeat(
            np.arange(
                len(adata.obs[['x_centroid', 'y_centroid']].values)
            ), 
            30
        )
        cols = neighbors[:, 1:].reshape(-1)
        return torch.tensor([rows, cols], dtype=torch.long)

data = GraphDatasetGenerator(adata).generate_graph_data()

  batch = torch.tensor(self.adata.obs["folder"].cat.codes),
  plin2 = torch.tensor(self.adata.obs["plin2_area"]).float(),
  oro = torch.tensor(self.adata.obs["oil_red_o_area"]).float(),
  lipid_droplet = torch.tensor(self.adata.obs["lipid_droplet_area"]).float(),
  distance_to_nearest_amyloid = torch.tensor(self.adata.obs["lipid_droplet_area"]).float()


# Constructing our model

In [116]:
import torch
from torch import nn

class Encoder(nn.Module):
    def __init__(self, input_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [120]:
class EmbeddingEncoder(nn.Module):
    def __init__(self, input_dim):
        super(EmbeddingEncoder, self).__init__()

        # Expression encoder
        self.expression_encoder = Encoder(input_dim)

        # Pathology heads
        def mlp():
            return nn.Sequential(
                nn.Linear(1, 64),
                nn.ReLU(),
                nn.Linear(64, 8)
            )

        self.pathology_head = nn.ModuleDict({
            "oil_red_o": mlp(),
            "plin2": mlp(),
            "lipid_droplet": mlp(),
            "distance_to_amyloid": mlp()
        })

        self.embedding_head = nn.Linear(
            in_features=input_dim + 8 * 4,
            out_features=input_dim + 8 * 4
        )

    def forward(self, data):
        # Encode expression
        expression = self.expression_encoder(data.x)

        # Encode pathology
        pathology = torch.cat([
            self.pathology_head["oil_red_o"](data.oro.unsqueeze(1)),
            self.pathology_head["plin2"](data.plin2.unsqueeze(1)),
            self.pathology_head["lipid_droplet"](data.lipid_droplet.unsqueeze(1)),
            self.pathology_head["distance_to_amyloid"](data.distance_to_nearest_amyloid.unsqueeze(1)),
        ], dim=1)

        # Concatenate all embeddings
        full_embedding = torch.cat([expression, pathology], dim=1)
        
        return full_embedding

In [121]:
model = EmbeddingEncoder(
    input_dim=data.x.shape[1], 
)
output = model(data)

In [124]:
output

tensor([[-0.0613, -0.0887,  0.0274,  ..., -0.4865,  0.2525, -0.5796],
        [-0.0609, -0.2936, -0.0724,  ..., -0.4865,  0.2525, -0.5796],
        [ 0.0424, -0.0849,  0.0429,  ..., -7.1261,  2.8584,  0.4881],
        ...,
        [ 0.0263, -0.2360,  0.1897,  ..., -0.4865,  0.2525, -0.5796],
        [-0.0336, -0.1834,  0.1013,  ..., -0.4865,  0.2525, -0.5796],
        [ 0.1763,  0.0120,  0.0502,  ..., -0.4865,  0.2525, -0.5796]],
       grad_fn=<CatBackward0>)

# Transformer Layer

In [None]:
class TokenGTTransformerLayer(nn.Module):

    def __init__(self, input_dim, num_heads, dropout=0.1):
        super(TokenGTTransformerLayer, self).__init__()
        self.self_attn = nn.MultiHeadAttention(
            embed_dim = input_dim,
            num_heads = num_heads,
            batch_first = True
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout)
        )


    def forward(self, x, attn_mask=None):
        # x: [B, T, D]
        residual = x
        x, _ = self.self_attn(x, x, x, attn_mask=attn_mask)  # self-attention over tokens
        x = self.norm1(x + residual)

        residual = x
        x = self.ff(x)
        x = self.norm2(x + residual)
        return x

In [None]:
class TokenGT(nn.Module):

    def __init__(self, input_dim):
        super(TokenGT, self).__init__()

        self.embedding_encoder = EmbeddingEncoder(input_dim)

# Getting to go

In [None]:
nn.CrossEntropyLoss()(
    data.distance_to_nearest_amyloid < 60,
    data
)

CrossEntropyLoss()

In [None]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch_geometric.utils import to_dense_batch

from tqdm import tqdm

class TokenGT(nn.Module):
    def __init__(self, input_dim, num_batches, num_nodes, d_model=128, nhead=8, num_layers=4):
        super(TokenGT, self).__init__()
        self.d_model = d_model

        # === Node Encoder ===
        self.node_encoder = EmbeddingEncoder(input_dim, num_batches)

        # === Edge Encoder ===
        self.edge_encoder = nn.Sequential(
            nn.Linear(2 * d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

        # === Type Embedding: 0=node, 1=edge, 2=graph token ===
        self.type_embedding = nn.Embedding(3, d_model)

        # === Position (Node ID) Embedding ===
        self.position_embedding = nn.Embedding(num_nodes, d_model)

        # === [GRAPH] token ===
        self.graph_token = nn.Parameter(torch.randn(1, 1, d_model))

        # === Transformer Encoder ===
        encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)

        # === Prediction Head ===
        self.output_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)  # Adjust output dimension as needed
        )

    def forward(self, data):
        B = data.x.shape[0]
        N = data.x.size(0)
        E = data.edge_index.size(1)

        # === Node Tokens ===
        node_embed = self.node_encoder(data)  # (N, d_model)

        # === Edge Tokens ===
        src = data.edge_index[0]
        dst = data.edge_index[1]
        edge_embed = self.edge_encoder(torch.cat([node_embed[src], node_embed[dst]], dim=1))  # (E, d_model)

        # === Type and Position Embedding ===
        node_type = self.type_embedding(torch.zeros(N, dtype=torch.long, device=node_embed.device))  # type 0
        edge_type = self.type_embedding(torch.ones(E, dtype=torch.long, device=node_embed.device))   # type 1
        node_pos = self.position_embedding(data.node_id if hasattr(data, "node_id") else torch.arange(N, device=node_embed.device))

        node_tokens = node_embed + node_type + node_pos
        edge_tokens = edge_embed + edge_type

        # === [GRAPH] Token ===
        graph_token = self.graph_token.expand(B, -1, -1)  # (B, 1, d_model)
        
        # === Batch Assembly ===
        tokens = torch.cat([node_tokens, edge_tokens], dim=0)  # (N + E, d_model)
        batch_vec = torch.cat([data.batch, data.batch[src]])   # Match token order to batch size

        # === Dense Batch for Transformer ===
        token_batch, mask = to_dense_batch(tokens, batch_vec)  # (B, T, d_model)

        # Prepend [GRAPH] token
        token_batch = torch.cat([graph_token, token_batch], dim=1)  # (B, T+1, d_model)
        mask = torch.cat([torch.ones((B, 1), dtype=torch.bool, device=mask.device), mask], dim=1)

        # === Transformer ===
        out = self.transformer(token_batch, src_key_padding_mask=~mask)  # (B, T+1, d_model)

        # === Predict from [GRAPH] token ===
        graph_repr = out[:, 0]  # (B, d_model)
        return self.output_head(graph_repr)


class MultiTaskOutputHeads(nn.Module):
    def __init__(self, in_dim, num_pathologies, num_nb_outputs=1, per_output_dispersion=False):
        super().__init__()

        self.num_nb_outputs = num_nb_outputs

        # --- Count prediction (Negative Binomial) ---
        self.count_mu_head = nn.Linear(in_dim, num_nb_outputs)  # predicts log_mu: (B, num_nb_outputs)

        if per_output_dispersion:
            self.log_dispersion = nn.Parameter(torch.zeros(num_nb_outputs))  # separate for each output
        else:
            self.log_dispersion = nn.Parameter(torch.tensor(0.0))  # shared

        # --- Hurdle Log-Normal for pathologies ---
        self.hurdle_zero_logits = nn.Linear(in_dim, num_pathologies)
        self.hurdle_log_mu = nn.Linear(in_dim, num_pathologies)
        self.hurdle_log_sigma = nn.Linear(in_dim, num_pathologies)

        # --- Binary classification: proximity to amyloid ---
        self.amyloid_head = nn.Linear(in_dim, 1)

    def forward(self, center_token):
        # --- NB count ---
        log_mu = self.count_mu_head(center_token)  # (B, num_nb_outputs)
        if self.log_dispersion.ndim == 0:
            log_disp = self.log_dispersion.expand_as(log_mu)  # scalar -> (B, K)
        else:
            log_disp = self.log_dispersion.unsqueeze(0).expand(center_token.size(0), -1)  # (B, K)

        # --- Hurdle Log-Normal ---
        zero_logits = self.hurdle_zero_logits(center_token)
        hln_log_mu = self.hurdle_log_mu(center_token)
        hln_log_sigma = self.hurdle_log_sigma(center_token)

        # --- Binary prediction ---
        amyloid_logit = self.amyloid_head(center_token)

        return {
            "nb": {
                "log_mu": log_mu,
                "log_dispersion": log_disp
            },
            "hurdle_lognorm": {
                "zero_logits": zero_logits,
                "log_mu": hln_log_mu,
                "log_sigma": hln_log_sigma
            },
            "amyloid": {
                "logits": amyloid_logit
            }
        }

from torch_geometric.utils import k_hop_subgraph

def get_subgraph(data, center_nodes, num_hops=1):
    subset, edge_index, mapping, edge_mask = k_hop_subgraph(
        node_idx=center_nodes,
        num_hops=num_hops,
        edge_index=data.edge_index,
        relabel_nodes=True
    )
    sub_data = data.subgraph(subset)
    sub_data.edge_index = edge_index
    sub_data.edge_mask = edge_mask
    sub_data.node_mapping = mapping
    return sub_data

token_gt = TokenGT(data.x.shape[1], 
        adata.obs["folder"].nunique(), 
        data.x.shape[0],
        d_model=96, 
        nhead=8, 
        num_layers=4)


  1%|▏         | 966/69552 [00:04<05:31, 206.60it/s]


KeyboardInterrupt: 

In [95]:

mtoh = MultiTaskOutputHeads(96, 3, data.x.shape[1])

batch_size = 1
device = torch.device("cpu")

example = []
for center in tqdm(torch.randperm(data.num_nodes)):
    sub = get_subgraph(data, center_nodes=[center])
    out = model(sub.to(device))[sub.node_mapping[0]]
    mtoh_out = mtoh(out)
    nb_likelihood = NegativeBinomial(
        logits = mtoh_out['nb']['log_mu'],
        total_count = mtoh_out['nb']['log_dispersion'].exp()
    ).log_prob(data.counts[center]).sum()

    pathology_likelihood = hurdle_normal_log_prob(
        zero_logits = mtoh_out["hurdle_lognorm"]["zero_logits"],
        log_mu = mtoh_out["hurdle_lognorm"]["log_mu"],
        log_sigma = mtoh_out["hurdle_lognorm"]["log_sigma"],
        data = torch.stack([
            data.lipid_droplet,
            data.plin2,
            data.oro,
        ]).T[center]
    )

    loss = nb_likelihood + pathology_likelihood

  0%|          | 4/69552 [00:00<22:30, 51.48it/s]


ValueError: Expected parameter scale (Tensor of shape (3,)) of distribution Normal(loc: torch.Size([3]), scale: torch.Size([3])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([2.8416e-06, 0.0000e+00, 2.3265e-20], grad_fn=<ExpBackward0>)

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

# --- Assume you already have these ---
# model: your main GNN or encoder
# mtoh: MultiTaskOutputHeads(96, 3, data.x.shape[1])
# data: graph with .counts, .lipid_droplet, etc.
# get_subgraph: function for extracting neighborhood
# hurdle_normal_log_prob: your custom function
# NegativeBinomial: a distribution class (e.g., from torch.distributions or custom)

# --- Training setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
mtoh = mtoh.to(device)

optimizer = Adam(list(model.parameters()) + list(mtoh.parameters()), lr=1e-3)
num_epochs = 10

# --- Training loop ---
for epoch in range(num_epochs):
    model.train()
    mtoh.train()

    total_loss = 0.0
    num_nodes = data.num_nodes

    for center in tqdm(torch.randperm(num_nodes), desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()

        # 1. Subgraph extraction
        sub = get_subgraph(data, center_nodes=[center])
        sub = sub.to(device)

        # 2. Forward pass through model
        out = model(sub)[sub.node_mapping[0]]  # center node embedding
        mtoh_out = mtoh(out)

        # 3. Count loss (Negative Binomial)
        nb_likelihood = NegativeBinomial(
            logits=mtoh_out['nb']['log_mu'],
            total_count=mtoh_out['nb']['log_dispersion'].exp()
        ).log_prob(data.counts[center].to(device)).sum()

        # 4. Pathology loss (Hurdle Normal)
        pathology_data = torch.stack([
            data.lipid_droplet,
            data.plin2,
            data.oro,
        ], dim=1)[center].to(device)

        pathology_likelihood = hurdle_normal_log_prob(
            zero_logits=mtoh_out["hurdle_lognorm"]["zero_logits"],
            log_mu=mtoh_out["hurdle_lognorm"]["log_mu"],
            log_sigma=mtoh_out["hurdle_lognorm"]["log_sigma"],
            data=pathology_data
        ).sum()

        # 5. Total negative log-likelihood (we *minimize* -logprob)
        loss = -(nb_likelihood + pathology_likelihood)

        # 6. Backward + optimizer step
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / num_nodes
    print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")


Epoch 1:   0%|          | 1/69552 [00:00<7:14:18,  2.67it/s]


ValueError: Expected parameter logits (Tensor of shape (366,)) of distribution NegativeBinomial(total_count: torch.Size([366]), logits: torch.Size([366])) to satisfy the constraint Real(), but found invalid values:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<ViewBackward0>)

In [56]:
import torch
import torch.nn.functional as F
from torch.distributions import Normal
from torch.distributions import NegativeBinomial

def hurdle_normal_log_prob(zero_logits, log_mu, log_sigma, data):
    """
    Compute log-likelihood under a Hurdle Normal model.

    Args:
        zero_logits (Tensor): Logits for P(data > 0)
        log_mu (Tensor): Log mean of the normal distribution
        log_sigma (Tensor): Log std dev of the normal distribution
        data (Tensor): Observed values (same shape as other inputs)

    Returns:
        Tensor: log-likelihoods (same shape as input tensors)
    """
    sigma = log_sigma.exp()
    mu = log_mu.exp()
    is_zero = (data == 0)
    
    # P(data = 0) = sigmoid(-zero_logits)
    log_prob_zero = F.logsigmoid(-zero_logits)

    # P(data > 0) = sigmoid(zero_logits) * Normal.log_prob(data)
    normal = Normal(loc=mu, scale=sigma)
    log_prob_nonzero = F.logsigmoid(zero_logits) + normal.log_prob(data)

    return torch.where(is_zero, log_prob_zero, log_prob_nonzero)
