In [1]:
import os
import numpy as np
import pandas as pd
import torch
import rdkit
from torch_geometric.datasets import MoleculeNet
from gnn import GIN_mol, GAT_mol
from visualization import draw_molecule
from training import train_graph_classifier, stratified_split, sample_data
from torch_geometric.data import Data
from utils.create_dataset import MolDataset

In [2]:
production_data = MolDataset(root="molecule_data/msi_drugs/") 

In [3]:
dataset = MoleculeNet(root = "molecule_data", name = "HIV")

In [4]:
print(dataset[0].y)

tensor([[0.]])


In [5]:
# Split into training, testing, and validation.
train_data, val_data, test_data = stratified_split(dataset, 0.8, 0.1, 0.1)
train_data_down = sample_data(train_data)



In [6]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: HIV(41127):
Number of graphs: 41127
Number of features: 9
Number of classes: 2

Data(x=[19, 9], edge_index=[2, 40], edge_attr=[40, 3], smiles='CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)=[O+]2', y=[1, 1])
Number of nodes: 19
Number of edges: 40
Average node degree: 2.11
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [7]:
# Draw the first molecule in the data set.

draw_molecule(dataset[5].smiles)


In [8]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_data_down, batch_size=128, shuffle=True)
val_loader = DataLoader(val_data, batch_size=128, shuffle=True)
production_loader = DataLoader(production_data, batch_size=128, shuffle=False)

In [9]:
from IPython.display import Javascript
from sklearn.preprocessing import OneHotEncoder 
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = GAT_mol(hidden_channels=64, num_classes=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
for data in train_loader:
    ohe = OneHotEncoder().fit(data.y)
    break




for epoch in range(1, 60):
    loss, acc, auc = train_graph_classifier(model, ohe, train_loader, criterion, optimizer, train=True)
    val_loss, val_acc, val_auc = train_graph_classifier(model, ohe, val_loader, criterion, optimizer, train=False)
    
    #train_acc = test(train_loader)
    #test_acc = test(test_loader)
    print(f"Epoch: {epoch:03d}")
    print(f"Train loss: {loss:.4f}   |   Train acc: {acc:.4f}   |.  Train auc: {auc:.4f}")
    print(f"Val loss: {val_loss:.4f}   |   Val acc: {val_acc:.4f}   |.  Val auc: {val_auc:.4f}")
    

<IPython.core.display.Javascript object>

Epoch: 001
Train loss: 0.6502   |   Train acc: 0.6574   |.  Train auc: 0.4924
Val loss: 0.4066   |   Val acc: 0.9650   |.  Val auc: 0.4958
Epoch: 002
Train loss: 0.6369   |   Train acc: 0.6667   |.  Train auc: 0.5181
Val loss: 0.4113   |   Val acc: 0.9650   |.  Val auc: 0.5523
Epoch: 003
Train loss: 0.6352   |   Train acc: 0.6667   |.  Train auc: 0.5400
Val loss: 0.4676   |   Val acc: 0.9650   |.  Val auc: 0.5705
Epoch: 004
Train loss: 0.6319   |   Train acc: 0.6667   |.  Train auc: 0.5629
Val loss: 0.3670   |   Val acc: 0.9650   |.  Val auc: 0.5704
Epoch: 005
Train loss: 0.6303   |   Train acc: 0.6670   |.  Train auc: 0.5869
Val loss: 0.4395   |   Val acc: 0.9356   |.  Val auc: 0.5694
Epoch: 006
Train loss: 0.6072   |   Train acc: 0.6953   |.  Train auc: 0.6249
Val loss: 0.4156   |   Val acc: 0.9086   |.  Val auc: 0.6178
Epoch: 007
Train loss: 0.5999   |   Train acc: 0.7091   |.  Train auc: 0.6219
Val loss: 0.4320   |   Val acc: 0.8915   |.  Val auc: 0.6331
Epoch: 008
Train loss: 0.59

In [10]:
def get_predictions(model, loader):
    model.eval()
    all_names = []
    all_pos_prob = []
    for data in loader:
        name_list = data.name
        data.x = data.x.to(torch.float32)
        out = model(data.x, data.edge_index, data.batch)
        pos_prob = list(out[:,1].detach().numpy())
        all_names += name_list
        all_pos_prob += pos_prob
    all_pos_prob = np.array(all_pos_prob)
    sort_ind = np.argsort(all_pos_prob)[::-1]
    all_names = np.array(all_names)
    sorted_names = all_names[sort_ind]
    sorted_probs = all_pos_prob[sort_ind]
    
    drug_probs = tuple(zip(sorted_names, sorted_probs))
    return drug_probs
    
    

In [11]:
drug_probs = get_predictions(model, production_loader)
for elem in drug_probs:
    #print(f" Drug: {elem[0]:<10} \t| Prob: {elem[1]:>20.4f}")
    print(f" Drug: {elem[0]:<25} \t|\t\t Prob: {elem[1]:.4f}")


 Drug: Aluminium chloride        	|		 Prob: 1.0000
 Drug: IRON DEXTRAN              	|		 Prob: 1.0000
 Drug: Nitric Oxide              	|		 Prob: 1.0000
 Drug: Magnesium Sulfate         	|		 Prob: 1.0000
 Drug: CHROMIUM                  	|		 Prob: 1.0000
 Drug: Oxygen                    	|		 Prob: 1.0000
 Drug: TETRACHLORODECAOXIDE      	|		 Prob: 1.0000
 Drug: Iron                      	|		 Prob: 1.0000
 Drug: Calcium Chloride          	|		 Prob: 1.0000
 Drug: Kappadione                	|		 Prob: 1.0000
 Drug: Lithium                   	|		 Prob: 1.0000
 Drug: Potassium perchlorate     	|		 Prob: 1.0000
 Drug: sucralfate                	|		 Prob: 1.0000
 Drug: cisplatin                 	|		 Prob: 1.0000
 Drug: potassium chloride        	|		 Prob: 1.0000
 Drug: chlorothiazide            	|		 Prob: 1.0000
 Drug: cupric chloride           	|		 Prob: 1.0000
 Drug: MAGNESIUM                 	|		 Prob: 1.0000
 Drug: zinc                      	|		 Prob: 1.0000
 Drug: Arsenic trioxide        

In [148]:
print(drug_probs[0])

('cupric chloride', 0.57972246)
