In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from datasets import DatasetDict, Sequence, Value, Features, Dataset, concatenate_datasets
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import numpy as np
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
from src.utils import model_output
import tqdm
import json

# Line Labelling Seq2Seq

To exploit the sequential structure of line labels, I will build a NER model on top of the embedded sentences. It will basically be a token level classification but each token represents a line in the report.

In [None]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling/line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['class_agg']))

# train output of previous model
train_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-train_output.pt')
val_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-val_output.pt')
test_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-test_output.pt')

In [None]:
# Preprocess data for token classification
def get_rid_dict(dataset, split:str = "train"):
    """
    Returns a dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.

    Args:
        dataset (datasets.DatasetDict): DatasetDict containing the train, validation and test datasets.
        split (str, optional): Split of the data. Defaults to "train".

    Returns:
        dict: Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.
    """
    rids = dataset[split]['rid']
    rids_index = {}
    for i, rid in enumerate(rids):
        if rid not in rids_index:
            rids_index[rid] = [i]
        else:
            rids_index[rid].append(i)
    return rids_index

def preprocess_data(data:dict, rid_dict:dict, model: torch.nn.Module, tokenizer):
    """
    Preprocesses data for token classification.

    Args:
        data (dict of torch.Tensor): Embeddings of the reports. Contains ['embeddings', 'logits', 'labels'] in torch.Tensor format.
        rid_dict (dict): Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values. Used to group text lines.      
        model: Model used for token classification. Contains pretrained model and tokenizer.
        tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for token classification.
    """

    # Hidden dimension of the model
    hidden_dim = model.config.hidden_size
    assert hidden_dim == data["embeddings"].shape[2], "Hidden dimension of the model and the embeddings do not match."

    # CLS and SEP token embeddings
    cls_embeds = model.get_input_embeddings()(torch.tensor([tokenizer.cls_token_id], device="cuda")).to("cpu")
    sep_embeds = model.get_input_embeddings()(torch.tensor([tokenizer.sep_token_id], device="cuda")).to("cpu")

    # Extracting embeddings for each report and concatenating them
    data_dict = {}

    # Need to pad input_embeddings to have the same number of sentences for each report
    # Longest report + 2 (CLS and SEP)
    max_sentences = max([len(rid_dict[rid]) for rid in rid_dict.keys()]) + 2
    print("Max sentences: ", max_sentences)

    for idx, rid in enumerate(rid_dict.keys()):
        data_dict[idx] = {}

        # Embeddings (note: by indexing we already get a tensor of shape (num_sentences, hidden_dim) with num_sentences the number of sentences in the report))
        data_dict[idx]["embeddings"] = data["embeddings"][rid_dict[rid], 0, :]

        # Append CLS and SEP embeddings at the beginning and at the end
        data_dict[idx]["embeddings"] = torch.cat((cls_embeds, data_dict[idx]["embeddings"], sep_embeds), dim=0)

        # Pad to max_sentences
        data_dict[idx]["embeddings"] = torch.nn.functional.pad(data_dict[idx]["embeddings"], (0, 0, 0, max_sentences - data_dict[idx]["embeddings"].shape[0]), "constant", 0)
        
        # NER Tag with padding
        data_dict[idx]["ner_tag"] = [torch.argmax(label, dim = 0).item() for label in data["labels"][rid_dict[rid]]]
        data_dict[idx]["ner_tag"] = [-100] + data_dict[idx]["ner_tag"] + [-100] + [-100]*(max_sentences - len(data_dict[idx]["ner_tag"]) - 2)
        data_dict[idx]["ner_tag"] = torch.LongTensor(data_dict[idx]["ner_tag"])

        # Attention mask
        data_dict[idx]["attention_mask"] = torch.ones(max_sentences)
        data_dict[idx]["attention_mask"][data_dict[idx]["ner_tag"] == -100] = 0

        # Token type ids
        data_dict[idx]["token_type_ids"] = torch.zeros(max_sentences, dtype=torch.long)

        # Clone and detach all tensors
        for key in data_dict[idx].keys():
            data_dict[idx][key] = data_dict[idx][key].clone().detach()

    return data_dict

In [None]:
# Model and tokenizer
model = AutoModelForTokenClassification.from_pretrained(paths.MODEL_PATH/'medbert-diag-label')
tokenizer = AutoTokenizer.from_pretrained(paths.MODEL_PATH/'medbert-512')

# Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)

In [None]:
# Get rid_dict
train_rid_dict = get_rid_dict(dataset, split="train")
val_rid_dict = get_rid_dict(dataset, split="val")
test_rid_dict = get_rid_dict(dataset, split="test")

# Preprocess data
train_data_dict = preprocess_data(train_data, train_rid_dict, model, tokenizer)
val_data_dict = preprocess_data(val_data, val_rid_dict, model, tokenizer)
test_data_dict = preprocess_data(test_data, test_rid_dict, model, tokenizer)

In [None]:
# Create dataset
class TokenLineLabelDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data.keys())

    def __getitem__(self, idx):
        return self.data[idx]

# Training, validation and test dataset
train_dataset = TokenLineLabelDataSet(train_data_dict)
val_dataset = TokenLineLabelDataSet(val_data_dict)
test_dataset = TokenLineLabelDataSet(test_data_dict)

# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Training Arguments
batch_size = 8
learning_rate = 5e-6
epochs = 32
weight_decay = 0.01
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [None]:
# Training
# Seed
torch.manual_seed(42)

for epoch in range(epochs):

    # Training Loop
    print("Epoch: ", epoch)
    total_acc_train = 0
    total_loss_train = 0
    model.train()

    # Best accuracy and loss for best model
    best_acc = 0
    best_loss = 1000

    # Batch loop
    for batch in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs_embeds = batch["embeddings"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        labels = batch["ner_tag"].to(device)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
        
        # Accuracy and loss cleaned
        logits = outputs.logits
        loss = outputs.loss
        logits_clean = logits[labels != -100]
        label_clean = labels[labels != -100]

        predictions = logits_clean.argmax(dim=1)
        acc = (predictions == label_clean).float().mean()

        total_acc_train += acc.item()
        total_loss_train += loss.item()
    

        # for i in range(logits.shape[0]):

        #     logits_clean = logits[i][labels[i] != -100]
        #     label_clean = labels[i][labels[i] != -100]

        #     predictions = logits_clean.argmax(dim=1)
        #     acc = (predictions == label_clean).float().mean()
        #     total_acc_train += acc
        #     total_loss_train += loss.item()
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    total_acc_val = 0
    total_loss_val = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(val_dataloader):
            inputs_embeds = batch["embeddings"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["ner_tag"].to(device)
            outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
            

            # Accuracy and loss cleaned
            logits = outputs.logits
            loss = outputs.loss

            logits_clean = logits[labels != -100]
            label_clean = labels[labels != -100]

            predictions = logits_clean.argmax(dim=1)
            acc = (predictions == label_clean).float().mean()

            total_acc_val += acc.item()
            total_loss_val += loss.item()
            
            # for i in range(logits.shape[0]):

            #   logits_clean = logits[i][labels[i] != -100]
            #   label_clean = labels[i][labels[i] != -100]

            #   predictions = logits_clean.argmax(dim=1)
            #   acc = (predictions == label_clean).float().mean()
            #   total_acc_val += acc
            #   total_loss_val += loss.item()

        val_accuracy = total_acc_val / len(val_dataloader)
        val_loss = total_loss_val / len(val_dataloader)

        avg_train_loss = total_loss_train / len(train_dataloader)
        avg_train_acc = total_acc_train / len(train_dataloader)

        avg_val_loss = total_loss_val / len(val_dataloader)
        avg_val_acc = total_acc_val / len(val_dataloader)

        print(
            f'Epochs: {epoch + 1} | Loss: {avg_train_loss:.3f} | Accuracy: {avg_train_acc:.3f} | Val_Loss: {avg_val_loss:.3f} | Accuracy: {avg_val_acc:.3f}')

        # print(
        #     f'Epochs: {epoch + 1} | Loss: {total_loss_train / len(train_dataloader): .3f} | Accuracy: {total_acc_train / len(train_dataloader): .3f} | Val_Loss: {total_loss_val / len(val_dataloader): .3f} | Accuracy: {total_acc_val / len(val_dataloader): .3f}')



        # Save model if validation loss is lower than previous validation loss
        if epoch == 0:
            min_val_loss = val_loss
        elif val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(model.state_dict(), paths.MODEL_PATH/'line-label-seq2seq.pt')

In [None]:
# Load best model
model.load_state_dict(torch.load(paths.MODEL_PATH/'line-label-seq2seq.pt'))

In [None]:
# Predictions on test set
model.eval()
test_labels = []
test_predictions = []
with torch.no_grad():
    test_loss = 0
    for batch in tqdm.tqdm(test_dataloader):
        inputs_embeds = batch["embeddings"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        labels = batch["ner_tag"].to(device)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
        
        # Accuracy and loss cleaned
        logits_batch = outputs.logits
        loss = outputs.loss

        logits = logits_batch[labels != -100]
        label_clean = labels[labels != -100]

        predictions = logits.argmax(dim=1)
        acc = (predictions == label_clean).float().mean()

        test_labels.append(label_clean)
        test_predictions.append(predictions)

        test_loss += loss.item()

test_loss = test_loss / len(test_dataloader)
print("Test loss: ", test_loss)      
            

In [None]:
from sklearn.metrics import f1_score, recall_score, precision_score
# F1 score, recall and precision
test_labels = torch.cat(test_labels, dim=0)
test_predictions = torch.cat(test_predictions, dim=0)

print("F1 score: ", f1_score(test_labels.cpu(), test_predictions.cpu(), average='weighted'))
print("Recall: ", recall_score(test_labels.cpu(), test_predictions.cpu(), average='weighted'))
print("Precision: ", precision_score(test_labels.cpu(), test_predictions.cpu(), average='weighted'))