In [1]:
%cd ../

/home/varadi_kristof/llms-for-trials/src/hint


In [2]:
%load_ext autoreload
%autoreload 2

In [5]:
import os
import torch
import ast
import warnings;
import pandas as pd

from toxicity.model import MultitaskToxicityModel, load_ckp
from trial.model import TrialModel, Trainer as TrialTrainer
from trial.protocol import ProtocolEmbedding
from trial.disease_encoder import GRAM, build_icdcode2ancestor_dict

In [6]:
warnings.filterwarnings("ignore")
torch.manual_seed(0)
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
data_dir = "./data"
model_dir = "./checkpoints/toxicity"
smiles_embedding_dir = "toxicity/smiles_embedding"

smiles_embed_train = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_train.pt")
smiles_embed_valid = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_valid.pt")
smiles_embed_test = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_test.pt")

smiles_embeddings = {**smiles_embed_train, **smiles_embed_valid, **smiles_embed_test}

In [8]:
clintox_task = ['CT_TOX']
tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', 
               'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',
               'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']

all_tasks = tox21_tasks + clintox_task

first_smiles = next(iter(smiles_embed_train.values()))
input_shape = first_smiles.shape[0]

model = MultitaskToxicityModel(input_shape, all_tasks).to(device)
toxicity_model, _, _, _ = load_ckp(f"{model_dir}/best_model_by_valid.pt", model, None)

In [9]:
phase = 'phase_III'
model_name = 'mtdnn_model'

train_file = os.path.join(f"{data_dir}/trial", phase + '_train.csv')
valid_file = os.path.join(f"{data_dir}/trial", phase + '_valid.csv')
test_file = os.path.join(f"{data_dir}/trial", phase + '_test.csv')

In [10]:
train_df = pd.read_csv(train_file)
valid_df = pd.read_csv(valid_file)
test_df = pd.read_csv(test_file)

In [11]:
def explode_list(row):
    smiles_list = ast.literal_eval(row)
    return smiles_list

def extract_smiles_embed(smiles_row: str):
    smiles_list = explode_list(smiles_row)
    embeddings = [smiles_embeddings.get(smiles, torch.zeros(input_shape)) for smiles in smiles_list]

    if embeddings:
        summed_embeddings = torch.stack(embeddings).sum(dim=0) # TODO: this is very primitive, some information loss is likely
    else:
        summed_embeddings = torch.zeros(input_shape)
    return summed_embeddings

In [13]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict).to(device)
protocol_model = ProtocolEmbedding(output_dim=50, hidden_dim=50, num_layers=3, hf_model="emilyalsentzer/Bio_ClinicalBERT").to(device)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def trial_collate_fn(inputs):
    inputs["smiless"] = extract_smiless_embed(inputs["smiless"])
    inputs["criteria"] = protocol_model.tokenizer(inputs["criteria"])
    
    return inputs

In [14]:
hint_model_path = f"./checkpoints/{model_name}.ckpt"

model = TrialModel(
             toxicity_encoder = toxicity_model, 
             disease_encoder = gram_model, 
             protocol_encoder = protocol_model,
             embedding_size = 50, 
             num_ffn_layers=2,
             num_pred_layers=3,
             name=model_name)

In [15]:
trainer = TrialTrainer(model, lr=1e-3, weight_decay=0, device=device)
num_epochs = 5
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)

NameError: name 'train_loader' is not defined

In [None]:
test_results = trainer.test(test_loader)
bootstrap_results = trainer.bootstrap_test(test_loader)
torch.save(model, hint_model_path)