In [1]:
%cd ../

/home/varadi_kristof/llms-for-trials/src/hint


In [2]:
%load_ext autoreload
%autoreload 2

In [93]:
import os
import torch
import ast
import warnings;
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

from toxicity.model import MultitaskToxicityModel, load_ckp
from trial.model import TrialModel, Trainer as TrialTrainer
from trial.protocol import ProtocolEmbedding
from trial.disease_encoder import GRAM, build_icdcode2ancestor_dict

In [94]:
warnings.filterwarnings("ignore")
torch.manual_seed(0)
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [95]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict)
protocol_model = ProtocolEmbedding(output_dim=50, hidden_dim=50, num_layers=3,
                                   hf_model="emilyalsentzer/Bio_ClinicalBERT", device=device).to(device)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [96]:
data_dir = "./data"
model_dir = "./checkpoints/toxicity"
smiles_embedding_dir = "toxicity/smiles_embedding"

smiles_embed_train = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_train.pt")
smiles_embed_valid = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_valid.pt")
smiles_embed_test = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_test.pt")

smiles_embeddings = {**smiles_embed_train, **smiles_embed_valid, **smiles_embed_test}

In [97]:
clintox_task = ['CT_TOX']
tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', 
               'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',
               'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']

all_tasks = tox21_tasks + clintox_task

first_smiles = next(iter(smiles_embed_train.values()))
input_shape = first_smiles.shape[0]

model = MultitaskToxicityModel(input_shape, all_tasks).to(device)
toxicity_model, _, _, _ = load_ckp(f"{model_dir}/best_model_by_valid.pt", model, None)

In [98]:
phase = 'phase_III'
model_name = 'mtdnn_model'

train_file = os.path.join(f"{data_dir}/trial", phase + '_train.csv')
valid_file = os.path.join(f"{data_dir}/trial", phase + '_valid.csv')
test_file = os.path.join(f"{data_dir}/trial", phase + '_test.csv')

In [99]:
def explode_list(row):
    smiles_list = ast.literal_eval(row)
    return smiles_list

def extract_smiles_embed(smiles_row: str):
    smiles_list = explode_list(smiles_row)
    embeddings = [smiles_embeddings.get(smiles, torch.zeros(input_shape)) for smiles in smiles_list]

    if embeddings:
        summed_embeddings = torch.stack(embeddings).sum(dim=0) # TODO: this is very primitive, some information loss is likely
    else:
        summed_embeddings = torch.zeros(input_shape)
    return summed_embeddings

In [100]:
def extract_icd(text):
    text = text[2:-2]
    lst_lst = []
    for i in text.split('", "'):
        i = i[1:-1]
        lst_lst.append([j.strip()[1:-1] for j in i.split(',')])
    return lst_lst 

In [101]:
class TrialDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        data = self.dataframe.iloc[idx]
        return {
            "nctids": data['nctid'],
            "labels": data['label'],
            "smiless": data['smiless'],
            "criteria": data['criteria'],
            "icdcodes": data['icdcodes']
        }

def trial_collate_fn(batch):
    batch_inputs = {key: [d[key] for d in batch] for key in batch[0]}
    
    batch_inputs["smiless"] = torch.stack([extract_smiles_embed(smiles) for smiles in batch_inputs["smiless"]])
    batch_inputs["icdcodes"] = [extract_icd(icd) for icd in batch_inputs["icdcodes"]]
    
    batch_inputs["criteria"] = [protocol_model.tokenizer(criteria, padding=True) for criteria in batch_inputs["criteria"]]
    
    batch_inputs["nctids"] = batch_inputs["nctids"]
    batch_inputs["labels"] = torch.tensor(batch_inputs["labels"])
    
    return (batch_inputs["nctids"], batch_inputs["labels"], batch_inputs["smiless"], batch_inputs["icdcodes"], batch_inputs["criteria"])

In [102]:
train_df = pd.read_csv(train_file)
valid_df = pd.read_csv(valid_file)
test_df = pd.read_csv(test_file)

train_dataset = TrialDataset(train_df)
valid_dataset = TrialDataset(valid_df)
test_dataset = TrialDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=trial_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn)

In [103]:
hint_model_path = f"./checkpoints/{model_name}.ckpt"

model = TrialModel(
             toxicity_encoder = toxicity_model, 
             disease_encoder = gram_model, 
             protocol_encoder = protocol_model,
             embedding_size = 50, 
             num_ffn_layers=2,
             num_pred_layers=3,
             name=model_name)

In [107]:
trainer = TrialTrainer(model, lr=1e-3, weight_decay=0, device=device)
num_epochs = 5
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)

  0%|                                                                                            | 0/11 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

In [None]:
test_results = trainer.test(test_loader)
bootstrap_results = trainer.bootstrap_test(test_loader)
torch.save(model, hint_model_path)