This notebook uses graph neural networks to predict cell types (subclass, supertype).

We consider cells as nodes and their gene expression as initial node features.

Graphs are constructed in 3 different ways: self edges, edges from a k-nearest-neighbors graph, and peptidergic communication networks (directed multilayer graphs).

Cell type prediction is challenging when few genes are available.

We find that gene expression in neighboring cells can improve classification in this regime

In [1]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.neighbors import kneighbors_graph
from torch_geometric.data import Data
from torch_geometric.utils import add_remaining_self_loops, from_scipy_sparse_matrix

from cci.gnn import GATnet, GCNnet, train_gnn, val_gnn, test_gnn
from scipy.sparse import csr_array
from sklearn.model_selection import StratifiedKFold
import seaborn as sns
from cci.utils import get_adata, get_new_gene_subsets
from collections import Counter

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", font_scale=0.5, rc=custom_params)
%config InlineBackend.figure_format="retina"

adata = get_adata("VISp")



In [2]:
# Create a dataframe from a slice of VISp with the most cells
display(adata.obs[["brain_section_label", "z_section"]].sort_values("z_section").value_counts().to_frame().head(4))
one_sec = adata[adata.obs["z_section"] == 5.0, :]
df = one_sec.obs.copy()

Unnamed: 0_level_0,Unnamed: 1_level_0,count
brain_section_label,z_section,Unnamed: 2_level_1
C57BL6J-638850.30,5.0,9242
C57BL6J-638850.29,4.8,8713
C57BL6J-638850.28,4.6,7780
C57BL6J-638850.31,5.4,6939


In [4]:
# Graph construction hyperparameters
d = 40 / 1000  # (in mm)
L_thr = 0.0
R_thr = 0.0
lr_gene_pairs = [["Tac2", "Tacr3"], ["Penk", "Oprd1"], ["Pdyn", "Oprd1"], ["Pdyn", "Oprk1"], ["Grp", "Grpr"]]
n_layers = len(lr_gene_pairs)
num_nodes = df.shape[0]
cell_type = "supertype"

In [6]:
edge_index_list = [None] * n_layers
df["participant"] = np.zeros(num_nodes, dtype=bool)

# Edgelist from multi-layer graphs
for i in range(n_layers):
    ligand, receptor = lr_gene_pairs[i]
    df["L"] = one_sec[:, one_sec.var["gene_symbol"] == ligand].X.toarray().ravel()
    df["R"] = one_sec[:, one_sec.var["gene_symbol"] == receptor].X.toarray().ravel()

    df["L"] = (df["L"] > L_thr).astype(bool)
    df["R"] = (df["R"] > R_thr).astype(bool)

    A = df["L"].values.reshape(-1, 1) @ df["R"].values.reshape(1, -1)
    Dx = (df["x_reconstructed"].values.reshape(-1, 1) - df["x_reconstructed"].values.reshape(1, -1)) ** 2
    Dy = (df["y_reconstructed"].values.reshape(-1, 1) - df["y_reconstructed"].values.reshape(1, -1)) ** 2
    D = np.sqrt(Dx + Dy)
    del Dx, Dy

    # cells are connected only if within distance d
    A[D > d] = 0

    # participant should have more than one connection
    df["participant"] = df["participant"] + (A.sum(axis=1) > 1)

    # construct directed graph from adjacency matrix
    edge_index_list[i], _ = from_scipy_sparse_matrix(csr_array(A))


# Squash the multi-layer graph into a single layer graph  
edge_index_squashed = set(edge_index_list[0].T)
for i in range(1, len(edge_index_list)):
    edge_index_squashed = set(edge_index_list[i].T).union(edge_index_squashed)
edge_index_squashed = list(edge_index_squashed)
edge_index_list.append(edge_index_squashed)

In [7]:
# Get stratified splits based on cell type label
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Keep cells who are participants in the multilayer graph (more than 1 connection)
keep = df["participant"].values
train_idx, test_idx = next(skf.split(np.arange(df[keep].shape[0]), df[cell_type][keep].values))

train_mask = np.zeros(df.shape[0], dtype=bool)
train_mask[train_idx] = True
train_mask = torch.tensor(train_mask, dtype=torch.bool)

test_mask = np.zeros(df.shape[0], dtype=bool)
test_mask[test_idx] = True
test_mask = torch.tensor(test_mask, dtype=torch.bool)

labels = torch.tensor(df[cell_type].cat.codes.values, dtype=torch.long)

multilayer_data = Data(x=None,
            edge_index_list=edge_index_list,
            y=labels,
            train_mask=train_mask,
            test_mask=test_mask)



In [8]:
# Edgelist from k nearest neighbors 
k = 10
X = df[[ "x_reconstructed", "y_reconstructed"]]

A = kneighbors_graph(X, n_neighbors=k, mode="distance", include_self=False)

edgelist_knn = torch.tensor(list(zip(A.tocoo().row, A.tocoo().col)))
edgelist_knn = add_remaining_self_loops(edgelist_knn.T)[0]

In [9]:
# Given a knn adjacency matrix, find how many cells are connected to other cells of the same type
celltype_match_cnt = []
nodes_to_include = []
nodes_excluded = []
for i in range(num_nodes):
    ind = np.argwhere(A[i].toarray().flatten() > 0).flatten()
    neighbors = df.iloc[ind, df.columns.get_loc(cell_type)]
    neighbors = Counter(list(neighbors))
    origin = df.iloc[i, df.columns.get_loc(cell_type)]
    origin_neighbor_match_cnt = neighbors.get(origin, 0)
    celltype_match_cnt.append(origin_neighbor_match_cnt)

pd.DataFrame(celltype_match_cnt).value_counts()

0 
0     2685
1     1343
2      912
3      800
4      782
5      743
6      683
7      630
8      426
9      198
10      40
Name: count, dtype: int64

In [10]:
# Edgelist consisting of self-edges
edgelist_self = torch.tensor([[i, i] for i in range(num_nodes)])
edgelist_self = edgelist_self.T

# Edgelist from squashed mult-layer graph
edgelist_squashed = torch.stack(edge_index_list[-1], dim=0)
edgelist_squashed = add_remaining_self_loops(edgelist_squashed.T)[0]

In [63]:
# Generate new gene subsets by passing a list that specifies the size of each subset (with ratios)
gene_subsets_new = get_new_gene_subsets(adata, [0.01, 0.01, 0.01])

In [43]:
# Subsets of genes used to generate results shown below
gene_subsets = [
[['Slc1a3', 'Sp8', 'Prom1', 'Gja1', 'Mkx'], 
['Slc1a3', 'Sp8', 'Prom1', 'Gja1', 'Mkx', 'Blank-2', 'Glis3', 'Acta2', 'Ramp3', 'Sla'], 
['Slc1a3', 'Sp8', 'Prom1', 'Gja1', 'Mkx', 'Blank-2', 'Glis3', 'Acta2', 'Ramp3', 'Sla', 'St3gal1', 'Gpc3', 'Ptger3', 'Kcnj5', 'Medag']]]

In [65]:
# Specify which subset to use as features
gene_subset_idx = 0
one_sec_gene_subset = one_sec[:, gene_subsets_new[gene_subset_idx]]
one_sec_x = torch.tensor(one_sec_gene_subset.X.todense(), dtype=torch.float)

# Create PYG data objects from the 3 different edgelists

data_knn = Data(x=one_sec_x,
            edge_index=edgelist_knn,
            y=labels,
            train_mask=train_mask,
            test_mask=test_mask)

data_self = Data(x=one_sec_x,
            edge_index=edgelist_self,
            y=labels,
            train_mask=train_mask,
            test_mask=test_mask)

data_squashed = Data(x=one_sec_x,
                edge_index=edgelist_squashed,
                y=labels,
                train_mask=train_mask,
                test_mask=test_mask)

In [66]:
# Create GATv2 models
model_knn = GATnet(hidden_channels=32, num_features = data_knn.x.shape[1] , num_classes = torch.unique(data_knn.y).size()[0] )
model_self = GATnet(hidden_channels=32, num_features = data_self.x.shape[1] , num_classes = torch.unique(data_self.y).size()[0] )
model_squashed = GATnet(hidden_channels=32, num_features = data_squashed.x.shape[1] , num_classes = torch.unique(data_squashed.y).size()[0] )

In [67]:
#Train GNN to predict celltyped with self edges
def self_edges(model_self, data_self):
    optimizer_self = torch.optim.Adam(model_self.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 101):
        loss = train_gnn(model_self, optimizer_self, data_self, criterion)
        val_acc = test_gnn(model_self, data_self)
        # print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val acc: {val_acc:.4f}")

    model_self.eval()
    val_acc = test_gnn(model_self, data_self)
    print(f'Self edge accuracy: {val_acc:.4f}')
    return val_acc
self_edges(model_self, data_self)    

Self edge accuracy: 0.1710


In [68]:
#Train GNN to predict celltyped with KNN edges
def knn_edges(model_knn, data_knn):
    optimizer_knn = torch.optim.Adam(model_knn.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 101):
        loss = train_gnn(model_knn, optimizer_knn, data_knn, criterion)
        val_acc = test_gnn(model_knn, data_knn)
        # print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val acc: {val_acc:.4f}")

    model_knn.eval()
    val_acc = test_gnn(model_knn, data_knn)
    print(f'KNN accuracy: {val_acc:.4f}')
    return val_acc
knn_edges(model_knn, data_knn)

KNN accuracy: 0.1898


In [69]:
##Train GNN to predict celltypes with a squashed multilayer graph
def multilayer_edges(model_lr, data_lr):
    optimizer_lr = torch.optim.Adam(model_lr.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 101):
        loss = train_gnn(model_lr, optimizer_lr, data_lr, criterion)
        val_acc = test_gnn(model_lr, data_lr)
        # print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val acc: {val_acc:.4f}")

    model_lr.eval()
    val_acc = test_gnn(model_lr, data_lr)
    print(f'Multilayer graph accuracy: {val_acc:.4f}')
    return val_acc
multilayer_edges(model_squashed, data_squashed)    

Multilayer graph accuracy: 0.1804
