In [1]:
import numpy as np
import torch
from load_msi import create_msi
from visualization import (
        generate_nx_graph, generate_khop_subgraph, generate_pyvis_graph, 
        visualize_tsne_embeddings, visualize_pca_embeddings)
from training import split_data, create_minibatches, train_model
from torch_geometric.utils import to_networkx, degree
from torch_geometric.nn import to_hetero

from gnn import GAT, GIN
print(torch.__version__)

1.13.1+cu117


In [4]:
msi = create_msi("msi", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data")
data = msi.hetero_data

In [3]:
# Generate and visualize graphs (HIV: id = C0019693, gid = 840)
##G, G_undirected = generate_nx_graph(msi.data)
##subG = generate_khop_subgraph(msi, G, 840, 1)
##py_graph = generate_pyvis_graph(subG)
##py_graph.show("test.html")

In [5]:
# Set the device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [5]:
edge_type = ("dz", "implicates", "prot")
id_map = msi.mappings["dz_id_gid_map"]
name_map = msi.mappings["dz_id_name_map"]



degrees = degree(data[edge_type].edge_index[0]).numpy()
sorted_gids = np.argsort(degrees)[::-1]
sorted_degs = np.sort(degrees)[::-1]


#for gid, deg in zip(sorted_gids, sorted_degs):
#    id = list(id_map.keys())[list(id_map.values()).index(gid)]
#    name = name_map[id]
#    print(f"Disease: {name:<45}   |   1-hop neighbors: {deg:>5}   |   id: {id}")


In [6]:
# Prepare data for GNN.
edge_types = ("drug", "treats", "dz")
rev_edge_types = ("dz", "rev_treats", "drug")

# Split the data into train, validation, and test sets.
train_data, val_data, test_data = split_data(
        data, val_size=0.1, test_size=0.1, supervision_ratio=0.3,
        edge_types=edge_types, neg_sampling_ratio=2.0,
        rev_edge_types=rev_edge_types)


# Create minibatches.
train_loader = create_minibatches(
        train_data, edge_types=edge_types, num_khop_neighbors = [50,25,10])
val_loader = create_minibatches(
        val_data, edge_types=edge_types, num_khop_neighbors = [50,25,10])



In [7]:
# Instantiate the model and train it.
model = GAT(64, 32)
model = to_hetero(model, data.metadata(), aggr='sum')

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
train_model(model, "drug", "dz", edge_types=edge_types, 
            epochs=10, train_loader=train_loader, val_data=val_data, 
            optimizer=optimizer, device=device) 

100%|██████████| 12/12 [00:08<00:00,  1.45it/s]


 Epoch: 001
 Loss     : 1.6718     |     AUC:         0.5121
 Val_Loss : 1.0777     |     Val_AUC:     0.5175


100%|██████████| 12/12 [00:08<00:00,  1.49it/s]


 Epoch: 002
 Loss     : 0.7487     |     AUC:         0.5487
 Val_Loss : 0.6986     |     Val_AUC:     0.5304


100%|██████████| 12/12 [00:07<00:00,  1.51it/s]


 Epoch: 003
 Loss     : 0.6762     |     AUC:         0.5513
 Val_Loss : 0.6940     |     Val_AUC:     0.5308


100%|██████████| 12/12 [00:08<00:00,  1.36it/s]


 Epoch: 004
 Loss     : 0.6512     |     AUC:         0.5743
 Val_Loss : 0.7010     |     Val_AUC:     0.4805


100%|██████████| 12/12 [00:07<00:00,  1.53it/s]


 Epoch: 005
 Loss     : 0.6345     |     AUC:         0.5975
 Val_Loss : 0.6812     |     Val_AUC:     0.5242


100%|██████████| 12/12 [00:07<00:00,  1.51it/s]


 Epoch: 006
 Loss     : 0.6338     |     AUC:         0.5842
 Val_Loss : 0.6749     |     Val_AUC:     0.5018


100%|██████████| 12/12 [00:07<00:00,  1.53it/s]


 Epoch: 007
 Loss     : 0.6278     |     AUC:         0.5990
 Val_Loss : 0.6548     |     Val_AUC:     0.5437


100%|██████████| 12/12 [00:08<00:00,  1.41it/s]


 Epoch: 008
 Loss     : 0.6315     |     AUC:         0.5932
 Val_Loss : 0.6683     |     Val_AUC:     0.5268


100%|██████████| 12/12 [00:07<00:00,  1.57it/s]


 Epoch: 009
 Loss     : 0.6181     |     AUC:         0.6172
 Val_Loss : 0.6529     |     Val_AUC:     0.5516


100%|██████████| 12/12 [00:08<00:00,  1.43it/s]


 Epoch: 010
 Loss     : 0.6232     |     AUC:         0.6081
 Val_Loss : 0.6574     |     Val_AUC:     0.5517


In [8]:
#visualize_tsne_embeddings(gae_model, train_dataset, 'Untrained GAE: train set embeddings t-SNE', labeled=True, labels=[40, 190, 230, 1830, 260, 110, 280, 1967])

In [9]:
#visualize_pca_embeddings(gae_model, train_dataset, 'Untrained GAE: train set embeddings PCA', labeled=True, labels=[40, 190, 230, 1830, 260, 110, 280, 1967])

In [10]:
print(data.metadata())

(['prot', 'drug', 'dz', 'func'], [('drug', 'binds', 'prot'), ('dz', 'implicates', 'prot'), ('prot', 'associates', 'prot'), ('prot', 'partOf', 'func'), ('func', 'partOf', 'func'), ('drug', 'treats', 'dz'), ('prot', 'rev_binds', 'drug'), ('prot', 'rev_implicates', 'dz'), ('func', 'rev_partOf', 'prot'), ('dz', 'rev_treats', 'drug')])


In [11]:
msi = create_msi("msi-hiv", 
                 num_features = 1,
                 data_dir="../multiscale_interactome/data", 
                 extra_dz_data="data/hiv/hiv_protein30.tsv")

data = msi.hetero_data
#(HIV: id = C0019693, gid = 840)
hiv_id = "C0019693"
hiv_gid = 840


In [12]:
def get_node_embeddings(model, data):
    model.eval()
    model.to(device)
    data.to(device)
    z = model(data.x_dict, data.edge_index_dict)
    return z

In [13]:
z = get_node_embeddings(model, data)

In [14]:
hiv_emb = z["dz"][hiv_gid]

In [15]:
z_drugs = z["drug"]

In [16]:
edge_preds = []
for drug_emb in z_drugs:
    edge_pred = (drug_emb * hiv_emb).sum(dim=-1).detach().numpy()
    edge_preds.append(edge_pred)


In [17]:
sorted_edges = np.sort(edge_preds)[::-1]
sorted_ind = np.argsort(edge_preds)[::-1]

In [18]:
print(sorted_edges)
print(sorted_ind)

[ 0.13641936  0.00232488 -0.07020539 ... -1.9591882  -2.0613048
 -2.147374  ]
[1489 1340    8 ... 1602 1599 1212]


In [19]:
for ind, edge in zip(sorted_ind, sorted_edges):
    id = [i for i in msi.mappings["drug_id_gid_map"] if msi.mappings["drug_id_gid_map"][i] == ind][0]
    drug = msi.mappings["drug_id_name_map"][id]
    
    print(f"Drug: {drug:<30}    |    HIV edge pred: {edge:>10.4}")

Drug: Acetohydroxamic Acid              |    HIV edge pred:     0.1364
Drug: linaclotide                       |    HIV edge pred:   0.002325
Drug: deferoxamine-mesylate             |    HIV edge pred:   -0.07021
Drug: montelukast                       |    HIV edge pred:   -0.07853
Drug: benserazide                       |    HIV edge pred:    -0.1162
Drug: voglibose                         |    HIV edge pred:    -0.1486
Drug: homoharringtonine                 |    HIV edge pred:    -0.1548
Drug: Bromazepam                        |    HIV edge pred:    -0.1566
Drug: linezolid                         |    HIV edge pred:    -0.1666
Drug: nan                               |    HIV edge pred:    -0.1719
Drug: nan                               |    HIV edge pred:    -0.1745
Drug: nan                               |    HIV edge pred:    -0.2261
Drug: procarbazine                      |    HIV edge pred:    -0.2386
Drug: quinagolide                       |    HIV edge pred:    -0.2516
Drug: 