In [34]:
%load_ext autoreload
%autoreload 2

In [213]:
import os
import torch
import warnings;

from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst
from HINT.molecule_encode import MPNN, ADMET 
from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict
from HINT.protocol.model import ProtocolEmbedding
from HINT.simple import TrialModel, Trainer

In [214]:
warnings.filterwarnings("ignore")
torch.manual_seed(0) 
device = torch.device("cuda:0")
if not os.path.exists("figure"):
	os.makedirs("figure")

In [215]:
phase = 'phase_III'
model_name = 'improved_model'
datafolder = "data"
train_file = os.path.join(datafolder, phase + '_train.csv')
valid_file = os.path.join(datafolder, phase + '_valid.csv')
test_file = os.path.join(datafolder, phase + '_test.csv')

In [216]:
mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)
admet_model_path = "save_model/admet_model.ckpt"
if not os.path.exists(admet_model_path):
	admet_dataloader_lst = generate_admet_dataloader_lst(batch_size=32)
	admet_trainloader_lst = [i[0] for i in admet_dataloader_lst]
	admet_testloader_lst = [i[1] for i in admet_dataloader_lst]
	admet_model = ADMET(molecule_encoder = mpnn_model, 
						highway_num=2, 
						device = device, 
						epoch=3, 
						lr=5e-4, 
						weight_decay=0, 
						save_name = 'admet_')
	admet_model.train(admet_trainloader_lst, admet_testloader_lst)
	torch.save(admet_model, admet_model_path)
else:
	admet_model = torch.load(admet_model_path)
	admet_model = admet_model.to(device)
	admet_model.set_device(device)

In [217]:
train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=32) 
valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) 
test_loader = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) 

In [218]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)
protocol_model = ProtocolEmbedding(output_dim = 50, highway_num=3, device = device)

In [219]:
hint_model_path = f"save_model/{model_name}.ckpt"

 model = TrialModel(
                 molecule_encoder = admet_model.molecule_encoder, 
                 disease_encoder = gram_model, 
                 protocol_encoder = protocol_model,
                 embedding_size = 50, 
                 num_ffn_layers=3,
                 num_pred_layers=5,
                 name=model_name)

In [221]:
trainer = Trainer(model, lr=1e-3, weight_decay=0, device=device)

In [224]:
num_epochs = 3
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)

test_results = trainer.test(test_loader)
torch.save(model, hint_model_path)

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [06:17<00:00, 75.41s/it]


--------------------------------------------------
Accuracy: 0.687, TP: 631, FP:131, TN:156, FN:228
F1-Score: 0.779
ROC-AUC: 0.639
PR-AUC: 0.881
--------------------------------------------------


In [223]:
results = trainer.bootstrap_test(test_loader)

PR_AUC - mean: 0.8743, std: 0.0089
F1 - mean: 0.7833, std: 0.0095
ROC_AUC - mean: 0.6487, std: 0.0143
ACCURACY - mean: 0.6930, std: 0.0105


{'pr_auc': [0.8719873063893295,
  0.8844379751263312,
  0.869645628095402,
  0.8610018473091201,
  0.8761112058547952,
  0.8558318311186806,
  0.8791107528957044,
  0.8790084841280343,
  0.8883907318166265,
  0.8819457093437804,
  0.8568579478418812,
  0.8716378382576917,
  0.8684590883135863,
  0.8781764964731984,
  0.882865444358874,
  0.8782530347355606,
  0.8760434580265184,
  0.8698476382256061,
  0.8699937843877229,
  0.8858841839137568],
 'f1': [0.7807288449660283,
  0.7906137184115523,
  0.7912885662431942,
  0.7781269641734758,
  0.7928358208955224,
  0.7665816326530612,
  0.7854103343465046,
  0.7864615384615384,
  0.7844092570036542,
  0.7948717948717949,
  0.7640866873065015,
  0.7839699436443333,
  0.7789605510331872,
  0.7761557177615571,
  0.7951807228915662,
  0.7954971857410881,
  0.7695190505933791,
  0.7845303867403316,
  0.7729591836734694,
  0.7929782082324456],
 'roc_auc': [0.6496544444223147,
  0.649074074074074,
  0.6387132317715819,
  0.6420613509192646,
  0.63

In [19]:
from collections import Counter
predictions = [0 if pred < 0.5 else 1 for pred in predict_all]
prediction_counts = Counter(predictions)

print(prediction_counts)


Counter({1: 975, 0: 171})
