Edges based on gene expression. Same idea as LR edges. This isn't the best idea because we want to select the genes during training. However, if we find that some gene edges are very informative when used in a GNN, biologists could always select those genes as part of the experiment.

In [9]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.neighbors import kneighbors_graph
from torch_geometric.data import Data
from torch_geometric.utils import add_remaining_self_loops

import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, GCNConv
import torch.nn as nn
from graphFeatureSelect.utils import get_adata

from sklearn.model_selection import StratifiedKFold
import seaborn as sns
from torch_geometric.utils import add_remaining_self_loops, from_scipy_sparse_matrix
from scipy.sparse import csr_array


custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", font_scale=0.5, rc=custom_params)
%config InlineBackend.figure_format="retina"

In [3]:
def train_gnn_concrete(model, optimizer, data, criterion, temp):
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    out = model(data.x, data.edge_index, temp)  # Perform a single forward pass.
    loss = criterion(
        out[data.train_mask], data.y[data.train_mask]
    )  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss


def val_gnn_concrete(model, data):
    model.eval()
    temp = 0.01
    out = model(data.x, data.edge_index, temp)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    val_correct = pred[data.val_mask] == data.y[data.val_mask]  # Check against ground-truth labels.
    val_acc = int(val_correct.sum()) / int(data.val_mask.sum())  # Derive ratio of correct predictions.
    return val_acc


def test_gnn_concrete(model, data):
    model.eval()
    temp = 0.01
    out = model(data.x, data.edge_index, temp)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
    test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
    return test_acc

In [4]:
class GATnet_concrete(torch.nn.Module):
    def __init__(self, n_mask, hidden_channels, num_features, num_classes):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATv2Conv(num_features, hidden_channels)
        self.conv2 = GATv2Conv(hidden_channels, num_classes)
        self.n_mask = n_mask
        self.num_features = num_features
        self.num_classes = num_classes
        self.concrete = nn.Parameter(torch.randn(self.n_mask, self.num_features))

    def forward(self, x, edge_index, temp):
        mask = F.gumbel_softmax(self.concrete, tau=temp, hard=True)
        mask = torch.sum(mask, axis=0)
        mask = torch.clamp(mask, min=0, max=1)
        x = mask * x
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

    def softmax(self):
        return F.softmax(self.concrete, dim=1)

In [5]:
adata = get_adata()



In [6]:
display(adata.obs[["brain_section_label", "z_section"]].sort_values("z_section").value_counts().to_frame().head(4))
one_sec = adata[adata.obs["z_section"] == 5.0, :]
df = one_sec.obs.copy()
num_nodes = df.shape[0]
cell_type = "supertype"

Unnamed: 0_level_0,Unnamed: 1_level_0,count
brain_section_label,z_section,Unnamed: 2_level_1
C57BL6J-638850.30,5.0,9242
C57BL6J-638850.29,4.8,8713
C57BL6J-638850.28,4.6,7780
C57BL6J-638850.31,5.4,6939


In [7]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

train_idx, test_idx = next(skf.split(np.arange(df.shape[0]), df[cell_type].values))

train_mask = np.zeros(df.shape[0], dtype=bool)
train_mask[train_idx] = True
train_mask = torch.tensor(train_mask, dtype=torch.bool)

test_mask = np.zeros(df.shape[0], dtype=bool)
test_mask[test_idx] = True
test_mask = torch.tensor(test_mask, dtype=torch.bool)

labels = torch.tensor(df[cell_type].cat.codes.values, dtype=torch.long)



In [8]:
# Graph construction hyperparameters
d = 40 / 1000  # (in mm)
L_thr = 0.0
R_thr = 0.0
lr_gene_pairs = [["Tac2", "Tacr3"], ["Penk", "Oprd1"], ["Pdyn", "Oprd1"], ["Pdyn", "Oprk1"], ["Grp", "Grpr"]]
n_layers = len(lr_gene_pairs)
num_nodes = df.shape[0]
cell_type = "supertype"

In [10]:
edge_index_list = [None] * n_layers
df["participant"] = np.zeros(num_nodes, dtype=bool)

# Edgelist from multi-layer graphs
for i in range(n_layers):
    ligand, receptor = lr_gene_pairs[i]
    df["L"] = one_sec[:, one_sec.var["gene_symbol"] == ligand].X.toarray().ravel()
    df["R"] = one_sec[:, one_sec.var["gene_symbol"] == receptor].X.toarray().ravel()

    df["L"] = (df["L"] > L_thr).astype(bool)
    df["R"] = (df["R"] > R_thr).astype(bool)

    A = df["L"].values.reshape(-1, 1) @ df["R"].values.reshape(1, -1)
    Dx = (df["x_reconstructed"].values.reshape(-1, 1) - df["x_reconstructed"].values.reshape(1, -1)) ** 2
    Dy = (df["y_reconstructed"].values.reshape(-1, 1) - df["y_reconstructed"].values.reshape(1, -1)) ** 2
    D = np.sqrt(Dx + Dy)
    del Dx, Dy

    # cells are connected only if within distance d
    A[D > d] = 0

    # participant should have more than one connection
    df["participant"] = df["participant"] + (A.sum(axis=1) > 1)

    # construct directed graph from adjacency matrix
    edge_index_list[i], _ = from_scipy_sparse_matrix(csr_array(A))


# Squash the multi-layer graph into a single layer graph
edge_index_squashed = set(edge_index_list[0].T)
for i in range(1, len(edge_index_list)):
    edge_index_squashed = set(edge_index_list[i].T).union(edge_index_squashed)
edge_index_squashed = list(edge_index_squashed)
edge_index_list.append(edge_index_squashed)

In [11]:
edgelist_squashed = torch.stack(edge_index_list[-1], dim=0)
edgelist_squashed = add_remaining_self_loops(edgelist_squashed.T)[0]

In [13]:
num_genes_considered = 100
genes_to_select = 10

In [19]:
one_sec_x = torch.tensor(one_sec.X.todense(), dtype=torch.float)
one_sec_x_subset = torch.tensor(one_sec.X.todense(), dtype=torch.float)[
    :, :num_genes_considered
]  # only use first n genes to feature select from

data_gene_concrete = Data(
    x=one_sec_x_subset, edge_index=edgelist_squashed, y=labels, train_mask=train_mask, test_mask=test_mask
)
data_gene_concrete_full = Data(
    x=one_sec_x, edge_index=edgelist_squashed, y=labels, train_mask=train_mask, test_mask=test_mask
)

In [20]:
model_gene_concrete = GATnet_concrete(
    n_mask=genes_to_select,
    hidden_channels=32,
    num_features=data_gene_concrete.x.shape[1],
    num_classes=torch.unique(data_gene_concrete.y).size()[0],
)
model_gene_concrete_full = GATnet_concrete(
    n_mask=genes_to_select,
    hidden_channels=32,
    num_features=data_gene_concrete_full.x.shape[1],
    num_classes=torch.unique(data_gene_concrete_full.y).size()[0],
)

In [16]:
def linear_temp_schedule(epoch):
    return 10 * (1 - epoch / 1000) + 1e-3


def exp_decay_temp_schedule(epoch, total_epoch):
    start_temp = 10
    end_temp = 0.01
    temp = start_temp * (end_temp / start_temp) ** (epoch / total_epoch)
    return temp


# Article from Ian Covert says using temp = 0.1 throughout training is effective as well.

In [17]:
# Train model_gene_concrete to predict cell types
def gene_edges_concrete(model_gene, data_gene):
    optimizer_gene = torch.optim.Adam(model_gene.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 1001):
        loss = train_gnn_concrete(model_gene, optimizer_gene, data_gene, criterion, 0.1)
        val_acc = test_gnn_concrete(model_gene, data_gene)
        if epoch % 200 == 0:
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val acc: {val_acc:.4f}")

    model_gene.eval()
    val_acc = test_gnn_concrete(model_gene, data_gene)
    print(f"Gene edge accuracy: {val_acc:.4f}")
    return val_acc


gene_edges_concrete(model_gene_concrete, data_gene_concrete)

Epoch: 200, Loss: 3.2590, Val acc: 0.1920
Epoch: 400, Loss: 3.1533, Val acc: 0.1449
Epoch: 600, Loss: 3.1360, Val acc: 0.3407
Epoch: 800, Loss: 2.9140, Val acc: 0.3121
Epoch: 1000, Loss: 3.0935, Val acc: 0.2731
Gene edge accuracy: 0.2704


0.2704164413196322

In [18]:
for i in range(len(model_gene_concrete.concrete)):
    print(torch.max(F.softmax(model_gene_concrete.concrete[i])))
    print(torch.argmax(F.softmax(model_gene_concrete.concrete[i])))

tensor(0.0570, grad_fn=<MaxBackward1>)
tensor(47)
tensor(0.1102, grad_fn=<MaxBackward1>)
tensor(40)
tensor(0.0660, grad_fn=<MaxBackward1>)
tensor(39)
tensor(0.1195, grad_fn=<MaxBackward1>)
tensor(78)
tensor(0.0577, grad_fn=<MaxBackward1>)
tensor(67)
tensor(0.0572, grad_fn=<MaxBackward1>)
tensor(48)
tensor(0.0459, grad_fn=<MaxBackward1>)
tensor(32)
tensor(0.0707, grad_fn=<MaxBackward1>)
tensor(75)
tensor(0.0580, grad_fn=<MaxBackward1>)
tensor(73)
tensor(0.0567, grad_fn=<MaxBackward1>)
tensor(74)


  print(torch.max(F.softmax(model_gene_concrete.concrete[i])))
  print(torch.argmax(F.softmax(model_gene_concrete.concrete[i])))


In [22]:
# Train model_gene_concrete_full to predict cell types
def gene_edges_concrete_full(model_gene_full, data_gene):
    optimizer_gene = torch.optim.Adam(model_gene_full.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 2001):
        loss = train_gnn_concrete(model_gene_full, optimizer_gene, data_gene, criterion, 0.1)
        val_acc = test_gnn_concrete(model_gene_full, data_gene)
        if epoch % 200 == 0:
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val acc: {val_acc:.4f}")

    model_gene_full.eval()
    val_acc = test_gnn_concrete(model_gene_full, data_gene)
    print(f"Gene edge accuracy: {val_acc:.4f}")
    return val_acc


gene_edges_concrete(model_gene_concrete_full, data_gene_concrete_full)

Epoch: 200, Loss: 3.2190, Val acc: 0.1833
Epoch: 400, Loss: 3.1513, Val acc: 0.1466
Epoch: 600, Loss: 3.1007, Val acc: 0.2445
Epoch: 800, Loss: 3.0798, Val acc: 0.1747
Epoch: 1000, Loss: 2.9429, Val acc: 0.1520
Gene edge accuracy: 0.2985


0.298539751216874

In [85]:
for i in range(len(model_gene_concrete_full.concrete)):
    print(torch.max(F.softmax(model_gene_concrete_full.concrete[i])))
    print(torch.argmax(F.softmax(model_gene_concrete_full.concrete[i])))

tensor(0.0117, grad_fn=<MaxBackward1>)
tensor(83)
tensor(0.0156, grad_fn=<MaxBackward1>)
tensor(266)
tensor(0.0313, grad_fn=<MaxBackward1>)
tensor(326)
tensor(0.0137, grad_fn=<MaxBackward1>)
tensor(21)
tensor(0.0244, grad_fn=<MaxBackward1>)
tensor(28)
tensor(0.0207, grad_fn=<MaxBackward1>)
tensor(495)
tensor(0.0057, grad_fn=<MaxBackward1>)
tensor(178)
tensor(0.0072, grad_fn=<MaxBackward1>)
tensor(56)
tensor(0.0095, grad_fn=<MaxBackward1>)
tensor(517)
tensor(0.0125, grad_fn=<MaxBackward1>)
tensor(300)


  print(torch.max(F.softmax(model_self_concrete_full.concrete[i])))
  print(torch.argmax(F.softmax(model_self_concrete_full.concrete[i])))
