In [20]:
from torch_geometric.nn import GATv2Conv
import torch.nn as nn
import numpy as np

In [22]:
import anndata as ad
from scipy.sparse import issparse
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
import numpy as np
import scipy.sparse as sp


class AnnDataGraphDataset(Dataset):
    """
    Tabular AnnData dataset.

    Args:
        path (str): Path to the AnnData file.
        keep_genes (list): List of genes to keep.
        keep_cells (list): List of cells to keep.
        spatial_coords (list): Anndata.obs columns for spatial coordinates (for regression)
        cell_type (str): Anndata.obs column for cell types.
        max_order (int): Maximum order of neighbors to consider.
        d_threshold (float): Distance threshold (in mm) for considering neighbors.
    """

    def __init__(
        self,
        path,
        keep_genes=None,
        keep_cells=None,
        spatial_coords=["x_ccf", "y_ccf", "z_ccf"],
        cell_type="supertype",
        max_order=2,
        d_threshold=1000,
    ):
        super().__init__()
        self.path = path
        adata = ad.read_h5ad(self.path)
        assert "connectivities" in adata.obsp.keys(), "Spatial connectivities not found. Run `sc.pp.neighbors` first."
        assert "distances" in adata.obsp.keys(), "Spatial distances not found. Run `sc.pp.neighbors` first."

        # filter genes
        if keep_genes is not None:
            adata = adata[:, keep_genes].copy()
        else:
            keep_genes = get_non_blank_genes(adata)
            adata = adata[:, keep_genes].copy()

        # filter cells
        if keep_cells is not None:
            adata = adata[keep_cells, :].copy()

        self.adata = adata
        self.max_order = max_order
        self.d_threshold = d_threshold

        # create binary adjacency matrix without self-loops
        adj = self.adata.obsp["connectivities"].copy()
        adj = adj.astype(bool).astype(int)
        adj[self.adata.obsp["distances"] > self.d_threshold] = 0
        adj.setdiag(0)

        # create adjacency matrices up to max_order
        self.adj_matrices = {}
        self.adj_matrices[1] = adj.copy()
        if self.max_order > 1:
            for i in range(2, self.max_order + 1):
                self.adj_matrices[i] = adj.dot(self.adj_matrices[i - 1])

        self.spatial_coords = spatial_coords
        self.cell_type = cell_type
        self.cell_type_list = adata.obs[cell_type].cat.categories.tolist()
        self.cell_type_labelencoder = LabelEncoder()
        self.cell_type_labelencoder.fit(self.cell_type_list)
        self.data_issparse = issparse(adata.X)

    def get_neighbors(self, idx):
        nhood_idx = []
        for i in range(1, self.max_order + 1):
            nhood_idx.append(np.where(self.adj_matrices[i][idx, :].toarray().flatten())[0])
        nhood_idx = np.concatenate(nhood_idx, axis=0)
        nhood_idx = np.unique(np.concatenate([nhood_idx, [idx]]))
        return nhood_idx

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, idx):
        # get all neighbors
        nhood_idx = self.get_neighbors(idx)
        local_adj = self.adj_matrices[1][np.ix_(nhood_idx, nhood_idx)]
        edgelist = np.array(local_adj.nonzero()).T

        gene_exp = self.adata.X[nhood_idx, :]
        if self.data_issparse:
            gene_exp = gene_exp.toarray().astype(np.float32)
        xyz = self.adata.obs.iloc[nhood_idx][self.spatial_coords].values.astype(np.float32)
        celltype = self.cell_type_labelencoder.transform(self.adata.obs.iloc[nhood_idx][self.cell_type])
        return gene_exp, edgelist, celltype


def get_non_blank_genes(adata):
    keep_genes = adata.var[~adata.var.index.str.startswith("Blank")].index
    return keep_genes

In [23]:
from datetime import datetime
from pathlib import Path


def get_paths(verbose: bool = False) -> dict:
    """
    Get custom paths from config.toml that is in the root directory.
    """

    # get path of this file
    root_path = Path("../")
    config_path = root_path / "config.toml"
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = toml.load(config_path)
    config["package_root"] = root_path
    if verbose:
        print(config)
    return config


def get_datetime(expname: str = ""):
    datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    if expname is None:
        expname = datetime_str
    else:
        expname = f"{datetime_str}_{expname}"
    return expname


def get_adata(path: str):
    adata = ad.read_h5ad(path)
    adata.obsm["ccf"] = np.concatenate(
        (
            np.expand_dims(np.array(adata.obs["x_ccf"]), axis=1),
            np.expand_dims(np.array(adata.obs["y_ccf"]), axis=1),
            np.expand_dims(np.array(adata.obs["z_ccf"]), axis=1),
        ),
        axis=1,
    )
    adata.var.set_index("gene_symbol", inplace=True, drop=False)

    return adata

In [25]:
ad = get_adata("../data/VISp_nhood.h5ad")

In [27]:
ad.obs["supertype"]

0               0162 OB Dopa-Gaba_1
1                0135 HPF CR Glut_1
2                0135 HPF CR Glut_1
3                0135 HPF CR Glut_1
4        1149 CBX MLI Megf11 Gaba_1
                    ...            
61879                1193 Endo NN_1
61880                1193 Endo NN_1
61881                1193 Endo NN_1
61882                1193 Endo NN_1
61883                1193 Endo NN_1
Name: supertype, Length: 61884, dtype: category
Categories (126, object): ['0001 CLA-EPd-CTX Car3 Glut_1', '0002 CLA-EPd-CTX Car3 Glut_2', '0003 IT EP-CLA Glut_1', '0004 IT EP-CLA Glut_2', ..., '1196 Monocytes NN_1', '1197 DC NN_1', '1198 B cells NN_1', '1201 T cells NN_4']

In [15]:
import toml


from pathlib import Path

import anndata as ad
import lightning as L
import torch
from torch.utils.data import ConcatDataset, DataLoader, random_split


class AnnDataGraphDataModule(L.LightningDataModule):
    def __init__(self, data_dir: None, file_names: list[str] = ["VISp_nhood.h5ad"], batch_size: int = 1):
        super().__init__()
        if data_dir is None:
            data_dir = get_paths()["data_root"]
            # data_dir = "../data/"
        self.adata_paths = [str(data_dir) + file_name for file_name in file_names]
        for adata_path in self.adata_paths:
            if not Path(adata_path).exists():
                raise FileNotFoundError(f"File not found: {adata_path}")

        self.batch_size = batch_size

    def setup(self, stage: str):
        self.adatas = []
        for adata_path in self.adata_paths:
            self.adatas.append(AnnDataGraphDataset(adata_path))
        self.data_full = ConcatDataset(self.adatas)
        self.data_train, self.data_test = random_split(
            self.data_full, [0.8, 0.2], generator=torch.Generator().manual_seed(0)
        )

        if stage == "fit":
            self.data_train, self.data_val = random_split(
                self.data_train, [0.8, 0.2], generator=torch.Generator().manual_seed(1)
            )

        if stage == "test":  # Note: this is not the test set. Just a quick way to check the model through lightining.
            _, self.data_test = random_split(self.data_full, [0.9, 0.1], generator=torch.Generator().manual_seed(0))

        if stage == "predict":
            self.data_predict = self.data_full

    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=16)

    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=16)

    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=16)

    def predict_dataloader(self):
        return DataLoader(self.data_predict, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=16)

In [16]:
class GAT3(torch.nn.Module):
    def __init__(self, hidden_channels, num_features, num_classes):
        super().__init__()
        torch.manual_seed(1234567)
        self.hidden_channels = hidden_channels
        self.num_features = num_features
        self.num_classes = num_classes
        self.conv1 = GATv2Conv(self.num_features, self.hidden_channels, heads=8, concat=False)
        self.conv2 = GATv2Conv(self.hidden_channels, self.num_classes, heads=8, concat=False)
        self.lin1 = nn.Linear(self.num_features, self.num_classes)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x, edge_index):
        residual1 = self.lin1(x)
        out = self.conv1(x, edge_index)
        out = out.relu()
        out = self.dropout(out)
        out = self.conv2(out, edge_index)
        out = out + residual1

        return out

In [17]:
import lightning as L
import torch
import torch.nn as nn
from torchmetrics import MeanSquaredError
from torchmetrics.classification import MulticlassAccuracy
from torch_geometric.nn import GCNConv


class LitGNNv0(L.LightningModule):
    def __init__(self, input_dim, hidden_dim, n_labels, weight_mse=1.0, weight_ce=1.0):
        super(LitGNNv0, self).__init__()

        self.weight_mse = weight_mse
        self.weight_ce = weight_ce

        self.GAT = GAT3(hidden_channels=32, num_features=input_dim, num_classes=n_labels)

        # losses
        self.loss_ce = nn.CrossEntropyLoss()

        # metrics
        self.metric_overall_acc = MulticlassAccuracy(
            num_classes=n_labels, top_k=1, average="weighted", multidim_average="global"
        )
        self.metric_macro_acc = MulticlassAccuracy(
            num_classes=n_labels, top_k=1, average="macro", multidim_average="global"
        )
        self.metric_multiclass_acc = MulticlassAccuracy(
            num_classes=n_labels, top_k=1, average=None, multidim_average="global"
        )

    def forward(self, x, edge_index):
        celltype = self.GAT(x, edge_index)
        return celltype

    def training_step(self, batch, batch_idx):
        # for GNN, batch size should be 1, and there isn't a batch dimension.
        gene_exp, edgelist, celltype = batch
        gene_exp = gene_exp.squeeze(dim=0)
        edgelist = edgelist.squeeze(dim=0).T
        celltype = celltype.squeeze(dim=0)

        celltype_pred = self.forward(gene_exp, edgelist)

        # Calculate losses
        total_loss = self.loss_ce(celltype_pred, celltype.squeeze())

        # Log losses
        self.log("train_total_loss", total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        # Calculate metrics
        train_overall_acc = self.metric_overall_acc(preds=celltype_pred, target=celltype.reshape(-1))
        train_macro_acc = self.metric_macro_acc(preds=celltype_pred, target=celltype.reshape(-1))

        # Log metrics
        self.log("train_overall_acc", train_overall_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_macro_acc", train_macro_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return total_loss

    def on_train_epoch_end(self):
        pass

    def validation_step(self, batch, batch_idx):
        # for GNN, batch size should be 1, and there isn't a batch dimension.
        gene_exp, edgelist, celltype = batch
        gene_exp = gene_exp.squeeze(dim=0)
        edgelist = edgelist.squeeze(dim=0).T
        celltype = celltype.squeeze(dim=0)

        celltype_pred = self.forward(gene_exp, edgelist)

        # Calculate metrics
        val_overall_acc = self.metric_overall_acc(preds=celltype_pred, target=celltype.reshape(-1))
        val_macro_acc = self.metric_macro_acc(preds=celltype_pred, target=celltype.reshape(-1))
        val_metric_multiclass_acc = self.metric_multiclass_acc(preds=celltype_pred, target=celltype.reshape(-1))

        self.log("val_overall_acc", val_overall_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_macro_acc", val_macro_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def on_validation_epoch_end(self):
        pass

    def on_test_epoch_end(self):
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        return optimizer

In [18]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger


# data parameters, we'll eventually obtain this from the data.
n_genes = 500
n_labels = 94

# paths
paths = get_paths()
expname = get_datetime(expname="VISp_nhood_GNN")
log_path = paths["runs_root"] + f"logs/{expname}"
checkpoint_path = paths["runs_root"] + f"checkpoints/{expname}"

# helpers
tb_logger = TensorBoardLogger(save_dir=log_path)
checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_path, monitor="val_rmse_overall", filename="{epoch}-{val_rmse_overall:.2f}"
)

# data, model and fitting
datamodule = AnnDataGraphDataModule(data_dir=paths["data_root"], file_names=["VISp_nhood.h5ad"], batch_size=1)
model = LitGNNv0(input_dim=n_genes, hidden_dim=32, n_labels=n_labels, weight_mse=1.0, weight_ce=0.1)
trainer = L.Trainer(
    limit_train_batches=1000, limit_val_batches=100, max_epochs=5, logger=tb_logger, callbacks=[checkpoint_callback]
)
trainer.fit(model=model, datamodule=datamodule)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
