In [2]:
%cd ../

/home/varadi_kristof/llms-for-trials/src/hint


In [184]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [202]:
import os
import torch
import ast
import warnings;
import pandas as pd
from torch.utils.data import DataLoader, Dataset

from toxicity.model import MultitaskToxicityModel, load_ckp
from trial.model import TrialModel, Trainer as TrialTrainer
from trial.protocol import ProtocolEmbedding
from trial.disease_encoder import GRAM, build_icdcode2ancestor_dict

In [203]:
warnings.filterwarnings("ignore")
torch.manual_seed(0)
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [204]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device=device).to(device)
protocol_model = ProtocolEmbedding(hf_model="emilyalsentzer/Bio_ClinicalBERT", device=device).to(device)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [205]:
data_dir = "./data"
trial_data_dir = f"{data_dir}/trial"
model_dir = "./checkpoints/toxicity"
protocol_embedding_file = f"{data_dir}/protocol_embeddings.pth"
smiles_embedding_dir = "toxicity/smiles_embedding"

smiles_embed_train = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_train.pt")
smiles_embed_valid = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_valid.pt")
smiles_embed_test = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_test.pt")

smiles_embeddings = {**smiles_embed_train, **smiles_embed_valid, **smiles_embed_test}

In [206]:
clintox_task = ['CT_TOX']
tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', 
               'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',
               'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']

all_tasks = tox21_tasks + clintox_task

first_smiles = next(iter(smiles_embed_train.values()))
input_shape = first_smiles.shape[0]

model = MultitaskToxicityModel(input_shape, all_tasks).to(device)
toxicity_model, _, _, _ = load_ckp(f"{model_dir}/best_model_by_valid.pt", model, None)

In [207]:
model_name = 'mtdnn_multiphase'

train_file = os.path.join(trial_data_dir,'train.csv')
valid_file = os.path.join(trial_data_dir, 'valid.csv')
test_file = os.path.join(trial_data_dir, 'test.csv')

In [208]:
def explode_list(row):
    smiles_list = ast.literal_eval(row)
    return smiles_list

def extract_smiles_embed(smiles_row: str):
    smiles_list = explode_list(smiles_row)
    embeddings = [smiles_embeddings.get(smiles, torch.zeros(input_shape)) for smiles in smiles_list]

    if embeddings:
        summed_embeddings = torch.stack(embeddings).sum(dim=0) # TODO: this is very primitive, some information loss is likely
    else:
        summed_embeddings = torch.zeros(input_shape)
    return summed_embeddings

In [209]:
def extract_icd(text):
    text = text[2:-2]
    lst_lst = []
    for i in text.split('", "'):
        i = i[1:-1]
        lst_lst.append([j.strip()[1:-1] for j in i.split(',')])
    return lst_lst 

In [210]:
class TrialDataset(Dataset):
    def __init__(self, dataframe, phase_categories):
        self.dataframe = dataframe
        phase_dummies = pd.get_dummies(self.dataframe['phase']).reindex(columns=phase_categories, fill_value=0)
        self.dataframe.drop('phase', axis=1, inplace=True)
        self.dataframe = pd.concat([self.dataframe, phase_dummies], axis=1)
        self.phase_columns = phase_categories
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        data = self.dataframe.iloc[idx]
        phase_data = torch.tensor(data[self.phase_columns].values.astype(float))
        return {
            "nctids": data['nctid'],
            "labels": data['label'],
            "smiless": data['smiless'],
            "criteria": data['criteria'],
            "icdcodes": data['icdcodes'],
            "phase": phase_data
        }

def trial_collate_fn(batch):
    batch_inputs = {key: [d[key] for d in batch] for key in batch[0]}
    
    batch_inputs["smiless"] = torch.stack([extract_smiles_embed(smiles) for smiles in batch_inputs["smiless"]])
    batch_inputs["icdcodes"] = [extract_icd(icd) for icd in batch_inputs["icdcodes"]]
    batch_inputs["criteria"] = [protocol_model.tokenizer(criteria, padding=True) for criteria in batch_inputs["criteria"]]
    batch_inputs["nctids"] = batch_inputs["nctids"]
    batch_inputs["labels"] = torch.tensor(batch_inputs["labels"])
    batch_inputs["phase"] = torch.stack(batch_inputs["phase"]).float()
    
    return (batch_inputs["nctids"], batch_inputs["labels"], batch_inputs["smiless"], batch_inputs["icdcodes"], batch_inputs["criteria"], batch_inputs["phase"])

In [211]:
def prepare_trial_df(df):
    df[["criteria"]] = df[["criteria"]].fillna(value="")
    df[["smiless"]] = df[["smiless"]].fillna(value="[]")
    df[["icdcodes"]] = df[["icdcodes"]].fillna(value="[]")
    return df

In [224]:
train_df = prepare_trial_df(pd.read_csv(train_file))
valid_df = prepare_trial_df(pd.read_csv(valid_file))
test_df = prepare_trial_df(pd.read_csv(test_file))

multiphase_df = pd.concat([train_df, valid_df, test_df])
all_phase_categories = pd.concat([train_df['phase'], valid_df['phase'], test_df['phase']]).unique()

train_dataset = TrialDataset(train_df, all_phase_categories)
valid_dataset = TrialDataset(valid_df, all_phase_categories)
test_dataset = TrialDataset(test_df, all_phase_categories)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=trial_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn)

In [225]:
from tqdm import tqdm

if not os.path.exists(protocol_embedding_file):
    embs = {}
    multiphase_dataset = TrialDataset(multiphase_df, all_phase_categories)
    multiphase_dataloader = DataLoader(multiphase_dataset, batch_size=32, shuffle=True, collate_fn=trial_collate_fn)
    
    for nctids, labels, smiles, icdcodes, criteria, phase in tqdm(multiphase_dataloader):
        criteria_emb = protocol_model(criteria)
        for c, emb in zip(criteria, criteria_emb):
            embs[c] = emb
            
    torch.save(embs, protocol_embedding_file)

  0%|                                                                                           | 0/195 [00:09<?, ?it/s]


TypeError: unhashable type: 'BatchEncoding'

In [196]:
phase_dim = train_dataset[0]["phase"].shape[-1]

In [197]:
hint_model_path = f"./checkpoints/{model_name}.ckpt"

model = TrialModel(
             toxicity_encoder = toxicity_model, 
             disease_encoder = gram_model, 
             protocol_embedding_path = protocol_embedding_file,
             protocol_embedding_size = protocol_model.embedding_size,
             embedding_size = 50, 
             num_ffn_layers=2,
             num_pred_layers=3,
             phase_dim=phase_dim,
             name=model_name,
             device=device)

In [198]:
trainer = TrialTrainer(model, lr=1e-3, weight_decay=0, device=device)
num_epochs = 5
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)

100%|█████████████████████████████████████████████████████████████████████████████████| 255/255 [21:00<00:00,  4.94s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 255/255 [17:35<00:00,  4.14s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 255/255 [17:31<00:00,  4.13s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 255/255 [17:31<00:00,  4.12s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 255/255 [17:40<00:00,  4.16s/it]


In [201]:
metrics

{'tp': 1141,
 'fp': 355,
 'tn': 948,
 'fn': 983,
 'f1': 0.6303867403314918,
 'pr_auc': 0.793367154984866,
 'roc_auc': 0.6323728885824832}

In [200]:
test_results = trainer.test(test_loader)
bootstrap_results = trainer.bootstrap_test(test_loader)
torch.save(model, hint_model_path)

KeyboardInterrupt: 