In [1]:
%cd ../

/home/varadi_kristof/llms-for-trials/src/hint


In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
import torch
import warnings
import numpy
import random

seed = 42
warnings.filterwarnings("ignore")
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

<torch._C.Generator at 0x7febc6a6dbf0>

In [58]:
import os
import ast
import pandas as pd
from torch.utils.data import DataLoader, Dataset

from toxicity.model import MultitaskToxicityModel, load_ckp
from trial.model import TrialModel, Trainer as TrialTrainer
from trial.protocol import ProtocolEmbedding
from trial.disease_encoder import GRAM, build_icdcode2ancestor_dict

In [28]:
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device=device).to(device)
protocol_model = ProtocolEmbedding(hf_model="emilyalsentzer/Bio_ClinicalBERT", device=device).to(device)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [29]:
data_dir = "./data"
trial_data_dir = f"{data_dir}/trial"
model_dir = "./checkpoints/toxicity"
protocol_embedding_file = f"{data_dir}/protocol_embeddings.pth"
smiles_embedding_dir = "toxicity/smiles_embedding"

smiles_embed_train = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_train.pt")
smiles_embed_valid = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_valid.pt")
smiles_embed_test = torch.load(f"{data_dir}/{smiles_embedding_dir}/smiles_embed_test.pt")

smiles_embeddings = {**smiles_embed_train, **smiles_embed_valid, **smiles_embed_test}

In [30]:
clintox_task = ['CT_TOX']
tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', 
               'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',
               'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']

all_tasks = tox21_tasks + clintox_task

first_smiles = next(iter(smiles_embed_train.values()))
input_shape = first_smiles.shape[0]

model = MultitaskToxicityModel(input_shape, all_tasks).to(device)
toxicity_model, _, _, _ = load_ckp(f"{model_dir}/best_model_by_valid.pt", model, None)

In [31]:
model_name = 'mtdnn_multiphase_small'

train_file = os.path.join(trial_data_dir,'train.csv')
valid_file = os.path.join(trial_data_dir, 'valid.csv')
test_file = os.path.join(trial_data_dir, 'test.csv')

In [32]:
def explode_list(row):
    smiles_list = ast.literal_eval(row)
    return smiles_list

def extract_smiles_embed(smiles_row: str):
    smiles_list = explode_list(smiles_row)
    embeddings = [smiles_embeddings.get(smiles, torch.zeros(input_shape)) for smiles in smiles_list]

    if embeddings:
        embs = embeddings[0]
    else:
        embs = torch.zeros(input_shape)
    return embs

In [33]:
def extract_icd(text):
    text = text[2:-2]
    lst_lst = []
    for i in text.split('", "'):
        i = i[1:-1]
        lst_lst.append([j.strip()[1:-1] for j in i.split(',')])
    return lst_lst 

In [34]:
def prepare_trial_df(df):
    df[["criteria"]] = df[["criteria"]].fillna(value="")
    df[["smiless"]] = df[["smiless"]].fillna(value="[]")
    df[["icdcodes"]] = df[["icdcodes"]].fillna(value="[]")
    return df

In [35]:
train_df = prepare_trial_df(pd.read_csv(train_file))
valid_df = prepare_trial_df(pd.read_csv(valid_file))
test_df = prepare_trial_df(pd.read_csv(test_file))

multiphase_df = pd.concat([train_df, valid_df, test_df])
all_phase_categories = pd.concat([train_df['phase'], valid_df['phase'], test_df['phase']]).unique()

In [36]:
from tqdm import tqdm

if not os.path.exists(protocol_embedding_file):
    
    def embedding_collate_fn(batch):
        batch_inputs = {key: [d[key] for d in batch] for key in batch[0]}

        batch_inputs["smiless"] = torch.stack([extract_smiles_embed(smiles) for smiles in batch_inputs["smiless"]])
        batch_inputs["icdcodes"] = [extract_icd(icd) for icd in batch_inputs["icdcodes"]]
        batch_inputs["criteria"] = [protocol_model.tokenizer(criteria, padding=True) for criteria in batch_inputs["criteria"]]
        batch_inputs["nctids"] = batch_inputs["nctids"]
        batch_inputs["labels"] = torch.tensor(batch_inputs["labels"])
        batch_inputs["phase"] = torch.stack(batch_inputs["phase"]).float()

        return (batch_inputs["nctids"], batch_inputs["labels"], batch_inputs["smiless"], batch_inputs["icdcodes"], batch_inputs["criteria"], batch_inputs["phase"])
    
    protocol_embeddings = {}
    multiphase_dataset = TrialDataset(multiphase_df, all_phase_categories)
    multiphase_dataloader = DataLoader(multiphase_dataset, batch_size=64, shuffle=False, collate_fn=embedding_collate_fn)
    
    for nctids, labels, smiles, icdcodes, criteria, phase in tqdm(multiphase_dataloader):
        criteria_embs = protocol_model(criteria).mean(dim=1).cpu()
        for nctid, emb in zip(nctids, criteria_embs):
            protocol_embeddings[nctid] = emb
    
    torch.save(protocol_embeddings, protocol_embedding_file)
    
else:
    protocol_embeddings = torch.load(protocol_embedding_file)

In [37]:
class TrialDataset(Dataset):
    def __init__(self, dataframe, phase_categories):
        self.dataframe = dataframe
        self.phase_categories = phase_categories
        phase_dummies = pd.get_dummies(self.dataframe['phase']).reindex(columns=phase_categories, fill_value=0)
        self.dataframe = pd.concat([self.dataframe, phase_dummies], axis=1)
        self.phase_columns = phase_categories
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        data = self.dataframe.iloc[idx]
        phase_data = torch.tensor(data[self.phase_columns].values.astype(float))
        return {
            "nctids": data['nctid'],
            "labels": data['label'],
            "smiless": data['smiless'],
            "criteria": data['criteria'],
            "icdcodes": data['icdcodes'],
            "phase": phase_data
        }


def trial_collate_fn(batch):
    batch_inputs = {key: [d[key] for d in batch] for key in batch[0]}
    
    batch_inputs["smiless"] = torch.stack([extract_smiles_embed(smiles) for smiles in batch_inputs["smiless"]])
    batch_inputs["icdcodes"] = [extract_icd(icd) for icd in batch_inputs["icdcodes"]]
    batch_inputs["criteria"] = torch.stack([protocol_embeddings.get(nctid) for nctid in batch_inputs["nctids"]]).float()
    batch_inputs["nctids"] = batch_inputs["nctids"]
    batch_inputs["labels"] = torch.tensor(batch_inputs["labels"])
    batch_inputs["phase"] = torch.stack(batch_inputs["phase"]).float()
    
    return (batch_inputs["nctids"], batch_inputs["labels"], batch_inputs["smiless"], batch_inputs["icdcodes"], batch_inputs["criteria"], batch_inputs["phase"])

In [38]:
from sklearn.model_selection import train_test_split

train_val_df, test_df = train_test_split(multiphase_df, test_size=0.2, random_state=42)
train_df, valid_df = train_test_split(train_val_df, test_size=0.25, random_state=42)

train_dataset = TrialDataset(train_df, all_phase_categories)
valid_dataset = TrialDataset(valid_df, all_phase_categories)
test_dataset = TrialDataset(test_df, all_phase_categories)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=trial_collate_fn, worker_init_fn=seed_worker, generator=g)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn, worker_init_fn=seed_worker, generator=g)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn, worker_init_fn=seed_worker, generator=g)

In [39]:
phase_dim = train_dataset[0]["phase"].shape[-1]

In [62]:
hint_model_path = f"./checkpoints/{model_name}.ckpt"

model = TrialModel(
             toxicity_encoder = toxicity_model, 
             disease_encoder = gram_model, 
             protocol_embedding_size = protocol_model.embedding_size,
             embedding_size = 100, 
             num_ffn_layers=2,
             num_pred_layers=3,
             dropout = 0.0,
             phase_dim=phase_dim,
             name=model_name,
             device=device)

In [None]:
trainer = TrialTrainer(model, lr=1e-3, weight_decay=0, device=device)
num_epochs = 20
metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)
torch.save(model, hint_model_path)

print("Test results\n\n")
test_results = trainer.test(test_loader, all_phase_categories)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train_loss,█▃▂▂▁
valid_loss,▁█▂▄▁

0,1
train_loss,0.54669
valid_loss,0.62324


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113153977526559, max=1.0…

100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:28<00:00,  8.15it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:25<00:00,  9.04it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:21<00:00, 11.05it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:19<00:00, 11.76it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:21<00:00, 10.73it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:18<00:00, 12.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:22<00:00, 10.23it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:21<00:00, 10.90it/s]


In [None]:
print("Bootstrap test results\n")
bootstrap_results = trainer.bootstrap_test(test_loader, all_phase_categories)

In [None]:
multiphase_df.shape