In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import DatasetDict
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import numpy as np
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
import tqdm
import accelerate

# Line Labelling Seq2Seq

To exploit the sequential structure of line labels, I will build a NER model on top of the embedded sentences. It will basically be a token level classification but each token represents a line in the report.

In [2]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling/line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['class_agg']))

# train output of previous model
train_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-train_output.pt')
val_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-val_output.pt')
test_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-test_output.pt')

In [3]:
# Preprocess data for token classification
def get_rid_dict(dataset, split:str = "train"):
    """
    Returns a dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.

    Args:
        dataset (datasets.DatasetDict): DatasetDict containing the train, validation and test datasets.
        split (str, optional): Split of the data. Defaults to "train".

    Returns:
        dict: Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.
    """
    rids = dataset[split]['rid']
    rids_index = {}
    for i, rid in enumerate(rids):
        if rid not in rids_index:
            rids_index[rid] = [i]
        else:
            rids_index[rid].append(i)
    return rids_index

def preprocess_data(data:dict, rid_dict:dict):
    """
    Preprocesses data for token classification.

    Args:
        data (dict of torch.Tensor): Embeddings of the reports. Contains ['embeddings', 'logits', 'labels'] in torch.Tensor format.
        rid_dict (dict): Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values. Used to group text lines.      
    """

    # Extracting embeddings for each report and concatenating them
    data_dict = {}

    max_sentences = max([len(rid_dict[rid]) for rid in rid_dict.keys()])
    print("Max sentences: ", max_sentences)

    for idx, rid in enumerate(rid_dict.keys()):
        data_dict[idx] = {}

        # Embeddings (must pad to max_sentences in dim 2)
        data_dict[idx]["embeddings"] = data["embeddings"][rid_dict[rid], 0, :]
        data_dict[idx]["embeddings"] = torch.nn.functional.pad(data_dict[idx]["embeddings"], (0, 0, 0, max_sentences - data_dict[idx]["embeddings"].shape[0]), "constant", 0)

        # NER Tag with padding
        data_dict[idx]["ner_tag"] = [torch.argmax(label, dim = 0).item() for label in data["labels"][rid_dict[rid]]]
        data_dict[idx]["ner_tag"] = data_dict[idx]["ner_tag"] + [-100]*(max_sentences - len(data_dict[idx]["ner_tag"]))
        data_dict[idx]["ner_tag"] = torch.LongTensor(data_dict[idx]["ner_tag"])

        # Attention mask
        data_dict[idx]["attention_mask"] = torch.ones(max_sentences)
        data_dict[idx]["attention_mask"][data_dict[idx]["ner_tag"] == -100] = 0

        # Clone and detach all tensors
        for key in data_dict[idx].keys():
            data_dict[idx][key] = data_dict[idx][key].clone().detach()

    return data_dict

In [4]:
# Model and tokenizer
#checkpoint = "xlm-roberta-large"

# Roberta Model
model = AutoModelForTokenClassification.from_pretrained(paths.MODEL_PATH/'roberta-xlm-large', num_labels=num_labels)
#tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [5]:
# Get rid_dict
train_rid_dict = get_rid_dict(dataset, split="train")
val_rid_dict = get_rid_dict(dataset, split="val")
test_rid_dict = get_rid_dict(dataset, split="test")

# Preprocess data
train_data_dict = preprocess_data(train_data, train_rid_dict)
val_data_dict = preprocess_data(val_data, val_rid_dict)
test_data_dict = preprocess_data(test_data, test_rid_dict)

Max sentences:  57
Max sentences:  35
Max sentences:  41


In [6]:
class RobertaXMLModel(torch.nn.Module):
    """
    Model for token classification using a pretrained Roberta model.
    """
    def __init__(self, model, input_dim:int, device:torch.device):
        super(RobertaXMLModel, self).__init__()
        self.device = device
        self.model = model
        self.model.to(self.device)
        self.embed_mapping = torch.nn.Linear(input_dim, self.model.config.hidden_size).to(self.device)
        self.activation = torch.nn.ReLU().to(self.device)

    def forward(self, inputs_embeds, attention_mask, labels=None, output_hidden_states=False):
        """
        Forward pass of the model. Must map from hidden BERT dimension to the dimension of the Roberta Embeddings first,
        then replace the CLS and SEP embeddings with the embeddings of the Roberta model.
        """
        
        # Map to roberta hidden dimension
        inputs_embeds = self.activation(self.embed_mapping(inputs_embeds))

        # Forward pass
        outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states = output_hidden_states)
        return outputs
        

In [7]:
# Model
accelerator = accelerate.Accelerator(mixed_precision='fp16', gradient_accumulation_steps=2)
device = accelerator.device
model_custom = RobertaXMLModel(model=model, input_dim=train_data["embeddings"].shape[2], device=device)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [8]:
# Training Arguments
batch_size = 4
learning_rate = 5e-6
epochs = 32
weight_decay = 0.01
optimizer = Adam(model_custom.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Create dataset
class TokenLineLabelDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data.keys())

    def __getitem__(self, idx):
        return self.data[idx]

# Training, validation and test dataset
train_dataset = TokenLineLabelDataSet(train_data_dict)
val_dataset = TokenLineLabelDataSet(val_data_dict)
test_dataset = TokenLineLabelDataSet(test_data_dict)

# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Accelerator
model_custom, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model_custom, optimizer, train_dataloader, val_dataloader)

In [9]:
# Training
# Seed
torch.manual_seed(42)

for epoch in range(epochs):

    # Training Loop
    print("Epoch: ", epoch)
    total_acc_train = 0
    total_loss_train = 0
    model.train()

    # Best accuracy and loss for best model
    best_acc = 0
    best_loss = 1000

    # Batch loop
    for batch in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs_embeds = batch["embeddings"]
        attention_mask = batch["attention_mask"]
        labels = batch["ner_tag"]

        outputs = model_custom(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
        
        # Accuracy and loss cleaned
        logits = outputs.logits
        loss = outputs.loss
        logits_clean = logits[labels != -100]
        label_clean = labels[labels != -100]

        predictions = logits_clean.argmax(dim=1)
        acc = (predictions == label_clean).float().mean()

        total_acc_train += acc.item()
        total_loss_train += loss.item()
    
        accelerator.backward(loss)
        optimizer.step()


    # Validation
    model.eval()
    total_acc_val = 0
    total_loss_val = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(val_dataloader):
            inputs_embeds = batch["embeddings"]
            attention_mask = batch["attention_mask"]
            labels = batch["ner_tag"]
            
            outputs = model_custom(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
            

            # Accuracy and loss cleaned
            logits = outputs.logits
            loss = outputs.loss

            logits_clean = logits[labels != -100]
            label_clean = labels[labels != -100]

            predictions = logits_clean.argmax(dim=1)
            acc = (predictions == label_clean).float().mean()

            total_acc_val += acc.item()
            total_loss_val += loss.item()

        val_accuracy = total_acc_val / len(val_dataloader)
        val_loss = total_loss_val / len(val_dataloader)

        avg_train_loss = total_loss_train / len(train_dataloader)
        avg_train_acc = total_acc_train / len(train_dataloader)

        avg_val_loss = total_loss_val / len(val_dataloader)
        avg_val_acc = total_acc_val / len(val_dataloader)

        print(
            f'Epochs: {epoch + 1} | Loss: {avg_train_loss:.3f} | Accuracy: {avg_train_acc:.3f} | Val_Loss: {avg_val_loss:.3f} | Accuracy: {avg_val_acc:.3f}')


        # Save model if validation loss is lower than previous validation loss
        if epoch == 0:
            min_val_loss = val_loss
            min_val_acc = val_accuracy
        elif val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(model_custom.state_dict(), paths.MODEL_PATH/'line-label-seq2seq-roberta-best-loss.pt')
        elif val_accuracy < min_val_acc:
            min_val_acc = val_accuracy
            torch.save(model_custom.state_dict(), paths.MODEL_PATH/'line-label-seq2seq-roberta-best-accuracy.pt')

Epoch:  0


100%|██████████| 12/12 [00:01<00:00,  6.21it/s]
100%|██████████| 3/3 [00:00<00:00, 28.31it/s]


Epochs: 1 | Loss: 2.236 | Accuracy: 0.151 | Val_Loss: 2.101 | Accuracy: 0.247
Epoch:  1


100%|██████████| 12/12 [00:01<00:00,  8.69it/s]
100%|██████████| 3/3 [00:00<00:00, 28.28it/s]


Epochs: 2 | Loss: 2.137 | Accuracy: 0.192 | Val_Loss: 2.068 | Accuracy: 0.250
Epoch:  2


100%|██████████| 12/12 [00:01<00:00,  8.05it/s]
100%|██████████| 3/3 [00:00<00:00, 28.75it/s]


Epochs: 3 | Loss: 2.068 | Accuracy: 0.244 | Val_Loss: 1.948 | Accuracy: 0.355
Epoch:  3


100%|██████████| 12/12 [00:01<00:00,  8.37it/s]
100%|██████████| 3/3 [00:00<00:00, 26.30it/s]


Epochs: 4 | Loss: 1.952 | Accuracy: 0.343 | Val_Loss: 1.798 | Accuracy: 0.428
Epoch:  4


100%|██████████| 12/12 [00:01<00:00,  8.37it/s]
100%|██████████| 3/3 [00:00<00:00, 28.16it/s]


Epochs: 5 | Loss: 1.677 | Accuracy: 0.497 | Val_Loss: 1.465 | Accuracy: 0.630
Epoch:  5


100%|██████████| 12/12 [00:01<00:00,  8.87it/s]
100%|██████████| 3/3 [00:00<00:00, 29.38it/s]


Epochs: 6 | Loss: 1.362 | Accuracy: 0.591 | Val_Loss: 1.206 | Accuracy: 0.714
Epoch:  6


100%|██████████| 12/12 [00:01<00:00,  8.90it/s]
100%|██████████| 3/3 [00:00<00:00, 28.98it/s]


Epochs: 7 | Loss: 1.112 | Accuracy: 0.711 | Val_Loss: 1.076 | Accuracy: 0.720
Epoch:  7


100%|██████████| 12/12 [00:01<00:00,  8.84it/s]
100%|██████████| 3/3 [00:00<00:00, 29.06it/s]


Epochs: 8 | Loss: 0.908 | Accuracy: 0.779 | Val_Loss: 1.000 | Accuracy: 0.735
Epoch:  8


100%|██████████| 12/12 [00:01<00:00,  8.84it/s]
100%|██████████| 3/3 [00:00<00:00, 29.43it/s]


Epochs: 9 | Loss: 0.705 | Accuracy: 0.834 | Val_Loss: 0.860 | Accuracy: 0.750
Epoch:  9


100%|██████████| 12/12 [00:01<00:00,  8.91it/s]
100%|██████████| 3/3 [00:00<00:00, 29.09it/s]


Epochs: 10 | Loss: 0.558 | Accuracy: 0.863 | Val_Loss: 0.907 | Accuracy: 0.805
Epoch:  10


100%|██████████| 12/12 [00:01<00:00,  8.93it/s]
100%|██████████| 3/3 [00:00<00:00, 29.46it/s]


Epochs: 11 | Loss: 0.478 | Accuracy: 0.895 | Val_Loss: 0.941 | Accuracy: 0.805
Epoch:  11


100%|██████████| 12/12 [00:01<00:00,  8.84it/s]
100%|██████████| 3/3 [00:00<00:00, 29.16it/s]


Epochs: 12 | Loss: 0.446 | Accuracy: 0.896 | Val_Loss: 0.939 | Accuracy: 0.780
Epoch:  12


100%|██████████| 12/12 [00:01<00:00,  8.92it/s]
100%|██████████| 3/3 [00:00<00:00, 29.16it/s]


Epochs: 13 | Loss: 0.485 | Accuracy: 0.895 | Val_Loss: 0.962 | Accuracy: 0.779
Epoch:  13


100%|██████████| 12/12 [00:01<00:00,  8.89it/s]
100%|██████████| 3/3 [00:00<00:00, 29.47it/s]


Epochs: 14 | Loss: 0.427 | Accuracy: 0.900 | Val_Loss: 1.015 | Accuracy: 0.799
Epoch:  14


100%|██████████| 12/12 [00:01<00:00,  8.95it/s]
100%|██████████| 3/3 [00:00<00:00, 29.54it/s]


Epochs: 15 | Loss: 0.404 | Accuracy: 0.914 | Val_Loss: 0.924 | Accuracy: 0.797
Epoch:  15


100%|██████████| 12/12 [00:01<00:00,  9.31it/s]
100%|██████████| 3/3 [00:00<00:00, 28.80it/s]


Epochs: 16 | Loss: 0.400 | Accuracy: 0.918 | Val_Loss: 0.934 | Accuracy: 0.812
Epoch:  16


100%|██████████| 12/12 [00:01<00:00,  8.82it/s]
100%|██████████| 3/3 [00:00<00:00, 28.88it/s]


Epochs: 17 | Loss: 0.271 | Accuracy: 0.953 | Val_Loss: 0.984 | Accuracy: 0.817
Epoch:  17


100%|██████████| 12/12 [00:01<00:00,  8.85it/s]
100%|██████████| 3/3 [00:00<00:00, 29.42it/s]


Epochs: 18 | Loss: 0.219 | Accuracy: 0.959 | Val_Loss: 1.036 | Accuracy: 0.813
Epoch:  18


100%|██████████| 12/12 [00:01<00:00,  8.96it/s]
100%|██████████| 3/3 [00:00<00:00, 29.35it/s]


Epochs: 19 | Loss: 0.208 | Accuracy: 0.954 | Val_Loss: 1.090 | Accuracy: 0.817
Epoch:  19


100%|██████████| 12/12 [00:01<00:00,  8.89it/s]
100%|██████████| 3/3 [00:00<00:00, 28.11it/s]


Epochs: 20 | Loss: 0.196 | Accuracy: 0.963 | Val_Loss: 1.058 | Accuracy: 0.813
Epoch:  20


100%|██████████| 12/12 [00:01<00:00,  8.06it/s]
100%|██████████| 3/3 [00:00<00:00, 30.08it/s]


Epochs: 21 | Loss: 0.203 | Accuracy: 0.962 | Val_Loss: 1.010 | Accuracy: 0.813
Epoch:  21


100%|██████████| 12/12 [00:01<00:00,  8.94it/s]
100%|██████████| 3/3 [00:00<00:00, 30.18it/s]


Epochs: 22 | Loss: 0.172 | Accuracy: 0.962 | Val_Loss: 1.067 | Accuracy: 0.813
Epoch:  22


100%|██████████| 12/12 [00:01<00:00,  8.96it/s]
100%|██████████| 3/3 [00:00<00:00, 27.87it/s]


Epochs: 23 | Loss: 0.175 | Accuracy: 0.969 | Val_Loss: 1.079 | Accuracy: 0.813
Epoch:  23


100%|██████████| 12/12 [00:01<00:00,  8.50it/s]
100%|██████████| 3/3 [00:00<00:00, 28.18it/s]


Epochs: 24 | Loss: 0.144 | Accuracy: 0.972 | Val_Loss: 1.095 | Accuracy: 0.808
Epoch:  24


100%|██████████| 12/12 [00:01<00:00,  8.44it/s]
100%|██████████| 3/3 [00:00<00:00, 28.22it/s]


Epochs: 25 | Loss: 0.135 | Accuracy: 0.974 | Val_Loss: 1.136 | Accuracy: 0.813
Epoch:  25


100%|██████████| 12/12 [00:01<00:00,  8.49it/s]
100%|██████████| 3/3 [00:00<00:00, 28.58it/s]


Epochs: 26 | Loss: 0.142 | Accuracy: 0.974 | Val_Loss: 1.155 | Accuracy: 0.813
Epoch:  26


100%|██████████| 12/12 [00:01<00:00,  8.57it/s]
100%|██████████| 3/3 [00:00<00:00, 28.61it/s]


Epochs: 27 | Loss: 0.138 | Accuracy: 0.977 | Val_Loss: 1.141 | Accuracy: 0.813
Epoch:  27


100%|██████████| 12/12 [00:01<00:00,  8.52it/s]
100%|██████████| 3/3 [00:00<00:00, 28.85it/s]


Epochs: 28 | Loss: 0.117 | Accuracy: 0.979 | Val_Loss: 1.139 | Accuracy: 0.817
Epoch:  28


100%|██████████| 12/12 [00:01<00:00,  8.66it/s]
100%|██████████| 3/3 [00:00<00:00, 28.34it/s]


Epochs: 29 | Loss: 0.121 | Accuracy: 0.978 | Val_Loss: 1.183 | Accuracy: 0.813
Epoch:  29


100%|██████████| 12/12 [00:01<00:00,  7.88it/s]
100%|██████████| 3/3 [00:00<00:00, 27.51it/s]


Epochs: 30 | Loss: 0.139 | Accuracy: 0.978 | Val_Loss: 1.157 | Accuracy: 0.813
Epoch:  30


100%|██████████| 12/12 [00:01<00:00,  8.25it/s]
100%|██████████| 3/3 [00:00<00:00, 24.93it/s]


Epochs: 31 | Loss: 0.124 | Accuracy: 0.973 | Val_Loss: 1.191 | Accuracy: 0.813
Epoch:  31


100%|██████████| 12/12 [00:01<00:00,  7.85it/s]
100%|██████████| 3/3 [00:00<00:00, 28.74it/s]

Epochs: 32 | Loss: 0.109 | Accuracy: 0.978 | Val_Loss: 1.181 | Accuracy: 0.813





In [10]:
# Predictions on test set
test_dataloader = accelerator.prepare(test_dataloader)
model_custom.eval()
test_labels = []
test_predictions = []
logits_list = []

with torch.no_grad():
    test_loss = 0
    for batch in tqdm.tqdm(test_dataloader):
        inputs_embeds = batch["embeddings"]
        attention_mask = batch["attention_mask"]
        labels = batch["ner_tag"]
            
        outputs = model_custom(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
        
        # Accuracy and loss cleaned
        logits_batch = outputs.logits
        loss = outputs.loss

        logits = logits_batch[labels != -100]
        label_clean = labels[labels != -100]

        predictions = logits.argmax(dim=1)
        acc = (predictions == label_clean).float().mean()

        test_labels.append(label_clean)
        test_predictions.append(predictions)
        logits_list.append(logits)

        test_loss += loss.item()

test_loss = test_loss / len(test_dataloader)
print("Test loss: ", test_loss)      
            

100%|██████████| 4/4 [00:00<00:00, 46.26it/s]

Test loss:  0.4901980496942997





In [11]:
# Save Results
results = {
    "test_labels": test_labels,
    "test_predictions": test_predictions,
    "logits": logits_list
}

torch.save(results, paths.RESULTS_PATH/'line_labelling/RoBERTA-seq2seq-finetuned-test.pt')

In [12]:
from sklearn.metrics import f1_score, recall_score, precision_score
# F1 score, recall and precision
test_labels = torch.cat(test_labels, dim=0)
test_predictions = torch.cat(test_predictions, dim=0)

print("F1 score: ", f1_score(test_labels.cpu(), test_predictions.cpu(), average='weighted'))
print("Recall: ", recall_score(test_labels.cpu(), test_predictions.cpu(), average='weighted'))
print("Precision: ", precision_score(test_labels.cpu(), test_predictions.cpu(), average='weighted'))

F1 score:  0.9137968328414687
Recall:  0.9159420289855073
Precision:  0.9165121245709806
