In [1]:
import numpy as np
import torch
from load_msi import create_msi
from visualization import (
        generate_nx_graph, generate_khop_subgraph, generate_pyvis_graph, 
        visualize_tsne_embeddings, visualize_pca_embeddings)
from training import split_data, create_minibatches, train_link_predictor
from torch_geometric.utils import to_networkx, degree
from torch_geometric.nn import to_hetero

from gnn import GAT, GIN, DotProductLinkPredictor
import torch_geometric.transforms as T
print(torch.__version__)

1.13.1+cu117


In [2]:
msi = create_msi("msi", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data")
data = msi.hetero_data
data = T.ToUndirected()(data)

In [3]:
# Generate and visualize graphs (HIV: id = C0019693, gid = 840)
##G, G_undirected = generate_nx_graph(msi.data)
##subG = generate_khop_subgraph(msi, G, 840, 1)
##py_graph = generate_pyvis_graph(subG)
##py_graph.show("test.html")

In [4]:
# Set the device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [5]:
edge_type = ("dz", "implicates", "prot")
id_map = msi.mappings["dz_id_gid_map"]
name_map = msi.mappings["dz_id_name_map"]



degrees = degree(data[edge_type].edge_index[0]).numpy()
sorted_gids = np.argsort(degrees)[::-1]
sorted_degs = np.sort(degrees)[::-1]


#for gid, deg in zip(sorted_gids, sorted_degs):
#    id = list(id_map.keys())[list(id_map.values()).index(gid)]
#    name = name_map[id]
#    print(f"Disease: {name:<45}   |   1-hop neighbors: {deg:>5}   |   id: {id}")


In [6]:
# Prepare data for GNN.
edge_types = ("drug", "treats", "dz")
rev_edge_types = ("dz", "rev_treats", "drug")

# Split the data into train, validation, and test sets.
train_data, val_data, test_data = split_data(
        data, val_size=0.1, test_size=0.1, supervision_ratio=0.3,
        edge_types=edge_types, neg_sampling_ratio=1.0,
        rev_edge_types=rev_edge_types)




# Create minibatches.
train_loader = create_minibatches(
        train_data, edge_types=edge_types, num_khop_neighbors=[50,25,10, 10], neg_sampling_ratio=1.0, batch_size=16)
val_loader = create_minibatches(
        val_data, edge_types=edge_types, num_khop_neighbors = [50,25,10, 10], batch_size=16)



In [7]:
# Instantiate the model and generate some initial node embeddings prior to training.
model = GAT(64, 16)
model = to_hetero(model, data.metadata(), aggr='sum').to(device)
predictor = DotProductLinkPredictor(32, 32, 1).to(device)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
#model.reset_parameters()
predictor.reset_parameters()

# Train and validate the model for a given number of epochs.
for epoch in range(1, 30+1):
    
    # Training..
    loss, auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=train_loader, optimizer=optimizer, device=device, train=True)
    
    # Validation.
    val_loss, val_auc = train_link_predictor(
            model, predictor, "drug", "dz", edge_types=edge_types, 
            loader=val_loader,  optimizer=optimizer, device=device, train=False)
    
    print(f"Epoch     : {epoch}")
    print(f" Loss     : {loss:<10.4f} |     AUC:     {auc:>10.4f}")
    print(f" Val_Loss : {val_loss:<10.4f} |     Val_AUC: {val_auc:>10.4f}")

100%|██████████| 89/89 [00:40<00:00,  2.18it/s]
100%|██████████| 74/74 [00:15<00:00,  4.89it/s]


Epoch     : 1
 Loss     : 0.8193     |     AUC:         0.5588
 Val_Loss : 0.6444     |     Val_AUC:     0.6831


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.90it/s]


Epoch     : 2
 Loss     : 0.6480     |     AUC:         0.7032
 Val_Loss : 0.5770     |     Val_AUC:     0.7650


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.91it/s]


Epoch     : 3
 Loss     : 0.6290     |     AUC:         0.7331
 Val_Loss : 0.5809     |     Val_AUC:     0.7690


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.91it/s]


Epoch     : 4
 Loss     : 0.5950     |     AUC:         0.7608
 Val_Loss : 0.5376     |     Val_AUC:     0.8097


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:14<00:00,  4.98it/s]


Epoch     : 5
 Loss     : 0.5660     |     AUC:         0.7818
 Val_Loss : 0.5714     |     Val_AUC:     0.7669


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.88it/s]


Epoch     : 6
 Loss     : 0.5678     |     AUC:         0.7767
 Val_Loss : 0.5434     |     Val_AUC:     0.8041


100%|██████████| 89/89 [00:40<00:00,  2.22it/s]
100%|██████████| 74/74 [00:15<00:00,  4.93it/s]


Epoch     : 7
 Loss     : 0.5570     |     AUC:         0.7780
 Val_Loss : 0.5108     |     Val_AUC:     0.8306


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:14<00:00,  4.94it/s]


Epoch     : 8
 Loss     : 0.5317     |     AUC:         0.8009
 Val_Loss : 0.5360     |     Val_AUC:     0.7924


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.89it/s]


Epoch     : 9
 Loss     : 0.5247     |     AUC:         0.8039
 Val_Loss : 0.6323     |     Val_AUC:     0.8058


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:14<00:00,  4.95it/s]


Epoch     : 10
 Loss     : 0.5255     |     AUC:         0.8077
 Val_Loss : 0.5047     |     Val_AUC:     0.8291


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:14<00:00,  4.96it/s]


Epoch     : 11
 Loss     : 0.5272     |     AUC:         0.8047
 Val_Loss : 0.5295     |     Val_AUC:     0.8330


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.91it/s]


Epoch     : 12
 Loss     : 0.5245     |     AUC:         0.8029
 Val_Loss : 0.5390     |     Val_AUC:     0.8250


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:14<00:00,  4.93it/s]


Epoch     : 13
 Loss     : 0.5252     |     AUC:         0.8056
 Val_Loss : 0.5328     |     Val_AUC:     0.7975


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.89it/s]


Epoch     : 14
 Loss     : 0.5247     |     AUC:         0.8037
 Val_Loss : 0.5185     |     Val_AUC:     0.8061


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.92it/s]


Epoch     : 15
 Loss     : 0.5206     |     AUC:         0.8042
 Val_Loss : 0.4999     |     Val_AUC:     0.8315


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.89it/s]


Epoch     : 16
 Loss     : 0.5257     |     AUC:         0.8070
 Val_Loss : 0.5135     |     Val_AUC:     0.8229


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.92it/s]


Epoch     : 17
 Loss     : 0.5341     |     AUC:         0.7931
 Val_Loss : 0.5084     |     Val_AUC:     0.8310


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.89it/s]


Epoch     : 18
 Loss     : 0.5190     |     AUC:         0.8059
 Val_Loss : 0.4854     |     Val_AUC:     0.8420


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:14<00:00,  4.93it/s]


Epoch     : 19
 Loss     : 0.5127     |     AUC:         0.8207
 Val_Loss : 0.5101     |     Val_AUC:     0.8347


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.83it/s]


Epoch     : 20
 Loss     : 0.5149     |     AUC:         0.8106
 Val_Loss : 0.5017     |     Val_AUC:     0.8414


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.90it/s]


Epoch     : 21
 Loss     : 0.4954     |     AUC:         0.8245
 Val_Loss : 0.5270     |     Val_AUC:     0.8315


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.92it/s]


Epoch     : 22
 Loss     : 0.5136     |     AUC:         0.8062
 Val_Loss : 0.5216     |     Val_AUC:     0.8289


100%|██████████| 89/89 [00:40<00:00,  2.22it/s]
100%|██████████| 74/74 [00:15<00:00,  4.91it/s]


Epoch     : 23
 Loss     : 0.5290     |     AUC:         0.8058
 Val_Loss : 0.5362     |     Val_AUC:     0.8401


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.91it/s]


Epoch     : 24
 Loss     : 0.5052     |     AUC:         0.8170
 Val_Loss : 0.5359     |     Val_AUC:     0.8266


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.90it/s]


Epoch     : 25
 Loss     : 0.4923     |     AUC:         0.8265
 Val_Loss : 0.4897     |     Val_AUC:     0.8465


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.87it/s]


Epoch     : 26
 Loss     : 0.4972     |     AUC:         0.8222
 Val_Loss : 0.4991     |     Val_AUC:     0.8448


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:15<00:00,  4.91it/s]


Epoch     : 27
 Loss     : 0.5018     |     AUC:         0.8216
 Val_Loss : 0.5075     |     Val_AUC:     0.8314


100%|██████████| 89/89 [00:39<00:00,  2.24it/s]
100%|██████████| 74/74 [00:14<00:00,  4.96it/s]


Epoch     : 28
 Loss     : 0.4871     |     AUC:         0.8324
 Val_Loss : 0.4994     |     Val_AUC:     0.8380


100%|██████████| 89/89 [00:40<00:00,  2.22it/s]
100%|██████████| 74/74 [00:14<00:00,  4.95it/s]


Epoch     : 29
 Loss     : 0.5013     |     AUC:         0.8149
 Val_Loss : 0.5003     |     Val_AUC:     0.8271


100%|██████████| 89/89 [00:39<00:00,  2.23it/s]
100%|██████████| 74/74 [00:15<00:00,  4.89it/s]

Epoch     : 30
 Loss     : 0.4909     |     AUC:         0.8266
 Val_Loss : 0.5190     |     Val_AUC:     0.8266





In [39]:
msi = create_msi("msi-hiv", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data", 
                 extra_dz_data="../data/hiv/hiv_protein30.tsv")

hiv_data = msi.hetero_data
hiv_data = T.ToUndirected()(hiv_data)
#(HIV: id = C0019693, gid = 840)
hiv_id = "C0019693"
hiv_gid = 840


In [40]:
data_loader = create_minibatches(
        hiv_data, edge_types=edge_types, num_khop_neighbors = [50,25,10, 10], batch_size=16)

AttributeError: 'EdgeStorage' object has no attribute 'edge_label_index'

In [35]:
def get_node_embeddings(model, data):
    model.eval()
    model.to("cpu")
    data.to("cpu")
    z = model(data.x_dict, data.edge_index_dict)
    return z

In [36]:
torch.cuda.empty_cache()
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
#print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
#print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
#print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [37]:
z = get_node_embeddings(model, data)
z_drugs = z["drug"]
hiv_emb = z["dz"][hiv_gid]

RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 1187117568 bytes. Error code 12 (Cannot allocate memory)

In [16]:
from numpy import dot
from numpy.linalg import norm

#hiv_emb = hiv_emb.cpu().detach().numpy()
edge_preds = []
for drug_emb in z_drugs:
    #drug_emb = drug_emb.cpu().detach().numpy()
    #edge_pred = (drug_emb * hiv_emb).sum(dim=-1).cpu().detach().numpy()
    #edge_pred = dot(drug_emb, hiv_emb)/(norm(drug_emb)*norm(hiv_emb))
    edge_pred = predictor(drug_emb, hiv_emb)
    edge_preds.append(edge_pred)
sorted_edges = np.sort(edge_preds)[::-1]
sorted_ind = np.argsort(edge_preds)[::-1]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_mm)

In [None]:
for ind, edge in zip(sorted_ind, sorted_edges):
    id = [i for i in msi.mappings["drug_id_gid_map"] if msi.mappings["drug_id_gid_map"][i] == ind][0]
    drug = msi.mappings["drug_id_name_map"][id]
    
    print(f"Drug: {drug:<30}    |    HIV edge pred: {edge:>10.4}")