In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling
import sys
import os
import torch
import numpy as np
import pandas as pd
import random
import copy
from torch_geometric.utils.dropout import dropout_adj
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath("C:\\Data\\Code\\BioML_manuscript\\data"))
from utils.boolODE_data_to_pyg_data import make_adj_from_df, to_pyg_data

In [5]:
datadir = 'data/'
name = 'dyn_trifurcating'
df=pd.read_csv(datadir + name + '/ExpressionData.csv', index_col=0)

adj_df = pd.read_csv(datadir + name + '/refNetwork.csv', index_col=0)

mat = df.to_numpy()

sz = df.to_numpy().shape
edge_index, adj = make_adj_from_df(datadir,df, name)
true_data = to_pyg_data(mat, sz[0], sz[1], edge_index=edge_index)

ode_dim = true_data.x.shape[0]

num_features = 3000

In [6]:
# Define Graph Autoencoder (GAE) Model
class GAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=16):
        super(GAE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.dropout = torch.nn.Dropout(0.3)

        # one linear layer (only weights) for decoding
        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

    # encode node features
    def encode(self, data):
        #x = self.dropout(data.x)
        edge_index = dropout_adj(data.edge_index, p = 0.2)[0]
        #edge_index = data.edge_index
        x = self.conv1(data.x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        return self.conv2(x, edge_index)

    # decode specific edges
    def decode(self, z, edge_index):
        return (z[edge_index[0]] * self.lin1(z[edge_index[1]])).sum(dim=-1)  # Inner product
    
    # decode all edges for full adjacency matrix inference
    def decode_all(self,z):
        adj_matrix = torch.ones((z.shape[0], z.shape[0]))
        full_edge_index = adj_matrix.nonzero().t().contiguous()

        return (z[full_edge_index[0]] * self.lin1(z[full_edge_index[1]])).sum(dim=-1)

In [7]:
# Train Model
def train(model, data, query, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data)
    neg_edges = negative_sampling(torch.cat([data.edge_index,query],dim=1), data.num_nodes, data.edge_index.size(1))

    edges = torch.cat([data.edge_index, neg_edges], dim=1)
    
    # Labels: 1 for real edges, 0 for negative samples
    labels = torch.cat([torch.ones(data.edge_index.size(1)), torch.zeros(neg_edges.size(1))]).to(data.x.device)
    preds = model.decode(z, edges)
    
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

def train_model(data,query, device):
    model = GAE(input_dim=num_features,hidden_dim=200)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.BCEWithLogitsLoss()

    loss_vec = []

    for epoch in range(500):
        loss = train(model, data, query, optimizer, criterion)
        
        loss_vec.append(loss)

    return model

In [8]:
for i in range(0,len(true_data.edge_index[0])):
    rem = i

    imputed_edge_index = true_data.edge_index

    mask = torch.ones(imputed_edge_index.shape[1], dtype=torch.bool)
    mask[rem] = False

    imputed_edge_index = imputed_edge_index[:,mask]

    # data has MOST edges
    data = copy.deepcopy(true_data)
    data.edge_index = imputed_edge_index

    adj_matrix = torch.ones((ode_dim, ode_dim))

    query_edge_index = adj_matrix.nonzero().t().contiguous()

    rem_query = []
    for k in range(0,len(query_edge_index[0])):
        for j in range(0,len(data.edge_index[0])):
            if query_edge_index[0][k] == data.edge_index[0][j] and query_edge_index[1][k] == data.edge_index[1][j]:
                rem_query.append(k)

    mask = torch.ones(query_edge_index.shape[1], dtype=torch.bool)
    mask[rem_query] = False

    query_edge_index = query_edge_index[:,mask]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data = data.to(device)

    adj = torch.zeros((ode_dim, ode_dim))
    count_adj = torch.zeros((ode_dim, ode_dim))

    print("Training for edge "+str(i+1)+"/"+str(len(true_data.edge_index[0])))
    for k in range(0,len(query_edge_index[0])):
        # print(str(k+1)+'/'+str(len(query_edge_index[0])))
        query = torch.tensor([[query_edge_index[0][k]],[query_edge_index[1][k]]])
        query = query.to(device)
        model = train_model(data,query, device)
        model.eval()
        
        z = model.encode(data)
        dec = model.decode(z, query)


        for k in range(0, len(query[0])):
            adj[query[0][k], query[1][k]] = torch.sigmoid(dec[k])
            count_adj[query[0][k], query[1][k]] += 1

    from scipy.io import savemat

    inferred_adj = (adj/count_adj).detach().numpy()
    i1 = true_data.edge_index[:,1][0].numpy()
    i2 = true_data.edge_index[:,1][1].numpy()
    savemat("dyn_trifurcating_new_"+str(i1)+"_"+str(i2)+".mat",{"inferred_adj": inferred_adj})

    

Training for edge 1/20




KeyboardInterrupt: 