In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import DatasetDict
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import numpy as np
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
import tqdm
import accelerate

# Line Labelling Seq2Seq

To exploit the sequential structure of line labels, I will build a NER model on top of the embedded sentences. It will basically be a token level classification but each token represents a line in the report.

In [None]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling/line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['class_agg']))

# train output of previous model
train_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-train_output.pt')
val_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-val_output.pt')
test_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-test_output.pt')

In [None]:
# Preprocess data for token classification
def get_rid_dict(dataset, split:str = "train"):
    """
    Returns a dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.

    Args:
        dataset (datasets.DatasetDict): DatasetDict containing the train, validation and test datasets.
        split (str, optional): Split of the data. Defaults to "train".

    Returns:
        dict: Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.
    """
    rids = dataset[split]['rid']
    rids_index = {}
    for i, rid in enumerate(rids):
        if rid not in rids_index:
            rids_index[rid] = [i]
        else:
            rids_index[rid].append(i)
    return rids_index

def preprocess_data(data:dict, rid_dict:dict):
    """
    Preprocesses data for token classification.

    Args:
        data (dict of torch.Tensor): Embeddings of the reports. Contains ['embeddings', 'logits', 'labels'] in torch.Tensor format.
        rid_dict (dict): Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values. Used to group text lines.      
    """

    # Extracting embeddings for each report and concatenating them
    data_dict = {}

    max_sentences = max([len(rid_dict[rid]) for rid in rid_dict.keys()])
    print("Max sentences: ", max_sentences)

    for idx, rid in enumerate(rid_dict.keys()):
        data_dict[idx] = {}

        # Embeddings (must pad to max_sentences in dim 2)
        data_dict[idx]["embeddings"] = data["embeddings"][rid_dict[rid], 0, :]
        data_dict[idx]["embeddings"] = torch.nn.functional.pad(data_dict[idx]["embeddings"], (0, 0, 0, max_sentences - data_dict[idx]["embeddings"].shape[0]), "constant", 0)

        # NER Tag with padding
        data_dict[idx]["ner_tag"] = [torch.argmax(label, dim = 0).item() for label in data["labels"][rid_dict[rid]]]
        data_dict[idx]["ner_tag"] = data_dict[idx]["ner_tag"] + [-100]*(max_sentences - len(data_dict[idx]["ner_tag"]))
        data_dict[idx]["ner_tag"] = torch.LongTensor(data_dict[idx]["ner_tag"])

        # Attention mask
        data_dict[idx]["attention_mask"] = torch.ones(max_sentences)
        data_dict[idx]["attention_mask"][data_dict[idx]["ner_tag"] == -100] = 0

        # Clone and detach all tensors
        for key in data_dict[idx].keys():
            data_dict[idx][key] = data_dict[idx][key].clone().detach()

    return data_dict

In [None]:
# Model and tokenizer
#checkpoint = "xlm-roberta-large"

# Roberta Model
model = AutoModelForTokenClassification.from_pretrained(paths.MODEL_PATH/'roberta-xlm-large', num_labels=num_labels)
#tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# Get rid_dict
train_rid_dict = get_rid_dict(dataset, split="train")
val_rid_dict = get_rid_dict(dataset, split="val")
test_rid_dict = get_rid_dict(dataset, split="test")

# Preprocess data
train_data_dict = preprocess_data(train_data, train_rid_dict)
val_data_dict = preprocess_data(val_data, val_rid_dict)
test_data_dict = preprocess_data(test_data, test_rid_dict)

In [None]:
class RobertaXMLModel(torch.nn.Module):
    """
    Model for token classification using a pretrained Roberta model.
    """
    def __init__(self, model, input_dim:int, device:torch.device):
        super(RobertaXMLModel, self).__init__()
        self.device = device
        self.model = model
        self.model.to(self.device)
        self.embed_mapping = torch.nn.Linear(input_dim, self.model.config.hidden_size).to(self.device)
        self.activation = torch.nn.ReLU().to(self.device)

    def forward(self, inputs_embeds, attention_mask, labels=None, output_hidden_states=False):
        """
        Forward pass of the model. Must map from hidden BERT dimension to the dimension of the Roberta Embeddings first,
        then replace the CLS and SEP embeddings with the embeddings of the Roberta model.
        """
        
        # Map to roberta hidden dimension
        inputs_embeds = self.activation(self.embed_mapping(inputs_embeds))

        # Forward pass
        outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states = output_hidden_states)
        return outputs
        

In [None]:
# Model
accelerator = accelerate.Accelerator(mixed_precision='fp16', gradient_accumulation_steps=2)
device = accelerator.device
model_custom = RobertaXMLModel(model=model, input_dim=train_data["embeddings"].shape[2], device=device)

In [None]:
# Training Arguments
batch_size = 4
learning_rate = 5e-6
epochs = 32
weight_decay = 0.01
optimizer = Adam(model_custom.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Create dataset
class TokenLineLabelDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data.keys())

    def __getitem__(self, idx):
        return self.data[idx]

# Training, validation and test dataset
train_dataset = TokenLineLabelDataSet(train_data_dict)
val_dataset = TokenLineLabelDataSet(val_data_dict)
test_dataset = TokenLineLabelDataSet(test_data_dict)

# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Accelerator
model_custom, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model_custom, optimizer, train_dataloader, val_dataloader)

In [None]:
# Training
# Seed
torch.manual_seed(42)

for epoch in range(epochs):

    # Training Loop
    print("Epoch: ", epoch)
    total_acc_train = 0
    total_loss_train = 0
    model.train()

    # Best accuracy and loss for best model
    best_acc = 0
    best_loss = 1000

    # Batch loop
    for batch in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs_embeds = batch["embeddings"]
        attention_mask = batch["attention_mask"]
        labels = batch["ner_tag"]

        outputs = model_custom(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
        
        # Accuracy and loss cleaned
        logits = outputs.logits
        loss = outputs.loss
        logits_clean = logits[labels != -100]
        label_clean = labels[labels != -100]

        predictions = logits_clean.argmax(dim=1)
        acc = (predictions == label_clean).float().mean()

        total_acc_train += acc.item()
        total_loss_train += loss.item()
    
        accelerator.backward(loss)
        optimizer.step()


    # Validation
    model.eval()
    total_acc_val = 0
    total_loss_val = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(val_dataloader):
            inputs_embeds = batch["embeddings"]
            attention_mask = batch["attention_mask"]
            labels = batch["ner_tag"]
            
            outputs = model_custom(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
            

            # Accuracy and loss cleaned
            logits = outputs.logits
            loss = outputs.loss

            logits_clean = logits[labels != -100]
            label_clean = labels[labels != -100]

            predictions = logits_clean.argmax(dim=1)
            acc = (predictions == label_clean).float().mean()

            total_acc_val += acc.item()
            total_loss_val += loss.item()

        val_accuracy = total_acc_val / len(val_dataloader)
        val_loss = total_loss_val / len(val_dataloader)

        avg_train_loss = total_loss_train / len(train_dataloader)
        avg_train_acc = total_acc_train / len(train_dataloader)

        avg_val_loss = total_loss_val / len(val_dataloader)
        avg_val_acc = total_acc_val / len(val_dataloader)

        print(
            f'Epochs: {epoch + 1} | Loss: {avg_train_loss:.3f} | Accuracy: {avg_train_acc:.3f} | Val_Loss: {avg_val_loss:.3f} | Accuracy: {avg_val_acc:.3f}')


        # Save model if validation loss is lower than previous validation loss
        if epoch == 0:
            min_val_loss = val_loss
            min_val_acc = val_accuracy
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(model_custom.state_dict(), paths.MODEL_PATH/'line-label-seq2seq-roberta-best-loss.pt')
        if val_accuracy > min_val_acc:
            min_val_acc = val_accuracy
            torch.save(model_custom.state_dict(), paths.MODEL_PATH/'line-label-seq2seq-roberta-best-accuracy.pt')

In [None]:
# Load model with best loss
model_custom.load_state_dict(torch.load(paths.MODEL_PATH/'line-label-seq2seq-roberta-best-loss.pt'))

In [None]:
# Predictions on test set
test_dataloader = accelerator.prepare(test_dataloader)
model_custom.eval()
test_labels = []
test_logits = []
test_embeddings = []

with torch.no_grad():
    test_loss = 0
    for batch in tqdm.tqdm(test_dataloader):
        inputs_embeds = batch["embeddings"]
        attention_mask = batch["attention_mask"]
        labels = batch["ner_tag"]
            
        outputs = model_custom(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True)
        
        # Accuracy and loss cleaned
        logits_clean = outputs.logits[labels != -100].to("cpu")
        label_clean = labels[labels != -100].to("cpu")
        embeddings_clean = outputs.hidden_states[-1][labels != -100].to("cpu")

        test_labels.append(label_clean)
        test_logits.append(logits_clean)
        test_embeddings.append(embeddings_clean)

In [None]:
# For comparability: saving results in a dictionary with keys "embeddings" torch.tensor(n_test_samples, seq_length, n_classes), 
# "logits" torch.tensor(n_test_samples, n_classes) and true OHE "labels" (n_test_samples, n_classes).

# test_labels need to be OHE
test_labels = torch.cat(test_labels, dim=0)
test_labels = torch.zeros((test_labels.shape[0], num_labels)).scatter_(-1, test_labels.unsqueeze(-1), 1)

# Test Logits
test_logits = torch.cat(test_logits, dim=0)

# Test Embeddings consist of 1 token, results expects sequence, thus unsqueeze at dim 1
test_embeddings = torch.cat(test_embeddings, dim=0).unsqueeze(1)

# Save results
results = {"embeddings": test_embeddings, "logits": test_logits, "labels": test_labels}
torch.save(results, paths.RESULTS_PATH/'line_labelling/RoBERTA-seq2seq-finetuned-test-loss.pt')