Try out the gnn model


In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_absolute_error
from scipy.sparse import issparse

ModuleNotFoundError: No module named 'torch_geometric'

In [None]:
adata_rna = sc.read_h5ad("/lustre/groups/ml01/workspace/eirini.giannakoulia/datasets/V11L12-038/V11L12-038_A1/V11L12-038_A1.RNA_MOSCOT_paired_hvg.h5ad")
adata_rna

In [None]:
adata_msi = sc.read_h5ad("/lustre/groups/ml01/workspace/eirini.giannakoulia/datasets/V11L12-038/V11L12-038_A1/V11L12-038_A1.MSI_MOSCOT_paired_hvg.h5ad")
adata_msi

In [None]:
adata_rna_train = adata_rna[adata_rna.obs["split"] == "train"]
adata_rna_test = adata_rna[adata_rna.obs["split"] == "test"]
adata_msi_train = adata_msi[adata_msi.obs["split"] == "train"]
adata_msi_test = adata_msi[adata_msi.obs["split"] == "test"]

In [None]:
coords_rna_train = adata_rna_train.obsm["spatial_warp"]
coords_rna_test = adata_rna_test.obsm["spatial_warp"]

In [None]:
class MultiLayerGCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.5):
        super(MultiLayerGCN, self).__init__()
        
        # Define GCN layers
        self.layers = torch.nn.ModuleList()
        self.layers.append(GCNConv(input_dim, hidden_dim))  # First layer
        
        # Add hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(GCNConv(hidden_dim, hidden_dim))  # Middle layers
        
        self.output_layer = GCNConv(hidden_dim, output_dim)  # Output layer
        
        self.dropout = dropout

    def forward(self, x, edge_index):
        # Pass data through each GCN layer with ReLU activation
        for layer in self.layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Final GCN layer (output layer)
        x = self.output_layer(x, edge_index)
        
        return x  

In [None]:
# --- Build Training Graph ---
knn_train = NearestNeighbors(n_neighbors=k_train).fit(coords_rna_train)
_, indices_train = knn_train.kneighbors(coords_rna_train)
train_edges = []
for i, neighbors in enumerate(indices_train):
    for neighbor in neighbors:
        if i != neighbor:  # Avoid self-loops
            train_edges.append([i, neighbor])
train_edge_index = torch.tensor(train_edges, dtype=torch.long).t().contiguous().to(device)
train_data = Data(x=X_train_tensor, edge_index=train_edge_index)

# --- Initialize GCN Model ---
model = MultiLayerGCN(input_dim, hidden_dim, output_dim, num_layers=num_layers, dropout=dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

# --- Train GCN Model ---
model.train()
for epoch in range(epochs):
    optimizer.zero_grad()
    out = model(train_data.x, train_data.edge_index)
    loss = criterion(out, Y_train_tensor)
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

# --- Build Test Graph ---
knn_test = NearestNeighbors(n_neighbors=k_test).fit(coords_rna_test)
_, indices_test = knn_test.kneighbors(coords_rna_test)
test_edges = []
for i, neighbors in enumerate(indices_test):
    for neighbor in neighbors:
        if i != neighbor:
            test_edges.append([i, neighbor])
test_edge_index = torch.tensor(test_edges, dtype=torch.long).t().contiguous().to(device)
test_data = Data(x=X_test_tensor, edge_index=test_edge_index)

# --- Evaluate the Model ---
model.eval()
with torch.no_grad():
    Y_pred_train = model(train_data.x, train_data.edge_index).detach().cpu().numpy()
    Y_pred_test = model(test_data.x, test_data.edge_index).detach().cpu().numpy()

# --- Compute Evaluation Metrics (Test Set) ---
Y_test_np = Y_test_tensor.cpu().numpy()
pearson_corr = pearsonr(Y_pred_test.flatten(), Y_test_np.flatten())[0]
spearman_corr = spearmanr(Y_pred_test.flatten(), Y_test_np.flatten())[0]
rmse_test = np.sqrt(mean_squared_error(Y_test_np, Y_pred_test))
r2_test = r2_score(Y_test_np, Y_pred_test)
mae_test = mean_absolute_error(Y_test_np, Y_pred_test)

metrics = pd.DataFrame({
    'rmse': [rmse_test],
    'mae': [mae_test],
    'r2': [r2_test],
    'pearson': [pearson_corr],
    'spearman': [spearman_corr]
})

predictions = pd.DataFrame({
    'y_true': Y_test_np.flatten(),
    'y_pred': Y_pred_test.flatten()
})

In [None]:
metrics