In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
import sys
import os
import torch
import numpy as np
import pandas as pd
import random
import copy
from torch_geometric.utils.dropout import dropout_adj
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import wandb
sys.path.append(os.path.abspath("C:\\Data\\Code\\BioML_refactor\\models"))
sys.path.append(os.path.abspath("C:\\Data\\Code\\BioML_refactor\\data"))
sys.path.append(os.path.abspath("C:\\Data\\Code\\BioML_refactor\\utils"))
from boolODE_data_to_pyg_data import make_adj_from_df, to_pyg_data
from data_creation import create_dataset, to_pyg_data_true

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
datadir = 'data/'
name = 'dyn_bifurcating'
df=pd.read_csv(datadir + name + '/ExpressionData.csv', index_col=0)

adj_df = pd.read_csv(datadir + name + '/refNetwork.csv', index_col=0)

mat = df.to_numpy()

sz = df.to_numpy().shape
edge_index, adj = make_adj_from_df(datadir,df, name)
true_data = to_pyg_data(mat, sz[0], sz[1], edge_index=edge_index)

ode_dim = true_data.x.shape[0]

num_features = 3000

Data(x=[7, 3000], edge_index=[2, 12])


In [7]:
# generate imputed data where some edges are missing (horrible implementation)
p = 0.1
rem = []
while len(rem) != 1:
    rem = []
    imputed_edge_index = true_data.edge_index
    while len(rem) == 0:
        for k in range(0,len(true_data.edge_index[0])):
            if random.random() < p:
                rem.append(k)

mask = torch.ones(imputed_edge_index.shape[1], dtype=torch.bool)
mask[rem] = False

imputed_edge_index = imputed_edge_index[:,mask]

# data has MOST edges
data = copy.deepcopy(true_data)
data.edge_index = imputed_edge_index

print(true_data.edge_index[:,rem])

tensor([[3],
        [6]])


In [8]:
# Define Graph Autoencoder (GAE) Model
class GAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=16):
        super(GAE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.dropout = torch.nn.Dropout(0.3)

        # one linear layer (only weights) for decoding
        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

    # encode node features
    def encode(self, data):
        #x = self.dropout(data.x)
        edge_index = dropout_adj(data.edge_index, p = 0.2)[0]
        #edge_index = data.edge_index
        x = self.conv1(data.x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        return self.conv2(x, edge_index)

    # decode specific edges
    def decode(self, z, edge_index):
        return (z[edge_index[0]] * self.lin1(z[edge_index[1]])).sum(dim=-1)  # Inner product
    
    # decode all edges for full adjacency matrix inference
    def decode_all(self,z):
        adj_matrix = torch.ones((z.shape[0], z.shape[0]))
        full_edge_index = adj_matrix.nonzero().t().contiguous()

        return (z[full_edge_index[0]] * self.lin1(z[full_edge_index[1]])).sum(dim=-1)

In [9]:
# Train Model
def train(model, data, query, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data)
    neg_edges = negative_sampling(torch.cat([data.edge_index,query],dim=1), data.num_nodes, data.edge_index.size(1))

    edges = torch.cat([data.edge_index, neg_edges], dim=1)
    
    # Labels: 1 for real edges, 0 for negative samples
    labels = torch.cat([torch.ones(data.edge_index.size(1)), torch.zeros(neg_edges.size(1))]).to(data.x.device)
    preds = model.decode(z, edges)
    
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

def train_model(data,query, device):
    model = GAE(input_dim=num_features,hidden_dim=200)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.BCEWithLogitsLoss()

    loss_vec = []

    for epoch in range(500):
        loss = train(model, data, query, optimizer, criterion)
        
        loss_vec.append(loss)

    return model

In [12]:
adj_matrix = torch.ones((ode_dim, ode_dim))
query_edge_index = adj_matrix.nonzero().t().contiguous()

print(query_edge_index)

rem_query = []
for k in range(0,len(query_edge_index[0])):
    for j in range(0,len(data.edge_index[0])):
        if query_edge_index[0][k] == data.edge_index[0][j] and query_edge_index[1][k] == data.edge_index[1][j]:
            rem_query.append(k)

mask = torch.ones(query_edge_index.shape[1], dtype=torch.bool)
mask[rem_query] = False

query_edge_index = query_edge_index[:,mask]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = data.to(device)

adj = torch.zeros((ode_dim, ode_dim))
count_adj = torch.zeros((ode_dim, ode_dim))

for k in range(0,len(query_edge_index[0])):
    print(str(k+1)+'/'+str(len(query_edge_index[0])))
    query = torch.tensor([[query_edge_index[0][k]],[query_edge_index[1][k]]])
    query = query.to(device)
    model = train_model(data,query, device)
    model.eval()
    
    z = model.encode(data)
    dec = model.decode(z, query)


    for k in range(0, len(query[0])):
        adj[query[0][k], query[1][k]] = torch.sigmoid(dec[k])
        count_adj[query[0][k], query[1][k]] += 1

tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
         3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
         6],
        [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,
         3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5,
         6]])
1/38
2/38
3/38
4/38
5/38
6/38
7/38
8/38
9/38
10/38
11/38
12/38
13/38
14/38
15/38
16/38
17/38
18/38
19/38
20/38
21/38
22/38
23/38
24/38
25/38
26/38
27/38
28/38
29/38
30/38
31/38
32/38
33/38
34/38
35/38
36/38
37/38
38/38


In [None]:
from scipy.io import savemat
inferred_adj = (adj/count_adj).detach().numpy()

print(true_data.edge_index[:,rem])
print()
print(inferred_adj[3,6])
print(inferred_adj)
# savemat("dyn_bifurcating_3_6.mat",{"inferred_adj": inferred_adj})

inferred_adj = inferred_adj[~np.isnan(inferred_adj)]
inferred_adj = torch.tensor(inferred_adj)

new_edges = torch.topk(inferred_adj.view(-1), 20)[1]  # Get top-k new edges
print(new_edges)
print(torch.topk(inferred_adj.view(-1), 38)) 

print(inferred_adj)

tensor([[3],
        [6]])

0.00036959173
[[9.9953568e-01           nan 2.4665186e-01 2.5782844e-12 7.8993517e-10
  1.6207440e-02 2.9352617e-01]
 [2.1613001e-15 9.9751943e-01           nan 1.5519875e-06 4.3758081e-07
  3.3765574e-04 4.9630828e-02]
 [1.4942311e-03 6.0407910e-14 1.0000000e+00           nan           nan
  9.9199706e-01 4.0135999e-11]
 [          nan 9.5236230e-01 3.5255669e-13           nan           nan
            nan 3.6959173e-04]
 [          nan 3.7981442e-07 3.7033052e-10           nan           nan
  9.9981290e-01 3.1388536e-04]
 [4.0403061e-02 4.1739629e-03 9.1271488e-07 1.2963502e-03 1.8535373e-05
  9.4620958e-03 4.5242172e-04]
 [4.1429153e-08 1.7850245e-04 2.1379344e-07 5.4521451e-04 5.5758834e-02
  2.5165857e-05 3.5186542e-06]]
tensor([14, 22,  0,  7, 15, 17,  5,  1, 35, 11, 24,  4, 29, 25, 12, 27, 34, 30,
        19, 10])
torch.return_types.topk(
values=tensor([1.0000e+00, 9.9981e-01, 9.9954e-01, 9.9752e-01, 9.9200e-01, 9.5236e-01,
        2.9353e-01, 2.4665e