In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling
import sys
import os
import torch
import numpy as np
import pandas as pd
import random
import copy
from torch_geometric.utils.dropout import dropout_adj
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath("C:\\Data\\Code\\BioML_manuscript\\data"))
from utils.boolODE_data_to_pyg_data import make_adj_from_df, to_pyg_data

In [None]:
# set user parameters:
#   datadir: directory in which data in the BoolODE format is available
#   name: name of the directory in which the data is located (subdirectory of datadir)
#   filenm: name under which the results should be saved for this network, note: output/"+filenm+"/"+filenm+"/" should exist before running!
#   num_features: amount of cells available for the data (758 for hESC network)
datadir = 'data/'
name = 'hESC'
filenm = 'hESC'

df=pd.read_csv(datadir + name + '/ExpressionData.csv', index_col=0)
adj_df = pd.read_csv(datadir + name + '/refNetwork.csv', index_col=0)

mat = df.to_numpy()

sz = df.to_numpy().shape
edge_index, adj = make_adj_from_df(datadir,df, name)
true_data = to_pyg_data(mat, sz[0], sz[1], edge_index=edge_index)

ode_dim = true_data.x.shape[0]

num_features = 758

In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import degree

# method to extract a subnetwork with the top N nodes from a larger pytorch geometric data file
def top_n_nodes_by_degree(data, N):

    num_nodes = data.num_nodes
    # compute node degree
    deg = degree(data.edge_index[0], num_nodes=num_nodes)

    # get indices of top N nodes
    top_n_indices = deg.topk(N).indices
    node_mask = torch.zeros(num_nodes, dtype=torch.bool)
    node_mask[top_n_indices] = True

    old_to_new = -torch.ones(num_nodes, dtype=torch.long)
    old_to_new[top_n_indices] = torch.arange(N)

    # keep only edges between the top-N nodes
    src, dst = data.edge_index
    edge_mask = node_mask[src] & node_mask[dst]
    new_edge_index = data.edge_index[:, edge_mask]
    new_edge_index = old_to_new[new_edge_index]

    # subsample node features and create new data object
    new_data = Data(
        x=data.x[top_n_indices] if data.x is not None else None,
        edge_index=new_edge_index
    )

    for key in data.keys:
        if key in ['x', 'edge_index']:
            continue
        attr = data[key]
        if torch.is_tensor(attr) and attr.size(0) == num_nodes:
            new_data[key] = attr[top_n_indices]
        else:
            new_data[key] = attr

    return new_data

In [None]:
# we create subnetworks for different numbers of nodes
from torch_geometric.utils import train_test_split_edges

data100 = top_n_nodes_by_degree(true_data, 100)
data100 = train_test_split_edges(data100)

data200 = top_n_nodes_by_degree(true_data, 200)
data200 = train_test_split_edges(data200)

data500 = top_n_nodes_by_degree(true_data, 500)
data500 = train_test_split_edges(data500)

data1000 = top_n_nodes_by_degree(true_data, 1000)
data1000 = train_test_split_edges(data1000)

data2000 = top_n_nodes_by_degree(true_data, 2000)
data2000 = train_test_split_edges(data2000)

data5000 = top_n_nodes_by_degree(true_data, 5000)
data5000 = train_test_split_edges(data5000)

In [30]:
# Define Graph Autoencoder (GAE) Model
class GAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=16):
        super(GAE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.dropout = torch.nn.Dropout(0.3)

        # one linear layer (only weights) for decoding
        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

    # encode node features
    def encode(self, data):
        #x = self.dropout(data.x)
        edge_index = dropout_adj(data.train_pos_edge_index, p = 0.2)[0]
        #edge_index = data.edge_index
        x = self.conv1(data.x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        return self.conv2(x, edge_index)

    # decode specific edges
    def decode(self, z, edge_index):
        return (z[edge_index[0]] * self.lin1(z[edge_index[1]])).sum(dim=-1)  # Inner product
    
    # decode all edges for full adjacency matrix inference
    def decode_all(self,z):
        adj_matrix = torch.ones((z.shape[0], z.shape[0]))
        full_edge_index = adj_matrix.nonzero().t().contiguous()

        return (z[full_edge_index[0]] * self.lin1(z[full_edge_index[1]])).sum(dim=-1)

In [None]:
# train Model
def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data)

    # here, we ensure that negative edges used for validating the model do not appear in the negative edges used to train the model (avoiding data leakage)
    neg_edges = negative_sampling(torch.cat([data.train_pos_edge_index,data.val_neg_edge_index],dim=1), data.x.shape[0], data.train_pos_edge_index.size(1))

    edges = torch.cat([data.train_pos_edge_index, neg_edges], dim=1)
    
    # labels: 1 for real edges, 0 for negative samples
    labels = torch.cat([torch.ones(data.train_pos_edge_index.size(1)), torch.zeros(neg_edges.size(1))]).to(data.x.device)
    preds = model.decode(z, edges)
    
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

def validate(model, data, criterion):
    model.eval()
    z = model.encode(data)

    edges = torch.cat([data.val_pos_edge_index, data.val_neg_edge_index], dim=1)
    labels = torch.cat([torch.ones(data.val_pos_edge_index.size(1)), torch.zeros(neg_edges.size(1))]).to(data.x.device)

    preds = model.decode(z, edges)

    val_loss = criterion(preds, labels)

    return val_loss.item()


In [None]:
from sklearn.metrics import roc_auc_score

# method to compute AUROC score for validation data (can be changed to test data)
def auroc(model, data, criterion):
    model.eval()
    z = model.encode(data)
    neg_edges = negative_sampling(data.train_pos_edge_index, data.x.shape[0], data.val_pos_edge_index.size(1))
    edges = torch.cat([data.val_pos_edge_index, neg_edges], dim=1)

    labels = torch.cat([torch.ones(data.val_pos_edge_index.size(1)), torch.zeros(neg_edges.size(1))]).to(data.x.device)

    preds = model.decode(z, edges)

    preds = preds.sigmoid()

    labels = labels.cpu()
    preds = preds.cpu()
    labels = labels.detach().numpy()
    preds = preds.detach().numpy()

    auroc = roc_auc_score(labels, preds)

    return [auroc, labels, preds]

In [None]:
from scipy.io import savemat

num_features = 758
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set data to run the evaluation on
nm_cur = "data100"
data_cur = data100

data_cur.to(device)

auroc_scores = []

# repeat model training 10 times, saving AUROC scores for validation set (can be changed to test set)
for k in range(0,10):

    print("training model " + str(k) + "...")
    model = GAE(input_dim=num_features,hidden_dim=200)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.BCEWithLogitsLoss()

    loss_vec = []

    for epoch in range(500):
        loss = train(model, data_cur, optimizer, criterion)
        
        loss_vec.append(loss)

        val_loss = validate(model, data_cur, criterion)
        print(val_loss)

    ans = auroc(model,data_cur,criterion)
    auroc_scores.append(ans[0])

savemat(nm_cur+"_auroc.mat",{"auroc":auroc_scores})



6.95760440826416
val loss:
13.132295608520508
13.67640495300293
val loss:
7.955120086669922
9.454837799072266
val loss:
1.8988947868347168
3.59898042678833
val loss:
3.631420850753784
5.5261006355285645
val loss:
4.386188983917236
6.373952388763428
val loss:
1.9604909420013428
3.266230583190918
val loss:
0.8643813133239746
1.9806842803955078
val loss:
1.518283724784851
2.089463710784912
val loss:
2.184612989425659
2.815073013305664
val loss:
2.186833620071411
2.8975400924682617
val loss:
1.667678713798523
2.1806788444519043
val loss:
1.1074415445327759
1.8132374286651611
val loss:
0.7549586296081543
1.5588256120681763
val loss:
0.8122979402542114
1.4901893138885498
val loss:
0.9041725397109985
1.7439045906066895
val loss:
0.7923570275306702
1.6361552476882935
val loss:
0.5819644927978516
1.2791329622268677
val loss:
0.5037238597869873
1.045715093612671
val loss:
0.5146893262863159
0.9149564504623413
val loss:
0.5464773774147034
0.9608250260353088
val loss:
0.5597227811813354
0.85042721