In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import warnings;

from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst
from HINT.molecule_encode import MPNN, ADMET 
from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict
from HINT.protocol_embedding import ProtocolEmbedding
from HINT.simple import TrialModel, Trainer

In [3]:
warnings.filterwarnings("ignore")
torch.manual_seed(0) 
device = torch.device("cuda:0")

In [4]:
phase = 'phase_II'
model_name = 'improved_model'
datafolder = "data"
train_file = os.path.join(datafolder, phase + '_train.csv')
valid_file = os.path.join(datafolder, phase + '_valid.csv')
test_file = os.path.join(datafolder, phase + '_test.csv')

In [5]:
mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)
admet_model_path = "checkpoints/admet_model.ckpt"
if not os.path.exists(admet_model_path):
	admet_dataloader_lst = generate_admet_dataloader_lst(batch_size=32)
	admet_trainloader_lst = [i[0] for i in admet_dataloader_lst]
	admet_testloader_lst = [i[1] for i in admet_dataloader_lst]
	admet_model = ADMET(molecule_encoder = mpnn_model, 
						highway_num=2, 
						device = device, 
						epoch=3, 
						lr=5e-4, 
						weight_decay=0, 
						save_name = 'admet_')
	admet_model.train(admet_trainloader_lst, admet_testloader_lst)
	torch.save(admet_model, admet_model_path)
else:
	admet_model = torch.load(admet_model_path)
	admet_model = admet_model.to(device)
	admet_model.set_device(device)

ModuleNotFoundError: No module named 'HINT.module'

In [None]:
train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=32) 
valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) 
test_loader = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) 

In [None]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)
protocol_model = ProtocolEmbedding(output_dim = 50, highway_num=3, device = device)

In [None]:
hint_model_path = f"checkpoints/{model_name}.ckpt"

model = TrialModel(
             molecule_encoder = admet_model.molecule_encoder, 
             disease_encoder = gram_model, 
             protocol_encoder = protocol_model,
             embedding_size = 50, 
             num_ffn_layers=2,
             num_pred_layers=3,
             name=model_name)

In [None]:
trainer = Trainer(model, lr=1e-3, weight_decay=0, device=device)
num_epochs = 5
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)

In [None]:
test_results = trainer.test(test_loader)
bootstrap_results = trainer.bootstrap_test(test_loader)
torch.save(model, hint_model_path)