In [2]:
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import BertModel
# from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments
from transformers import BertModel, BertConfig
from datasets import load_dataset
from transformers import AutoTokenizer,  EarlyStoppingCallback
# from torchcrf import CRF
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import ast
from transformers import DataCollatorForTokenClassification
import matplotlib.pyplot as plt
import json
import os
import optuna
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from scipy.special import softmax

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [3]:
# label2id = {'O': 0, 'I-Treatment': 1, 'I-Test': 2, 'I-Problem': 3, 'I-Background': 4, 'I-Other': 5}
# id2label = {0: 'O', 1: 'I-Treatment', 2: 'I-Test', 3: 'I-Problem', 4: 'I-Background', 5: 'I-Other'}

label2id = {'O': 0, 'I': 1}
id2label = {0: 'O', 1: 'I'}

In [4]:
def transform(example_batch):
    example_batch['sentence'] = ast.literal_eval(example_batch['sentence'])
    example_batch['tag'] = ast.literal_eval(example_batch['tag'])
    example_batch['tag'] = [label2id[label] for label in example_batch['tag']]
    return example_batch

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            new_labels.append(label)

    return new_labels

def tokenize_and_align_labels(examples, tokenizer):
    tokenized_inputs = tokenizer(
        examples["sentence"], truncation=True, is_split_into_words=True
    )
    
    all_labels = examples["tag"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

def tokenize_and_align_labels_wrapper(examples, tokenizer):
    return tokenize_and_align_labels(examples, tokenizer)

def untokenize_labels_predictions(word_ids, true_labels, predictions):
    untokenized_true_labels = []
    untokenized_predictions = []

    for sublist_word_ids, sublist_true_labels, sublist_predictions in zip(word_ids, true_labels, predictions):
        current_labels = []
        current_predictions = []
        last_word_id = None

        for word_id, label, prediction in zip(sublist_word_ids[1:-1], sublist_true_labels, sublist_predictions):
            # Skip if this word_id is the same as the last one (it's a subword)
            if word_id == last_word_id:
                continue

            current_labels.append(label)
            current_predictions.append(prediction)
            last_word_id = word_id

        untokenized_true_labels.append(current_labels)
        untokenized_predictions.append(current_predictions)

    return untokenized_true_labels, untokenized_predictions

class TrainingMonitor:
    def __init__(self, word_ids):
        self.best_f1 = 0
        self.best_confusion_matrix = None
        self.eval_dataset_name = "phee"
        self.word_ids = word_ids
    
    def set_eval_dataset_name(self, eval_dataset_name):
        self.eval_dataset_name = eval_dataset_name

    def set_word_ids(self, word_ids):
        self.word_ids = word_ids

    def compute_metrics_factory(self, fold_no, model_name, binary_classification=False):
        # Define the actual compute_metrics function
        def compute_metrics(eval_preds):
            logits, labels = eval_preds
            predictions = np.argmax(logits, axis=-1)

            # Remove ignored index (special tokens) and convert to labels
            true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
            true_predictions = [
                [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]

            untokenized_true_labels, untokenized_predictions = untokenize_labels_predictions(self.word_ids, true_labels, true_predictions)

            unflat_true = [label for seq in untokenized_true_labels for label in seq]
            unflat_pred = [label for seq in untokenized_predictions for label in seq]
            unreport = classification_report(y_pred=unflat_pred, y_true=unflat_true, output_dict=True)

            if binary_classification:
                scores = softmax(logits, axis=-1)

                y_true = [[0 if l == 0 else 1 for l in label if l != -100] for label in labels]
                y_scores_list = [[sum(s[1:]) for (s, l) in zip(score, label) if l != -100]
                    for score, label in zip(scores, labels)] # list of lists
                y_true = [item for sublist in y_true for item in sublist]
                y_scores = [item for sublist in y_scores_list for item in sublist]

                # Calculate ROC Curve and AUC
                fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                roc_auc = auc(fpr, tpr)

                # Calculate Precision-Recall Curve and AUC
                precision, recall, _ = precision_recall_curve(y_true, y_scores)
                pr_auc = auc(recall, precision)

                # Plot ROC Curve
                plt.figure(figsize=(8, 8))
                plt.plot(fpr, tpr, color='blue', lw=2, marker='.', label='ROC curve (area = %0.2f)' % roc_auc)
                plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
                
                plt.xlabel('False Positive Rate')
                plt.ylabel('True Positive Rate')
                plt.title('ROC Curve')
                plt.legend(loc="lower right")
                plt.savefig(f'analysis/{model_name}/{self.eval_dataset_name}/graphs/fold{fold_no}/binary_dataset_roc-curve.png')
                plt.close()

                # Plot Precision-Recall Curve
                plt.figure(figsize=(8, 8))
                plt.plot(recall, precision, color='blue', lw=2, label='Precision-Recall curve (area = %0.2f)' % pr_auc)
            
                plt.xlabel('Recall')
                plt.ylabel('Precision')
                plt.title('Precision-Recall Curve')
                plt.legend(loc="lower left")

                plt.tight_layout()
                plt.savefig(f'analysis/{model_name}/{self.eval_dataset_name}/graphs/fold{fold_no}/binary_dataset_precision-recall-curve.png')
                plt.close()


                cm = confusion_matrix(y_pred=unflat_pred, y_true=unflat_true)
                disp = ConfusionMatrixDisplay(cm, display_labels=np.array(['I', 'O']))
                fig, ax = plt.subplots(figsize=(8, 8))
                disp.plot(ax=ax)
                
                # Save the figure to an image file
                plt.savefig(f'analysis/{model_name}/{self.eval_dataset_name}/graphs/fold{fold_no}/binary_confusion_matrix.png')
                plt.close()

                un_report_df = pd.DataFrame(unreport).round(3).T
                with open(f"analysis/{model_name}/{self.eval_dataset_name}/reports/fold{fold_no}/binary_dataset_classification_report.json", "w") as f:
                    json.dump(un_report_df.to_dict(), f, indent=4)

                return {
                    "precision": unreport['macro avg']['precision'],
                    "recall": unreport['macro avg']['recall'],
                    "f1": unreport['macro avg']['f1-score'],
                    "accuracy": unreport['accuracy'],
                }
               
            else:
                unreport['macro_wo_O'] = {'precision': (unreport['I-Background']['precision'] + unreport['I-Other']['precision'] + unreport['I-Problem']['precision'] + unreport['I-Test']['precision'] + unreport['I-Treatment']['precision']) / 5,
                'recall': (unreport['I-Background']['recall'] + unreport['I-Other']['recall'] + unreport['I-Problem']['recall'] + unreport['I-Test']['recall'] + unreport['I-Treatment']['recall']) / 5,
                'f1-score': (unreport['I-Background']['f1-score'] + unreport['I-Other']['f1-score'] + unreport['I-Problem']['f1-score'] + unreport['I-Test']['f1-score'] + unreport['I-Treatment']['f1-score']) / 5,
                'support': (unreport['I-Background']['support'] + unreport['I-Other']['support'] + unreport['I-Problem']['support'] + unreport['I-Test']['support'] + unreport['I-Treatment']['support'])}
                
                un_report_df = pd.DataFrame(unreport).round(3).T

                binary_predictions = ['0' if label == 'O' else '1' for label in unflat_pred]
                binary_labels = ['0' if label == 'O' else '1' for label in unflat_true]

                # Generate a classification report
                binary_classification_report = classification_report(y_true=binary_labels, y_pred=binary_predictions, target_names=['O', 'I'], digits=3, output_dict=True)

           
           
                # new_f1_score = unreport['macro_wo_O']['f1-score']
                # if self.best_f1 < new_f1_score:
                #     self.best_f1 = new_f1_score
                cm = confusion_matrix(y_pred=unflat_pred, y_true=unflat_true)
                disp = ConfusionMatrixDisplay(cm, display_labels=np.array(['I-Background','I-Other', 'I-Problem', 'I-Test', 'I-Treatment', 'O']))
                fig, ax = plt.subplots(figsize=(8, 8))
                disp.plot(ax=ax)
                
                # Save the figure to an image file
                plt.savefig(f'analysis/{model_name}/{self.eval_dataset_name}/graphs/fold{fold_no}/confusion_matrix.png')
                plt.close()
                
                with open(f"analysis/{model_name}/{self.eval_dataset_name}/reports/fold{fold_no}/multiclass_classification_report.json", "w") as f:
                    json.dump(un_report_df.to_dict(), f, indent=4)
                
                # multiclass is mapped to binary 
                with open(f"analysis/{model_name}/{self.eval_dataset_name}/reports/fold{fold_no}/binary_classification_report.json", "w") as f:
                    json.dump(binary_classification_report, f, indent=4)
                

                return {
                    "precision": unreport['macro_wo_O']['precision'],
                    "recall": unreport['macro_wo_O']['recall'],
                    "f1": unreport['macro_wo_O']['f1-score'],
                    "accuracy": unreport['accuracy'],
                }
        return compute_metrics


In [5]:
class CustomBertLSTMModel(BertModel):  # (832x50 and 768x13)
    def __init__(self, config, lstm_hidden_size, lstm_num_layers, lstm_dropout):
        super().__init__(config)   
        self.lstm = nn.LSTM(input_size=config.hidden_size, hidden_size=lstm_hidden_size, num_layers=lstm_num_layers, dropout=lstm_dropout if lstm_num_layers > 1 else 0, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(lstm_hidden_size*2, config.num_labels)

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = super().forward(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        lstm_out, _ = self.lstm(sequence_output)
        logits = self.classifier(lstm_out)

        # Calculate the loss if labels are provided
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
        else:
            loss = None

        return (loss, logits) if loss is not None else logits
    
def model_init(trial, best_model_config=None):
    # Define hyperparameters using the trial object
    print(trial)
    lstm_hidden_size = trial.suggest_categorical('lstm_hidden_size', [50, 100, 200]) if trial is not None else best_model_config['lstm_hidden_size']
    lstm_num_layers = trial.suggest_categorical('lstm_num_layers', [1, 2, 3]) if trial is not None else best_model_config['lstm_num_layers']
    lstm_dropout = trial.suggest_float('lstm_dropout', 0.1, 0.5) if trial is not None else best_model_config['lstm_dropout']
    
    config = BertConfig.from_pretrained('dmis-lab/biobert-v1.1', num_labels=2)
    model = CustomBertLSTMModel(config, lstm_hidden_size=lstm_hidden_size, lstm_num_layers=lstm_num_layers, lstm_dropout=lstm_dropout)
    return model

In [6]:

def objective(trial=None, best_model_config=None, binary_classification=True):
    dataset_name = 'phee'
    file_name_prefix = 'binary_' if binary_classification else ""
    print(file_name_prefix)
    for i in range(1):
        if trial is not None:
            # hyperparameter tuning mode
            data = load_dataset('csv', data_files={'train': f'data/processed/{dataset_name}/fold{i}/{file_name_prefix}train.csv'})
            dataset = data['train'].train_test_split(test_size=0.15, seed=42) # 85% training, 15% validation
        else:
            # best model training and evaluation mode
            dataset = load_dataset('csv', data_files={'train': f'data/processed/{dataset_name}/fold{i}/{file_name_prefix}train.csv', 'test': f'data/processed/{dataset_name}/fold{i}/{file_name_prefix}test.csv'})

        for file_type in ['train', 'test']:
            dataset[file_type] = dataset[file_type].shuffle(seed=42).select(range(int(0.1 * len(dataset[file_type]))))  # Select 10%
            dataset[file_type] = dataset[file_type].map(transform)

        tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-v1.1')
        data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

        tokenized_datasets = dataset.map(
        lambda examples: tokenize_and_align_labels_wrapper(examples, tokenizer),
        batched=True,
        remove_columns=dataset['train'].column_names)

        # word ids of the tokenized test set to find the original words
        test_sentences = dataset['test']['sentence']
        test_word_ids = []
        for sentence in test_sentences:
            test_word_ids.append(tokenizer(sentence, truncation=True, is_split_into_words=True).word_ids())


        lstm_model = model_init(trial, best_model_config)


        # Training arguments
        lstm_training_args = TrainingArguments(
            output_dir='./models/lstm',
            num_train_epochs=2,
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            logging_strategy="epoch",
            evaluation_strategy="epoch",
            save_strategy="epoch",
            overwrite_output_dir=True,
            learning_rate=trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True) if trial is not None else best_model_config['learning_rate'],
            metric_for_best_model="f1",
            load_best_model_at_end=True,
            report_to="none"
        )

        early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.01)
        monitor = TrainingMonitor(test_word_ids)
        # Trainer
        lstm_trainer = Trainer(
            model=lstm_model,
            args=lstm_training_args,
            train_dataset=tokenized_datasets['train'],
            eval_dataset=tokenized_datasets['test'],
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=monitor.compute_metrics_factory(fold_no=i, model_name='lstm', binary_classification=binary_classification),
            callbacks=[early_stopping_callback]
        )

        # Train the model
        lstm_trainer.train()

        if trial is not None:
            eval_result = lstm_trainer.evaluate()
            print(eval_result)
            metric  = eval_result["eval_f1"]
            return metric
        else:
            for eval_dataset_name in ['mtsamples', 'doc-patient']:
                print(f"Evaluating on {eval_dataset_name}")
                eval_dataset = load_dataset('csv', data_files=f'data/processed/{eval_dataset_name}/{file_name_prefix}final_eval.csv')
                eval_dataset = eval_dataset.map(transform)
                
                # word ids of the tokenized mtsamples to find the original words
                eval_sentences = eval_dataset['train']['sentence']
                eval_word_ids = []
                for sentence in eval_sentences:
                    eval_word_ids.append(tokenizer(sentence, truncation=True, is_split_into_words=True).word_ids())

                eval_dataset = eval_dataset.map(lambda examples: tokenize_and_align_labels_wrapper(examples, tokenizer), batched=True, remove_columns=eval_dataset['train'].column_names)

                lstm_trainer.eval_dataset = eval_dataset['train']
                print(eval_dataset['train'])
                monitor.set_eval_dataset_name(eval_dataset_name)
                monitor.set_word_ids(eval_word_ids)
                print(lstm_trainer.get_eval_dataloader(eval_dataset))
                eval_result = lstm_trainer.evaluate() # compute metrics saves the confusion matrix and classification report
                print(eval_result)
                metric  = eval_result["eval_f1"]
                print(f"Eval f1 on {eval_dataset_name}: {metric}")
            return metric



In [8]:
# Initialize Optuna study and start the optimization process
study = optuna.create_study(direction="maximize")
study.optimize(lambda trial: objective(trial, binary_classification=True), n_trials=2)  # Adjust n_trials as needed

# Best trial results
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)
print("  Params: ")
best_params = trial.params
print(best_params)
with open("models/lstm/best_binary_hyperparameters.json", "w") as outfile:
    json.dump(best_params, outfile)

[I 2024-03-17 18:03:29,810] A new study created in memory with name: no-name-185b6c1b-7017-4ec7-9a70-642ed4dff2ff


binary_


Map:   0%|          | 0/56 [00:00<?, ? examples/s]

<optuna.trial._trial.Trial object at 0x7f339b2cf2f0>


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.5757,0.426433,0.80642,0.811608,0.807977,0.810629
2,0.3944,0.37654,0.816725,0.821401,0.818337,0.821108


[I 2024-03-17 18:03:40,429] Trial 0 finished with value: 0.818336878744657 and parameters: {'lstm_hidden_size': 200, 'lstm_num_layers': 1, 'lstm_dropout': 0.19323491096507872, 'learning_rate': 3.843565531379673e-05}. Best is trial 0 with value: 0.818336878744657.


{'eval_loss': 0.37653985619544983, 'eval_precision': 0.8167245141377713, 'eval_recall': 0.8214010371254188, 'eval_f1': 0.818336878744657, 'eval_accuracy': 0.8211077844311377, 'eval_runtime': 0.2551, 'eval_samples_per_second': 219.515, 'eval_steps_per_second': 54.879, 'epoch': 2.0}
binary_
<optuna.trial._trial.Trial object at 0x7f339b2cf2f0>


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.6176,0.494886,0.764306,0.770593,0.763424,0.76497
2,0.48,0.442091,0.781249,0.787302,0.776332,0.776946


[I 2024-03-17 18:03:50,164] Trial 1 finished with value: 0.7763320816451273 and parameters: {'lstm_hidden_size': 100, 'lstm_num_layers': 1, 'lstm_dropout': 0.2274142137625188, 'learning_rate': 1.4618507760222444e-05}. Best is trial 0 with value: 0.818336878744657.


{'eval_loss': 0.4420906603336334, 'eval_precision': 0.7812488769092543, 'eval_recall': 0.7873020971960901, 'eval_f1': 0.7763320816451273, 'eval_accuracy': 0.7769461077844312, 'eval_runtime': 0.2401, 'eval_samples_per_second': 233.283, 'eval_steps_per_second': 58.321, 'epoch': 2.0}
Number of finished trials:  2
Best trial:
  Value:  0.818336878744657
  Params: 
{'lstm_hidden_size': 200, 'lstm_num_layers': 1, 'lstm_dropout': 0.19323491096507872, 'learning_rate': 3.843565531379673e-05}


## Evaluation

In [17]:
# use the who train data for training with best hyperparameters
# evaluate on test set

best_model_config = json.load(open("models/lstm/best_binary_hyperparameters.json", "r"))
objective(None, best_model_config)  # Train and evaluate the best model
 


binary_
None


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.6152,0.498,0.751641,0.747526,0.710316,0.7104
2,0.4668,0.448761,0.76346,0.773989,0.752017,0.753067


  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,


Evaluating on mtsamples


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/101 [00:00<?, ? examples/s]

Map:   0%|          | 0/101 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 101
})
<accelerate.data_loader.DataLoaderShard object at 0x7fb60acb2700>


  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,


{'eval_loss': 0.4467838704586029, 'eval_precision': 0.7506518737992509, 'eval_recall': 0.7646669711501866, 'eval_f1': 0.7454590174238691, 'eval_accuracy': 0.74830220713073, 'eval_runtime': 1.6535, 'eval_samples_per_second': 61.084, 'eval_steps_per_second': 7.862, 'epoch': 2.0}
Eval f1 on mtsamples: 0.7454590174238691
Evaluating on doc-patient


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/88 [00:00<?, ? examples/s]

Map:   0%|          | 0/88 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 88
})
<accelerate.data_loader.DataLoaderShard object at 0x7fb5ec9c4460>


  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,


{'eval_loss': 0.7126114964485168, 'eval_precision': 0.5961882497166846, 'eval_recall': 0.6520208619640886, 'eval_f1': 0.5286400627512502, 'eval_accuracy': 0.5570578691184424, 'eval_runtime': 1.3663, 'eval_samples_per_second': 64.409, 'eval_steps_per_second': 8.051, 'epoch': 2.0}
Eval f1 on doc-patient: 0.5286400627512502


0.5286400627512502