In [1]:
import numpy as np
import pandas as pd
import torch
from load_msi import create_msi
from visualization import (
        generate_nx_graph, generate_khop_subgraph, generate_pyvis_graph, 
        visualize_tsne_embeddings, visualize_pca_embeddings)
from training import split_data, create_minibatches, train_link_predictor
from torch_geometric.utils import to_networkx, degree
from torch_geometric.nn import to_hetero

from gnn import GAT, GIN, DotProductLinkPredictor
import torch_geometric.transforms as T
print(torch.__version__)

1.13.1+cu117


In [2]:
msi = create_msi("msi", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data")

data = msi.hetero_data
data = T.ToUndirected()(data)

In [3]:
# Generate and visualize graphs (HIV: id = C0019693, gid = 840)
##G, G_undirected = generate_nx_graph(msi.data)
##subG = generate_khop_subgraph(msi, G, 840, 1)
##py_graph = generate_pyvis_graph(subG)
##py_graph.show("test.html")

In [4]:
# Set the device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [5]:
edge_type = ("dz", "implicates", "prot")
id_map = msi.mappings["dz_id_gid_map"]
name_map = msi.mappings["dz_id_name_map"]



degrees = degree(data[edge_type].edge_index[0]).numpy()
sorted_gids = np.argsort(degrees)[::-1]
sorted_degs = np.sort(degrees)[::-1]


#for gid, deg in zip(sorted_gids, sorted_degs):
#    id = list(id_map.keys())[list(id_map.values()).index(gid)]
#    name = name_map[id]
#    print(f"Disease: {name:<45}   |   1-hop neighbors: {deg:>5}   |   id: {id}")


In [6]:
# Prepare data for GNN.
edge_types = ("drug", "treats", "dz")
rev_edge_types = ("dz", "rev_treats", "drug")

# Split the data into train, validation, and test sets.
train_data, val_data, test_data = split_data(
        data, val_size=0.1, test_size=0.1, supervision_ratio=0.3,
        edge_types=edge_types, neg_sampling_ratio=1.0,
        rev_edge_types=rev_edge_types)


# Create minibatches.
train_loader = create_minibatches(
        train_data, edge_types=edge_types, num_khop_neighbors=[50,25,10], neg_sampling_ratio=1.0, batch_size=16)
val_loader = create_minibatches(
        val_data, edge_types=edge_types, num_khop_neighbors = [50,25,10], batch_size=16)



In [7]:
# Instantiate the model and generate some initial node embeddings prior to training.
model = GAT(4, 64, 16)
model = to_hetero(model, data.metadata(), aggr='sum').to(device)
predictor = DotProductLinkPredictor(32, 32, 1).to(device)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
#model.reset_parameters()
#predictor.reset_parameters()

# Train and validate the model for a given number of epochs.
for epoch in range(1, 50+1):
    
    # Training..
    loss, auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=train_loader, optimizer=optimizer, device=device, train=True)
    
    # Validation.
    val_loss, val_auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=val_loader,  optimizer=optimizer, device=device, train=False)
    
    print(f"Epoch     : {epoch}")
    print(f" Loss     : {loss:<10.4f} |     AUC:     {auc:>10.4f}")
    print(f" Val_Loss : {val_loss:<10.4f} |     Val_AUC: {val_auc:>10.4f}")

100%|██████████| 89/89 [00:32<00:00,  2.71it/s]
100%|██████████| 74/74 [00:10<00:00,  7.36it/s]


Epoch     : 1
 Loss     : 38.7276    |     AUC:         0.5138
 Val_Loss : 0.9099     |     Val_AUC:     0.4989


100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 74/74 [00:10<00:00,  7.32it/s]


Epoch     : 2
 Loss     : 0.8807     |     AUC:         0.5207
 Val_Loss : 0.8202     |     Val_AUC:     0.5069


100%|██████████| 89/89 [00:30<00:00,  2.93it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 3
 Loss     : 0.8447     |     AUC:         0.4860
 Val_Loss : 0.8477     |     Val_AUC:     0.5154


100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 74/74 [00:10<00:00,  7.21it/s]


Epoch     : 4
 Loss     : 0.8095     |     AUC:         0.5015
 Val_Loss : 0.7692     |     Val_AUC:     0.5039


100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 5
 Loss     : 2.1530     |     AUC:         0.5140
 Val_Loss : 0.7850     |     Val_AUC:     0.4927


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 6
 Loss     : 0.7613     |     AUC:         0.4922
 Val_Loss : 0.7423     |     Val_AUC:     0.4999


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.17it/s]


Epoch     : 7
 Loss     : 0.7507     |     AUC:         0.5010
 Val_Loss : 0.7325     |     Val_AUC:     0.4946


100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 74/74 [00:10<00:00,  7.13it/s]


Epoch     : 8
 Loss     : 0.7736     |     AUC:         0.4981
 Val_Loss : 0.7251     |     Val_AUC:     0.5045


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 9
 Loss     : 0.7247     |     AUC:         0.5012
 Val_Loss : 0.7385     |     Val_AUC:     0.5065


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.21it/s]


Epoch     : 10
 Loss     : 0.7074     |     AUC:         0.5072
 Val_Loss : 0.7411     |     Val_AUC:     0.5078


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.18it/s]


Epoch     : 11
 Loss     : 0.6227     |     AUC:         0.6863
 Val_Loss : 0.4864     |     Val_AUC:     0.7862


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.18it/s]


Epoch     : 12
 Loss     : 0.4925     |     AUC:         0.8085
 Val_Loss : 0.3402     |     Val_AUC:     0.9043


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.16it/s]


Epoch     : 13
 Loss     : 0.4445     |     AUC:         0.8342
 Val_Loss : 0.4182     |     Val_AUC:     0.8811


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.16it/s]


Epoch     : 14
 Loss     : 0.4429     |     AUC:         0.8542
 Val_Loss : 0.3255     |     Val_AUC:     0.9131


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.17it/s]


Epoch     : 15
 Loss     : 0.3917     |     AUC:         0.8677
 Val_Loss : 0.2595     |     Val_AUC:     0.9404


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.22it/s]


Epoch     : 16
 Loss     : 0.3690     |     AUC:         0.8850
 Val_Loss : 0.5363     |     Val_AUC:     0.9364


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.19it/s]


Epoch     : 17
 Loss     : 0.3752     |     AUC:         0.8838
 Val_Loss : 0.3027     |     Val_AUC:     0.9292


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.29it/s]


Epoch     : 18
 Loss     : 0.3845     |     AUC:         0.8749
 Val_Loss : 0.2566     |     Val_AUC:     0.9367


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.14it/s]


Epoch     : 19
 Loss     : 0.3587     |     AUC:         0.8856
 Val_Loss : 0.2656     |     Val_AUC:     0.9379


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.21it/s]


Epoch     : 20
 Loss     : 0.3657     |     AUC:         0.8877
 Val_Loss : 0.2905     |     Val_AUC:     0.9299


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.16it/s]


Epoch     : 21
 Loss     : 0.3502     |     AUC:         0.8926
 Val_Loss : 0.2984     |     Val_AUC:     0.9320


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.19it/s]


Epoch     : 22
 Loss     : 0.3589     |     AUC:         0.8854
 Val_Loss : 0.3125     |     Val_AUC:     0.9257


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.26it/s]


Epoch     : 23
 Loss     : 0.3757     |     AUC:         0.8814
 Val_Loss : 0.3204     |     Val_AUC:     0.9468


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.14it/s]


Epoch     : 24
 Loss     : 0.3479     |     AUC:         0.8892
 Val_Loss : 0.2562     |     Val_AUC:     0.9322


100%|██████████| 89/89 [00:31<00:00,  2.87it/s]
100%|██████████| 74/74 [00:10<00:00,  7.22it/s]


Epoch     : 25
 Loss     : 0.3478     |     AUC:         0.8905
 Val_Loss : 0.2718     |     Val_AUC:     0.9447


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.15it/s]


Epoch     : 26
 Loss     : 0.3494     |     AUC:         0.8871
 Val_Loss : 0.2642     |     Val_AUC:     0.9417


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.17it/s]


Epoch     : 27
 Loss     : 0.3390     |     AUC:         0.8919
 Val_Loss : 0.2358     |     Val_AUC:     0.9512


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 28
 Loss     : 0.3445     |     AUC:         0.8988
 Val_Loss : 0.3275     |     Val_AUC:     0.9344


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.21it/s]


Epoch     : 29
 Loss     : 0.3465     |     AUC:         0.8913
 Val_Loss : 0.5057     |     Val_AUC:     0.9430


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 30
 Loss     : 0.3360     |     AUC:         0.9062
 Val_Loss : 0.2785     |     Val_AUC:     0.9556


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.16it/s]


Epoch     : 31
 Loss     : 0.3362     |     AUC:         0.9013
 Val_Loss : 0.3014     |     Val_AUC:     0.9421


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 32
 Loss     : 0.3474     |     AUC:         0.8888
 Val_Loss : 0.2548     |     Val_AUC:     0.9481


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.25it/s]


Epoch     : 33
 Loss     : 0.3386     |     AUC:         0.9071
 Val_Loss : 0.3418     |     Val_AUC:     0.9476


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.10it/s]


Epoch     : 34
 Loss     : 0.3265     |     AUC:         0.9036
 Val_Loss : 0.3175     |     Val_AUC:     0.9471


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.09it/s]


Epoch     : 35
 Loss     : 0.3435     |     AUC:         0.8924
 Val_Loss : 0.2310     |     Val_AUC:     0.9606


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.18it/s]


Epoch     : 36
 Loss     : 0.3415     |     AUC:         0.8983
 Val_Loss : 0.2647     |     Val_AUC:     0.9390


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.18it/s]


Epoch     : 37
 Loss     : 0.3184     |     AUC:         0.9030
 Val_Loss : 0.2478     |     Val_AUC:     0.9436


100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 74/74 [00:10<00:00,  7.24it/s]


Epoch     : 38
 Loss     : 0.3433     |     AUC:         0.8936
 Val_Loss : 0.2563     |     Val_AUC:     0.9412


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 39
 Loss     : 0.3372     |     AUC:         0.8940
 Val_Loss : 0.2536     |     Val_AUC:     0.9355


100%|██████████| 89/89 [00:30<00:00,  2.87it/s]
100%|██████████| 74/74 [00:10<00:00,  7.21it/s]


Epoch     : 40
 Loss     : 0.3393     |     AUC:         0.8982
 Val_Loss : 0.2751     |     Val_AUC:     0.9206


100%|██████████| 89/89 [00:31<00:00,  2.87it/s]
100%|██████████| 74/74 [00:10<00:00,  7.15it/s]


Epoch     : 41
 Loss     : 0.3300     |     AUC:         0.9016
 Val_Loss : 0.2625     |     Val_AUC:     0.9479


100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 74/74 [00:10<00:00,  7.24it/s]


Epoch     : 42
 Loss     : 0.3363     |     AUC:         0.8967
 Val_Loss : 0.6504     |     Val_AUC:     0.8410


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.20it/s]


Epoch     : 43
 Loss     : 0.3430     |     AUC:         0.9022
 Val_Loss : 0.3031     |     Val_AUC:     0.9406


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.22it/s]


Epoch     : 44
 Loss     : 0.3383     |     AUC:         0.8926
 Val_Loss : 0.2531     |     Val_AUC:     0.9432


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.19it/s]


Epoch     : 45
 Loss     : 0.3577     |     AUC:         0.8942
 Val_Loss : 0.2532     |     Val_AUC:     0.9458


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 74/74 [00:10<00:00,  7.19it/s]


Epoch     : 46
 Loss     : 0.3392     |     AUC:         0.9083
 Val_Loss : 0.2495     |     Val_AUC:     0.9382


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.16it/s]


Epoch     : 47
 Loss     : 0.3361     |     AUC:         0.8994
 Val_Loss : 0.2423     |     Val_AUC:     0.9512


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.24it/s]


Epoch     : 48
 Loss     : 0.3192     |     AUC:         0.9156
 Val_Loss : 0.2240     |     Val_AUC:     0.9647


100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 74/74 [00:10<00:00,  7.17it/s]


Epoch     : 49
 Loss     : 0.3385     |     AUC:         0.9021
 Val_Loss : 0.2565     |     Val_AUC:     0.9599


100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 74/74 [00:10<00:00,  7.13it/s]

Epoch     : 50
 Loss     : 0.3046     |     AUC:         0.9172
 Val_Loss : 0.2780     |     Val_AUC:     0.9456





In [16]:

msi = create_msi("msi-hiv", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data", 
                 extra_dz_data="../data/hiv/hiv_protein30.tsv")

hiv_data = msi.hetero_data
hiv_data = T.ToUndirected()(hiv_data)
#(HIV: id = C0019693, gid = 840)
hiv_id = "C0019693"
hiv_gid = 840



In [17]:
def get_node_embeddings(model, data):
    model.eval()
    z = model(data.x_dict, data.edge_index_dict)
    return z

In [18]:
hiv_data.to("cpu")
model.to("cpu")
predictor.to("cpu")

DotProductLinkPredictor(
  (lin1): Linear(32, 32, bias=True)
  (lin2): Linear(32, 1, bias=True)
)

In [19]:
torch.cuda.empty_cache()
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
#print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
#print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
#print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [26]:
z = get_node_embeddings(model, hiv_data)
z = model(hiv_data.x_dict, hiv_data.edge_index_dict)
z_drugs = z["drug"]
z_dz = z["dz"]
hiv_emb = z["dz"][hiv_gid]

In [27]:
from numpy import dot
from numpy.linalg import norm



#hiv_emb = hiv_emb.cpu().detach().numpy()
drug_edge_preds = []
dz_edge_preds = []
for drug_emb in z_drugs:
    drug_edge_pred = predictor(drug_emb, hiv_emb)
    drug_edge_pred = drug_edge_pred.detach().numpy()
    drug_edge_preds.append(drug_edge_pred)
drug_sorted_edges = np.sort(drug_edge_preds)[::-1]
drug_sorted_ind = np.argsort(drug_edge_preds)[::-1]


for dz_emb in z_dz:
    dz_edge_pred = predictor(dz_emb, hiv_emb)
    dz_edge_pred = dz_edge_pred.detach().numpy()
    dz_edge_preds.append(dz_edge_pred)
dz_sorted_edges = np.sort(dz_edge_preds)[::-1]
dz_sorted_ind = np.argsort(dz_edge_preds)[::-1]




In [28]:
import pandas as pd
id_list = []
drug_list = []
sim_list = []
for ind, edge in zip(drug_sorted_ind, drug_sorted_edges):
    id = [i for i in msi.mappings["drug_id_gid_map"] if msi.mappings["drug_id_gid_map"][i] == ind][0]
    drug = msi.mappings["drug_id_name_map"][id]
    id_list.append(id)
    drug_list.append(drug)
    sim_list.append(edge)
    print(f"Drug: {drug:<30}    |    HIV edge pred: {edge:>10.4}")

res_dict = {"id":id_list, "name" : drug_list, "similarity" : sim_list}
res_df = pd.DataFrame(res_dict)

print("\n\n\n")

for ind, edge in zip(dz_sorted_ind, dz_sorted_edges):
    id = [i for i in msi.mappings["dz_id_gid_map"] if msi.mappings["dz_id_gid_map"][i] == ind][0]
    dz = msi.mappings["dz_id_name_map"][id]
    print(f"DZ: {dz:<30}    |    HIV edge pred: {edge:>10.4}")

res_dict = {"id":id_list, "name" : drug_list, "similarity" : sim_list}
res_df = pd.DataFrame(res_dict)


Drug: nan                               |    HIV edge pred:      2.233
Drug: Dexfenfluramine                   |    HIV edge pred:      2.072
Drug: nan                               |    HIV edge pred:      2.054
Drug: enoximone                         |    HIV edge pred:      1.985
Drug: nan                               |    HIV edge pred:      1.898
Drug: mianserin                         |    HIV edge pred:      1.873
Drug: viloxazine                        |    HIV edge pred:      1.856
Drug: risperidone                       |    HIV edge pred:      1.832
Drug: hyaluronidase (ovine)             |    HIV edge pred:      1.821
Drug: ethynodiol-diacetate              |    HIV edge pred:      1.815
Drug: Azelaic Acid                      |    HIV edge pred:      1.814
Drug: tolbutamide                       |    HIV edge pred:      1.812
Drug: carprofen                         |    HIV edge pred:      1.797
Drug: mitiglinide                       |    HIV edge pred:      1.762
Drug: 

In [None]:
# Write results to file
#res_df.to_csv("../results/res1_loss0.47_auc0.76.tsv", sep="\t", index=False)