In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import torch
import warnings;

from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst
from HINT.molecule_encode import MPNN, ADMET 
from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict
from HINT.protocol.model import ProtocolEmbedding
from HINT.simple import TrialModel, Trainer

In [4]:
warnings.filterwarnings("ignore")
torch.manual_seed(0) 
device = torch.device("cuda:0")
if not os.path.exists("figure"):
	os.makedirs("figure")

In [5]:
phase = 'phase_III'
model_name = 'improved_model'
datafolder = "data"
train_file = os.path.join(datafolder, phase + '_train.csv')
valid_file = os.path.join(datafolder, phase + '_valid.csv')
test_file = os.path.join(datafolder, phase + '_test.csv')

In [6]:
mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)
admet_model_path = "save_model/admet_model.ckpt"
if not os.path.exists(admet_model_path):
	admet_dataloader_lst = generate_admet_dataloader_lst(batch_size=32)
	admet_trainloader_lst = [i[0] for i in admet_dataloader_lst]
	admet_testloader_lst = [i[1] for i in admet_dataloader_lst]
	admet_model = ADMET(molecule_encoder = mpnn_model, 
						highway_num=2, 
						device = device, 
						epoch=3, 
						lr=5e-4, 
						weight_decay=0, 
						save_name = 'admet_')
	admet_model.train(admet_trainloader_lst, admet_testloader_lst)
	torch.save(admet_model, admet_model_path)
else:
	admet_model = torch.load(admet_model_path)
	admet_model = admet_model.to(device)
	admet_model.set_device(device)

In [7]:
train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=32) 
valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) 
test_loader = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) 

In [8]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)
protocol_model = ProtocolEmbedding(output_dim = 50, highway_num=3, device = device)

In [12]:
hint_model_path = f"save_model/{model_name}.ckpt"

model = TrialModel(
             molecule_encoder = admet_model.molecule_encoder, 
             disease_encoder = gram_model, 
             protocol_encoder = protocol_model,
             embedding_size = 50, 
             num_ffn_layers=2,
             num_pred_layers=3,
             name=model_name)

In [13]:
trainer = Trainer(model, lr=1e-3, weight_decay=0, device=device)
num_epochs = 3
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)

100%|████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [05:11<00:00, 103.98s/it]


In [15]:
test_results = trainer.test(test_loader)
bootstrap_resultsresults = trainer.bootstrap_test(test_loader)
torch.save(model, hint_model_path)

--------------------------------------------------
Accuracy: 0.684, TP: 614, FP:117, TN:170, FN:245
F1-Score: 0.772
ROC-AUC: 0.654
PR-AUC: 0.884
--------------------------------------------------
PR_AUC - mean: 0.8567, std: 0.0125
F1 - mean: 0.8081, std: 0.0082
ROC_AUC - mean: 0.6946, std: 0.0140
ACCURACY - mean: 0.7171, std: 0.0113


In [19]:
from collections import Counter
predictions = [0 if pred < 0.5 else 1 for pred in predict_all]
prediction_counts = Counter(predictions)

print(prediction_counts)


Counter({1: 975, 0: 171})
