# Reading the Data

In [None]:
import scanpy as sc
adata = sc.read_h5ad("../data/05-27.h5ad")



In [None]:
# Convert sparse matrix to dense if necessary
adata.X = adata.X.todense()
adata.layers["counts"] = adata.X.copy()

# Normalize and log-transform the data
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  return fn(*args_all, **kw)


In [100]:
adata.obs["batch"] = 0

# Spatial Dataset Generator

In [101]:
import torch
import torch_geometric
from torch_geometric.data import Data, Dataset

import numpy as np
from scipy.spatial import cKDTree

class GraphDatasetGenerator:
    def __init__(self, adata):
        self.adata = adata

    def generate_graph_data(self):
        # Edge index
        edge_index = self.create_edge_index()

        data = Data(
            # Expression
            x=torch.tensor(self.adata.X).float(), 
            counts = torch.tensor(self.adata.layers["counts"]).float(),

            # Batch information
            batch = torch.tensor(self.adata.obs["batch"]),

            # Spatial 
            edge_index=edge_index,

            # Pathology
            plin2 = torch.tensor(self.adata.obs["plin2_area"]).float(),
            oro = torch.tensor(self.adata.obs["oil_red_o_area"]).float(),
            lipid_droplet = torch.tensor(self.adata.obs["lipid_droplet_area"]).float(),

            # Distance to nearest amyloid
            distance_to_nearest_amyloid = torch.tensor(self.adata.obs["lipid_droplet_area"]).float()
        )

        return data

    def create_edge_index(self):
        tree = cKDTree(
            self.adata.obs[['x_centroid', 'y_centroid']].values
        )

        _, neighbors = tree.query(
            self.adata.obs[['x_centroid', 'y_centroid']].values, 
            k=31
        )

        rows = np.repeat(
            np.arange(
                len(adata.obs[['x_centroid', 'y_centroid']].values)
            ), 
            30
        )
        cols = neighbors[:, 1:].reshape(-1)
        return torch.tensor([rows, cols], dtype=torch.long)

data = GraphDatasetGenerator(adata).generate_graph_data()

  batch = torch.tensor(self.adata.obs["batch"]),
  plin2 = torch.tensor(self.adata.obs["plin2_area"]).float(),
  oro = torch.tensor(self.adata.obs["oil_red_o_area"]).float(),
  lipid_droplet = torch.tensor(self.adata.obs["lipid_droplet_area"]).float(),
  distance_to_nearest_amyloid = torch.tensor(self.adata.obs["lipid_droplet_area"]).float()


# Constructing our model

In [None]:
import torch
from torch import nn

class Encoder(nn.Module):
    def __init__(self, input_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class EmbeddingEncoder(nn.Module):
    def __init__(self, input_dim, num_batches):
        super(EmbeddingEncoder, self).__init__()

        # Expression encoder
        self.expression_encoder = Encoder(input_dim)

        # Pathology heads
        def mlp():
            return nn.Sequential(
                nn.Linear(1, 64),
                nn.ReLU(),
                nn.Linear(64, 8)
            )

        self.pathology_head = nn.ModuleDict({
            "oil_red_o": mlp(),
            "plin2": mlp(),
            "lipid_droplet": mlp(),
            "distance_to_amyloid": mlp()
        })

    def forward(self, data):
        # Encode expression
        expression = self.expression_encoder(data.x)

        # Encode pathology
        pathology = torch.cat([
            self.pathology_head["oil_red_o"](data.oro.unsqueeze(1)),
            self.pathology_head["plin2"](data.plin2.unsqueeze(1)),
            self.pathology_head["lipid_droplet"](data.lipid_droplet.unsqueeze(1)),
            self.pathology_head["distance_to_amyloid"](data.distance_to_nearest_amyloid.unsqueeze(1)),
        ], dim=1)

        # Concatenate all embeddings
        full_embedding = torch.cat([expression, pathology], dim=1)
        return full_embedding

model = EmbeddingEncoder(
    input_dim=data.x.shape[1], 
    num_batches=adata.obs["batch"].nunique()
)
output = model(data)

torch.Size([69552, 96])

# Getting to go

In [None]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch_geometric.utils import to_dense_batch

class TokenGT(nn.Module):
    def __init__(self, input_dim, num_batches, num_nodes, d_model=128, nhead=8, num_layers=4):
        super(TokenGT, self).__init__()
        self.d_model = d_model

        # === Node Encoder ===
        self.node_encoder = EmbeddingEncoder(input_dim, num_batches)

        # === Edge Encoder ===
        self.edge_encoder = nn.Sequential(
            nn.Linear(2 * d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

        # === Type Embedding: 0=node, 1=edge, 2=graph token ===
        self.type_embedding = nn.Embedding(3, d_model)

        # === Position (Node ID) Embedding ===
        self.position_embedding = nn.Embedding(num_nodes, d_model)

        # === [GRAPH] token ===
        self.graph_token = nn.Parameter(torch.randn(1, 1, d_model))

        # === Transformer Encoder ===
        encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)

        # === Prediction Head ===
        self.output_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)  # Adjust output dimension as needed
        )

    def forward(self, data):
        B = data.x.shape[0]
        N = data.x.size(0)
        E = data.edge_index.size(1)

        # === Node Tokens ===
        node_embed = self.node_encoder(data)  # (N, d_model)

        # === Edge Tokens ===
        src = data.edge_index[0]
        dst = data.edge_index[1]
        edge_embed = self.edge_encoder(torch.cat([node_embed[src], node_embed[dst]], dim=1))  # (E, d_model)

        # === Type and Position Embedding ===
        node_type = self.type_embedding(torch.zeros(N, dtype=torch.long, device=node_embed.device))  # type 0
        edge_type = self.type_embedding(torch.ones(E, dtype=torch.long, device=node_embed.device))   # type 1
        node_pos = self.position_embedding(data.node_id if hasattr(data, "node_id") else torch.arange(N, device=node_embed.device))

        node_tokens = node_embed + node_type + node_pos
        edge_tokens = edge_embed + edge_type

        # === [GRAPH] Token ===
        graph_token = self.graph_token.expand(B, -1, -1)  # (B, 1, d_model)
        
        # === Batch Assembly ===
        tokens = torch.cat([node_tokens, edge_tokens], dim=0)  # (N + E, d_model)
        batch_vec = torch.cat([data.batch, data.batch[src]])   # Match token order to batch size

        # === Dense Batch for Transformer ===
        token_batch, mask = to_dense_batch(tokens, batch_vec)  # (B, T, d_model)

        # Prepend [GRAPH] token
        token_batch = torch.cat([graph_token, token_batch], dim=1)  # (B, T+1, d_model)
        mask = torch.cat([torch.ones((B, 1), dtype=torch.bool, device=mask.device), mask], dim=1)

        # === Transformer ===
        out = self.transformer(token_batch, src_key_padding_mask=~mask)  # (B, T+1, d_model)

        # === Predict from [GRAPH] token ===
        graph_repr = out[:, 0]  # (B, d_model)
        return self.output_head(graph_repr)


class MultiTaskOutputHeads(nn.Module):
    def __init__(self, in_dim, num_pathologies, num_nb_outputs=1, per_output_dispersion=False):
        super().__init__()

        self.num_nb_outputs = num_nb_outputs

        # --- Count prediction (Negative Binomial) ---
        self.count_mu_head = nn.Linear(in_dim, num_nb_outputs)  # predicts log_mu: (B, num_nb_outputs)

        if per_output_dispersion:
            self.log_dispersion = nn.Parameter(torch.zeros(num_nb_outputs))  # separate for each output
        else:
            self.log_dispersion = nn.Parameter(torch.tensor(0.0))  # shared

        # --- Hurdle Log-Normal for pathologies ---
        self.hurdle_zero_logits = nn.Linear(in_dim, num_pathologies)
        self.hurdle_log_mu = nn.Linear(in_dim, num_pathologies)
        self.hurdle_log_sigma = nn.Linear(in_dim, num_pathologies)

        # --- Binary classification: proximity to amyloid ---
        self.amyloid_head = nn.Linear(in_dim, 1)

    def forward(self, center_token):
        # --- NB count ---
        log_mu = self.count_mu_head(center_token)  # (B, num_nb_outputs)
        if self.log_dispersion.ndim == 0:
            log_disp = self.log_dispersion.expand_as(log_mu)  # scalar -> (B, K)
        else:
            log_disp = self.log_dispersion.unsqueeze(0).expand(center_token.size(0), -1)  # (B, K)

        # --- Hurdle Log-Normal ---
        zero_logits = self.hurdle_zero_logits(center_token)
        hln_log_mu = self.hurdle_log_mu(center_token)
        hln_log_sigma = self.hurdle_log_sigma(center_token)

        # --- Binary prediction ---
        amyloid_logit = self.amyloid_head(center_token)

        return {
            "nb": {
                "log_mu": log_mu,
                "log_dispersion": log_disp
            },
            "hurdle_lognorm": {
                "zero_logits": zero_logits,
                "log_mu": hln_log_mu,
                "log_sigma": hln_log_sigma
            },
            "amyloid": {
                "logits": amyloid_logit
            }
        }

from torch_geometric.utils import k_hop_subgraph

def get_subgraph(data, center_nodes, num_hops=1):
    subset, edge_index, mapping, edge_mask = k_hop_subgraph(
        node_idx=center_nodes,
        num_hops=num_hops,
        edge_index=data.edge_index,
        relabel_nodes=True
    )
    sub_data = data.subgraph(subset)
    sub_data.edge_index = edge_index
    sub_data.edge_mask = edge_mask
    sub_data.node_mapping = mapping
    return sub_data

token_gt = TokenGT(data.x.shape[1], 
        adata.obs["batch"].nunique(), 
        data.x.shape[0],
        d_model=96, 
        nhead=8, 
        num_layers=4)


batch_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

example = []
for center in tqdm(torch.randperm(data.num_nodes)):
    sub = get_subgraph(data, center_nodes=[center])
    out = model(sub.to(device))[sub.node_mapping[0]]
    example.append(mtoh(out))

  1%|          | 460/69552 [00:03<08:50, 130.29it/s]


KeyboardInterrupt: 