In [1]:
import numpy as np
import pandas as pd
import torch
from load_msi import create_msi
from visualization import (
        generate_nx_graph, generate_khop_subgraph, generate_pyvis_graph, 
        visualize_tsne_embeddings, visualize_pca_embeddings)
from training import split_data, create_minibatches, train_link_predictor
from torch_geometric.utils import to_networkx, degree
from torch_geometric.nn import to_hetero

from gnn import GAT, GIN, DotProductLinkPredictor
import torch_geometric.transforms as T
print(torch.__version__)

1.13.1+cu117


In [2]:
msi = create_msi("msi", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data")

data = msi.hetero_data
#data = T.ToUndirected()(data)

In [3]:
# Generate and visualize graphs (HIV: id = C0019693, gid = 840)
##G, G_undirected = generate_nx_graph(msi.data)
##subG = generate_khop_subgraph(msi, G, 840, 1)
##py_graph = generate_pyvis_graph(subG)
##py_graph.show("test.html")

In [4]:
# Set the device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [5]:
print("Average number of 1-hop neighbors per node of edge type\n")
for edge_type in data.edge_types:
    degrees = degree(data[edge_type].edge_index[0]).numpy()
    print(f"{edge_type} : {np.mean(degrees):.1f}")

Average number of 1-hop neighbors per node of edge type

('drug', 'binds', 'prot') : 5.2
('dz', 'implicates', 'prot') : 30.0
('prot', 'associates', 'prot') : 64.0
('prot', 'partOf', 'func') : 2.9
('func', 'partOf', 'func') : 4.6
('drug', 'treats', 'dz') : 3.6
('prot', 'rev_binds', 'drug') : 0.7
('prot', 'rev_implicates', 'dz') : 2.1
('func', 'rev_partOf', 'prot') : 5.4
('dz', 'rev_treats', 'drug') : 7.1


In [6]:
# Prepare data for GNN.
edge_types = ("drug", "treats", "dz")
rev_edge_types = ("dz", "rev_treats", "drug")

# Split the data into train, validation, and test sets.
train_data, val_data, test_data = split_data(
        data, val_size=0.1, test_size=0.1, supervision_ratio=0.3,
        edge_types=edge_types, neg_sampling_ratio=1.0,
        rev_edge_types=rev_edge_types)




# Create minibatches.
train_loader = create_minibatches(
        train_data, edge_types=edge_types, #num_khop_neighbors=[50,25,10], 
        neg_sampling_ratio=1.0, batch_size=16)
val_loader = create_minibatches(
        val_data, edge_types=edge_types, #num_khop_neighbors = [50,25,10],
        batch_size=16)
test_loader = create_minibatches(
        test_data, edge_types=edge_types, #num_khop_neighbors = [50,25,10],
        batch_size=16)



In [7]:
# Instantiate the model and generate some initial node embeddings prior to training.
model = GAT(4, 64, 16)
model = to_hetero(model, data.metadata(), aggr='sum').to(device)
predictor = DotProductLinkPredictor(32, 32, 1).to(device)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
#model.reset_parameters()
#predictor.reset_parameters()

# Train and validate the model for a given number of epochs.
for epoch in range(1, 50+1):
    
    # Training..
    loss, auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=train_loader, optimizer=optimizer, device=device, train=True)
    
    # Validation.
    val_loss, val_auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=val_loader,  optimizer=optimizer, device=device, train=False)
    
    print(f"Epoch     : {epoch}")
    print(f" Loss     : {loss:<10.4f} |     AUC:     {auc:>10.4f}")
    print(f" Val_Loss : {val_loss:<10.4f} |     Val_AUC: {val_auc:>10.4f}")

100%|██████████| 89/89 [00:09<00:00,  9.48it/s]
100%|██████████| 74/74 [00:02<00:00, 25.75it/s]


Epoch     : 1
 Loss     : 1.1122     |     AUC:         0.7034
 Val_Loss : 0.6810     |     Val_AUC:     0.7468


100%|██████████| 89/89 [00:08<00:00, 10.27it/s]
100%|██████████| 74/74 [00:02<00:00, 25.58it/s]


Epoch     : 2
 Loss     : 0.6159     |     AUC:         0.7678
 Val_Loss : 0.6110     |     Val_AUC:     0.7652


100%|██████████| 89/89 [00:08<00:00, 10.22it/s]
100%|██████████| 74/74 [00:02<00:00, 25.90it/s]


Epoch     : 3
 Loss     : 0.5707     |     AUC:         0.8020
 Val_Loss : 0.5942     |     Val_AUC:     0.7908


100%|██████████| 89/89 [00:08<00:00, 10.33it/s]
100%|██████████| 74/74 [00:02<00:00, 25.67it/s]


Epoch     : 4
 Loss     : 0.5576     |     AUC:         0.8159
 Val_Loss : 0.5143     |     Val_AUC:     0.8263


100%|██████████| 89/89 [00:08<00:00, 10.47it/s]
100%|██████████| 74/74 [00:02<00:00, 26.45it/s]


Epoch     : 5
 Loss     : 0.5664     |     AUC:         0.8110
 Val_Loss : 0.5296     |     Val_AUC:     0.8217


100%|██████████| 89/89 [00:08<00:00, 10.43it/s]
100%|██████████| 74/74 [00:02<00:00, 26.05it/s]


Epoch     : 6
 Loss     : 0.5566     |     AUC:         0.8018
 Val_Loss : 0.5057     |     Val_AUC:     0.8461


100%|██████████| 89/89 [00:08<00:00, 10.48it/s]
100%|██████████| 74/74 [00:02<00:00, 26.59it/s]


Epoch     : 7
 Loss     : 0.5216     |     AUC:         0.8303
 Val_Loss : 0.4979     |     Val_AUC:     0.8457


100%|██████████| 89/89 [00:08<00:00, 10.52it/s]
100%|██████████| 74/74 [00:02<00:00, 26.23it/s]


Epoch     : 8
 Loss     : 0.5347     |     AUC:         0.8229
 Val_Loss : 0.5132     |     Val_AUC:     0.8526


100%|██████████| 89/89 [00:08<00:00, 10.45it/s]
100%|██████████| 74/74 [00:02<00:00, 26.47it/s]


Epoch     : 9
 Loss     : 0.5421     |     AUC:         0.8175
 Val_Loss : 0.5111     |     Val_AUC:     0.8345


100%|██████████| 89/89 [00:08<00:00, 10.40it/s]
100%|██████████| 74/74 [00:02<00:00, 26.02it/s]


Epoch     : 10
 Loss     : 0.5198     |     AUC:         0.8326
 Val_Loss : 0.5015     |     Val_AUC:     0.8316


 55%|█████▌    | 49/89 [00:04<00:03, 10.43it/s]


KeyboardInterrupt: 

In [9]:
# Test set.
test_loss, test_auc = train_link_predictor(
        model, predictor, "drug", "dz", edge_types=edge_types, 
        loader=test_loader,  optimizer=optimizer, device=device, train=False)
    
print(f" Test Loss     : {test_loss:<10.4f} |     Test AUC:     {test_auc:>10.4f}")

100%|██████████| 74/74 [00:02<00:00, 25.27it/s]

 Test Loss     : 0.4489     |     Test AUC:         0.8682





In [10]:
data = T.ToUndirected()(data)

# Split the data into train, validation, and test sets.
total_data, _, _ = split_data(
        data, val_size=0.0, test_size=0.0, supervision_ratio=0.3,
        edge_types=edge_types, neg_sampling_ratio=1.0,
        rev_edge_types=rev_edge_types)

total_data_loader = create_minibatches(
        total_data, edge_types=edge_types, #num_khop_neighbors=[50,25,10], 
        neg_sampling_ratio=1.0, batch_size=16)

# Retrain with the entire dataset
for epoch in range(1, 10+1):
    
    # Training with entire dataset...
    loss, auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=total_data_loader, optimizer=optimizer, device=device, train=True)
    
    print(f"Epoch     : {epoch}")
    print(f" Loss     : {loss:<10.4f} |     AUC:     {auc:>10.4f}")

100%|██████████| 112/112 [00:11<00:00,  9.35it/s]


Epoch     : 1
 Loss     : 0.4988     |     AUC:         0.8433


100%|██████████| 112/112 [00:11<00:00,  9.48it/s]


Epoch     : 2
 Loss     : 0.5002     |     AUC:         0.8414


100%|██████████| 112/112 [00:11<00:00,  9.45it/s]


Epoch     : 3
 Loss     : 0.5081     |     AUC:         0.8398


100%|██████████| 112/112 [00:11<00:00,  9.37it/s]


Epoch     : 4
 Loss     : 0.4955     |     AUC:         0.8482


100%|██████████| 112/112 [00:11<00:00,  9.49it/s]


Epoch     : 5
 Loss     : 0.4960     |     AUC:         0.8464


100%|██████████| 112/112 [00:11<00:00,  9.48it/s]


Epoch     : 6
 Loss     : 0.4970     |     AUC:         0.8456


100%|██████████| 112/112 [00:11<00:00,  9.46it/s]


Epoch     : 7
 Loss     : 0.5068     |     AUC:         0.8390


100%|██████████| 112/112 [00:11<00:00,  9.46it/s]


Epoch     : 8
 Loss     : 0.4846     |     AUC:         0.8525


100%|██████████| 112/112 [00:12<00:00,  9.32it/s]


Epoch     : 9
 Loss     : 0.4981     |     AUC:         0.8479


100%|██████████| 112/112 [00:11<00:00,  9.42it/s]


Epoch     : 10
 Loss     : 0.4900     |     AUC:         0.8468


In [11]:
# Add hiv information to the msi (HIV: id = C0019693, gid = 840).
msi = create_msi("msi-hiv", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data", 
                 extra_prot_data="../data/hiv/3_hivProt_humanProt_map.tsv",
                 extra_drug_data="../data/hiv/1b_drugs_prot_map_total.tsv", 
                 extra_dz_data="../data/hiv/2b_hiv_prot_map_total.tsv",
                 extra_drug_dz_links="../data/hiv/4b_drug_hiv_map_total.tsv")

hiv_data = msi.hetero_data
#hiv_data = T.ToUndirected()(hiv_data)

hiv_id = "C0019693"
hiv_gid = 840




In [12]:
# Send models and data back to cpu.
hiv_data.to("cpu")
model.to("cpu")
predictor.to("cpu")

DotProductLinkPredictor(
  (lin1): Linear(32, 32, bias=True)
  (lin2): Linear(32, 1, bias=True)
)

In [13]:
#torch.cuda.empty_cache()
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
#print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
#print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
#print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [14]:
#z = get_node_embeddings(model, hiv_data)
model.eval()
predictor.eval()
z = model(hiv_data.x_dict, hiv_data.edge_index_dict)
z_drugs = z["drug"]
z_dz = z["dz"]
hiv_emb = z_dz[hiv_gid]


In [15]:
# Predict the drug links to HIV.
drug_edge_preds = []
dz_edge_preds = []
for drug_emb in z_drugs:
    drug_edge_pred = predictor(drug_emb, hiv_emb)
    drug_edge_pred = drug_edge_pred.detach().numpy()
    drug_edge_preds.append(drug_edge_pred)
drug_sorted_edges = np.sort(drug_edge_preds)[::-1]
drug_sorted_ind = np.argsort(drug_edge_preds)[::-1]

# Find dz links to HIV (note: the model was not trained to do this.)
for dz_emb in z_dz:
    dz_edge_pred = predictor(dz_emb, hiv_emb)
    dz_edge_pred = dz_edge_pred.detach().numpy()
    dz_edge_preds.append(dz_edge_pred)
dz_sorted_edges = np.sort(dz_edge_preds)[::-1]
dz_sorted_ind = np.argsort(dz_edge_preds)[::-1]




In [16]:
import pandas as pd
id_list = []
drug_list = []
sim_list = []
for ind, edge in zip(drug_sorted_ind, drug_sorted_edges):
    id = [i for i in msi.mappings["drug_id_gid_map"] if msi.mappings["drug_id_gid_map"][i] == ind][0]
    drug = msi.mappings["drug_id_name_map"][id]
    id_list.append(id)
    drug_list.append(drug)
    sim_list.append(edge)
    print(f"Drug: {drug:<30}    |    HIV edge pred: {edge:>10.4}")

res_dict = {"id":id_list, "name" : drug_list, "similarity" : sim_list}
res_df = pd.DataFrame(res_dict)

print("\n\n\n")

for ind, edge in zip(dz_sorted_ind, dz_sorted_edges):
    id = [i for i in msi.mappings["dz_id_gid_map"] if msi.mappings["dz_id_gid_map"][i] == ind][0]
    dz = msi.mappings["dz_id_name_map"][id]
    print(f"DZ: {dz:<30}    |    HIV edge pred: {edge:>10.4}")

res_dict = {"id":id_list, "name" : drug_list, "similarity" : sim_list}
res_df = pd.DataFrame(res_dict)


Drug: nan                               |    HIV edge pred:     0.1607
Drug: cholic-acid                       |    HIV edge pred:     0.1285
Drug: dapiprazole                       |    HIV edge pred:     0.1263
Drug: eflornithine                      |    HIV edge pred:      0.104
Drug: bunazosin                         |    HIV edge pred:    0.09928
Drug: nan                               |    HIV edge pred:    0.09429
Drug: acetazolamide                     |    HIV edge pred:    0.08902
Drug: doluregravir                      |    HIV edge pred:    0.07937
Drug: nan                               |    HIV edge pred:    0.07474
Drug: trientine                         |    HIV edge pred:    0.07449
Drug: Vitamin E                         |    HIV edge pred:    0.06447
Drug: thiamine                          |    HIV edge pred:    0.06295
Drug: penicillin-v-potassium            |    HIV edge pred:    0.06143
Drug: acyclovir                         |    HIV edge pred:    0.05874
Drug: 

In [23]:
# Write results to file
res_df.to_csv("../results/res3_val_0.467_.867_test_0.436_.874_total.tsv", sep="\t", index=False)