In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import DatasetDict, Dataset
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
import tqdm
import json

# Line Labelling Seq2Seq

To exploit the sequential structure of line labels, I will build a NER model on top of the embedded sentences. It will basically be a token level classification but each token represents a line in the report.

In [2]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling/line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['class_agg']))

# train output of previous model
train_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-train_output.pt')
val_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-val_output.pt')
test_data = torch.load(paths.RESULTS_PATH/'line_labelling/medBERT-finetuned-test_output.pt')

In [3]:
# Preprocess data for token classification
def get_rid_dict(dataset, split:str = "train"):
    """
    Returns a dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.

    Args:
        dataset (datasets.DatasetDict): DatasetDict containing the train, validation and test datasets.
        split (str, optional): Split of the data. Defaults to "train".

    Returns:
        dict: Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values.
    """
    rids = dataset[split]['rid']
    rids_index = {}
    for i, rid in enumerate(rids):
        if rid not in rids_index:
            rids_index[rid] = [i]
        else:
            rids_index[rid].append(i)
    return rids_index

def preprocess_data(data:dict, rid_dict:dict, model: torch.nn.Module, tokenizer):
    """
    Preprocesses data for token classification.

    Args:
        data (dict of torch.Tensor): Embeddings of the reports. Contains ['embeddings', 'logits', 'labels'] in torch.Tensor format.
        rid_dict (dict): Dictionary containing the rids of the reports as keys and the corresponding index in the dataset as values. Used to group text lines.      
        model: Model used for token classification. Contains pretrained model and tokenizer.
        tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for token classification.
    """

    # Hidden dimension of the model
    hidden_dim = model.config.hidden_size
    assert hidden_dim == data["embeddings"].shape[2], "Hidden dimension of the model and the embeddings do not match."

    # CLS and SEP token embeddings
    cls_embeds = model.get_input_embeddings()(torch.tensor([tokenizer.cls_token_id], device="cuda")).to("cpu")
    sep_embeds = model.get_input_embeddings()(torch.tensor([tokenizer.sep_token_id], device="cuda")).to("cpu")

    # Extracting embeddings for each report and concatenating them
    data_dict = {}

    # Need to pad input_embeddings to have the same number of sentences for each report
    # Longest report + 2 (CLS and SEP)
    max_sentences = max([len(rid_dict[rid]) for rid in rid_dict.keys()]) + 2
    print("Max sentences: ", max_sentences)

    for idx, rid in enumerate(rid_dict.keys()):
        data_dict[idx] = {}

        # Embeddings (note: by indexing we already get a tensor of shape (num_sentences, hidden_dim) with num_sentences the number of sentences in the report))
        data_dict[idx]["embeddings"] = data["embeddings"][rid_dict[rid], 0, :]

        # Append CLS and SEP embeddings at the beginning and at the end
        data_dict[idx]["embeddings"] = torch.cat((cls_embeds, data_dict[idx]["embeddings"], sep_embeds), dim=0)

        # Pad to max_sentences
        data_dict[idx]["embeddings"] = torch.nn.functional.pad(data_dict[idx]["embeddings"], (0, 0, 0, max_sentences - data_dict[idx]["embeddings"].shape[0]), "constant", 0)
        
        # NER Tag with padding
        data_dict[idx]["ner_tag"] = [torch.argmax(label, dim = 0).item() for label in data["labels"][rid_dict[rid]]]
        data_dict[idx]["ner_tag"] = [-100] + data_dict[idx]["ner_tag"] + [-100] + [-100]*(max_sentences - len(data_dict[idx]["ner_tag"]) - 2)
        data_dict[idx]["ner_tag"] = torch.LongTensor(data_dict[idx]["ner_tag"])

        # Attention mask
        data_dict[idx]["attention_mask"] = torch.ones(max_sentences)
        data_dict[idx]["attention_mask"][data_dict[idx]["ner_tag"] == -100] = 0

        # Token type ids
        data_dict[idx]["token_type_ids"] = torch.zeros(max_sentences, dtype=torch.long)

        # Clone and detach all tensors
        for key in data_dict[idx].keys():
            data_dict[idx][key] = data_dict[idx][key].clone().detach()

    return data_dict

In [4]:
# Model and tokenizer
model = AutoModelForTokenClassification.from_pretrained(paths.MODEL_PATH/'medbert-diag-label')
tokenizer = AutoTokenizer.from_pretrained(paths.MODEL_PATH/'medbert-512')

# Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [5]:
# Get rid_dict
train_rid_dict = get_rid_dict(dataset, split="train")
val_rid_dict = get_rid_dict(dataset, split="val")
test_rid_dict = get_rid_dict(dataset, split="test")

# Preprocess data
train_data_dict = preprocess_data(train_data, train_rid_dict, model, tokenizer)
val_data_dict = preprocess_data(val_data, val_rid_dict, model, tokenizer)
test_data_dict = preprocess_data(test_data, test_rid_dict, model, tokenizer)

Max sentences:  59
Max sentences:  37
Max sentences:  43


In [6]:
# Create dataset
class TokenLineLabelDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data.keys())

    def __getitem__(self, idx):
        return self.data[idx]

# Training, validation and test dataset
train_dataset = TokenLineLabelDataSet(train_data_dict)
val_dataset = TokenLineLabelDataSet(val_data_dict)
test_dataset = TokenLineLabelDataSet(test_data_dict)

# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Training Arguments
batch_size = 8
learning_rate = 5e-6
epochs = 32
weight_decay = 0.01
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [14]:
# Training
# Seed
torch.manual_seed(42)

for epoch in range(epochs):

    # Training Loop
    print("Epoch: ", epoch)
    total_acc_train = 0
    total_loss_train = 0
    model.train()

    # Best accuracy and loss for best model
    best_acc = 0
    best_loss = 1000

    # Batch loop
    for batch in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs_embeds = batch["embeddings"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        labels = batch["ner_tag"].to(device)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
        
        # Accuracy and loss cleaned
        logits = outputs.logits
        loss = outputs.loss
        logits_clean = logits[labels != -100]
        label_clean = labels[labels != -100]

        predictions = logits_clean.argmax(dim=1)
        acc = (predictions == label_clean).float().mean()

        total_acc_train += acc.item()
        total_loss_train += loss.item()
    
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    total_acc_val = 0
    total_loss_val = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(val_dataloader):
            inputs_embeds = batch["embeddings"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["ner_tag"].to(device)
            outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
            

            # Accuracy and loss cleaned
            logits = outputs.logits
            loss = outputs.loss

            logits_clean = logits[labels != -100]
            label_clean = labels[labels != -100]

            predictions = logits_clean.argmax(dim=1)
            acc = (predictions == label_clean).float().mean()

            total_acc_val += acc.item()
            total_loss_val += loss.item()

        val_accuracy = total_acc_val / len(val_dataloader)
        val_loss = total_loss_val / len(val_dataloader)

        avg_train_loss = total_loss_train / len(train_dataloader)
        avg_train_acc = total_acc_train / len(train_dataloader)

        avg_val_loss = total_loss_val / len(val_dataloader)
        avg_val_acc = total_acc_val / len(val_dataloader)

        print(
            f'Epochs: {epoch + 1} | Loss: {avg_train_loss:.3f} | Accuracy: {avg_train_acc:.3f} | Val_Loss: {avg_val_loss:.3f} | Accuracy: {avg_val_acc:.3f}')

        # Save model if validation loss is lower than previous validation loss
        if epoch == 0:
            min_val_loss = val_loss
        elif val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(model.state_dict(), paths.MODEL_PATH/'line-label-seq2seq-medbert-min-loss.pt')
        elif val_accuracy < min_val_acc:
            min_val_acc = val_accuracy
            torch.save(model.state_dict(), paths.MODEL_PATH/'line-label-seq2seq-medbert-best-accuracy.pt')

Epoch:  0


  0%|          | 0/6 [00:00<?, ?it/s]

100%|██████████| 6/6 [00:01<00:00,  4.21it/s]
100%|██████████| 2/2 [00:00<00:00,  8.90it/s]


Epochs: 1 | Loss: 1.997 | Accuracy: 0.219 | Val_Loss: 1.843 | Accuracy: 0.384
Epoch:  1


100%|██████████| 6/6 [00:01<00:00,  5.65it/s]
100%|██████████| 2/2 [00:00<00:00, 13.15it/s]


Epochs: 2 | Loss: 1.701 | Accuracy: 0.440 | Val_Loss: 1.544 | Accuracy: 0.551
Epoch:  2


100%|██████████| 6/6 [00:01<00:00,  5.54it/s]
100%|██████████| 2/2 [00:00<00:00, 12.51it/s]


Epochs: 3 | Loss: 1.378 | Accuracy: 0.678 | Val_Loss: 1.266 | Accuracy: 0.727
Epoch:  3


100%|██████████| 6/6 [00:01<00:00,  5.61it/s]
100%|██████████| 2/2 [00:00<00:00, 10.66it/s]


Epochs: 4 | Loss: 1.084 | Accuracy: 0.784 | Val_Loss: 1.082 | Accuracy: 0.743
Epoch:  4


100%|██████████| 6/6 [00:01<00:00,  5.49it/s]
100%|██████████| 2/2 [00:00<00:00, 13.14it/s]


Epochs: 5 | Loss: 0.820 | Accuracy: 0.837 | Val_Loss: 0.968 | Accuracy: 0.774
Epoch:  5


100%|██████████| 6/6 [00:01<00:00,  5.42it/s]
100%|██████████| 2/2 [00:00<00:00, 13.92it/s]


Epochs: 6 | Loss: 0.650 | Accuracy: 0.881 | Val_Loss: 0.906 | Accuracy: 0.782
Epoch:  6


100%|██████████| 6/6 [00:01<00:00,  5.55it/s]
100%|██████████| 2/2 [00:00<00:00, 13.35it/s]


Epochs: 7 | Loss: 0.525 | Accuracy: 0.910 | Val_Loss: 0.893 | Accuracy: 0.790
Epoch:  7


100%|██████████| 6/6 [00:01<00:00,  5.62it/s]
100%|██████████| 2/2 [00:00<00:00, 11.19it/s]


Epochs: 8 | Loss: 0.426 | Accuracy: 0.935 | Val_Loss: 0.892 | Accuracy: 0.811
Epoch:  8


100%|██████████| 6/6 [00:01<00:00,  5.55it/s]
100%|██████████| 2/2 [00:00<00:00, 14.91it/s]


Epochs: 9 | Loss: 0.350 | Accuracy: 0.947 | Val_Loss: 0.883 | Accuracy: 0.811
Epoch:  9


100%|██████████| 6/6 [00:01<00:00,  5.67it/s]
100%|██████████| 2/2 [00:00<00:00, 10.63it/s]


Epochs: 10 | Loss: 0.306 | Accuracy: 0.945 | Val_Loss: 0.859 | Accuracy: 0.811
Epoch:  10


100%|██████████| 6/6 [00:01<00:00,  5.55it/s]
100%|██████████| 2/2 [00:00<00:00, 14.47it/s]


Epochs: 11 | Loss: 0.266 | Accuracy: 0.961 | Val_Loss: 0.873 | Accuracy: 0.811
Epoch:  11


100%|██████████| 6/6 [00:01<00:00,  5.60it/s]
100%|██████████| 2/2 [00:00<00:00, 11.43it/s]


Epochs: 12 | Loss: 0.239 | Accuracy: 0.959 | Val_Loss: 0.899 | Accuracy: 0.811
Epoch:  12


100%|██████████| 6/6 [00:01<00:00,  5.48it/s]
100%|██████████| 2/2 [00:00<00:00, 14.11it/s]


Epochs: 13 | Loss: 0.210 | Accuracy: 0.963 | Val_Loss: 0.902 | Accuracy: 0.811
Epoch:  13


100%|██████████| 6/6 [00:01<00:00,  5.47it/s]
100%|██████████| 2/2 [00:00<00:00, 14.51it/s]


Epochs: 14 | Loss: 0.183 | Accuracy: 0.970 | Val_Loss: 0.905 | Accuracy: 0.811
Epoch:  14


100%|██████████| 6/6 [00:01<00:00,  5.44it/s]
100%|██████████| 2/2 [00:00<00:00, 14.44it/s]


Epochs: 15 | Loss: 0.182 | Accuracy: 0.963 | Val_Loss: 0.916 | Accuracy: 0.811
Epoch:  15


100%|██████████| 6/6 [00:01<00:00,  5.43it/s]
100%|██████████| 2/2 [00:00<00:00, 12.06it/s]


Epochs: 16 | Loss: 0.175 | Accuracy: 0.965 | Val_Loss: 0.927 | Accuracy: 0.811
Epoch:  16


100%|██████████| 6/6 [00:01<00:00,  5.47it/s]
100%|██████████| 2/2 [00:00<00:00, 12.76it/s]


Epochs: 17 | Loss: 0.165 | Accuracy: 0.965 | Val_Loss: 0.948 | Accuracy: 0.811
Epoch:  17


100%|██████████| 6/6 [00:01<00:00,  5.36it/s]
100%|██████████| 2/2 [00:00<00:00, 11.77it/s]


Epochs: 18 | Loss: 0.155 | Accuracy: 0.972 | Val_Loss: 0.948 | Accuracy: 0.811
Epoch:  18


100%|██████████| 6/6 [00:01<00:00,  5.49it/s]
100%|██████████| 2/2 [00:00<00:00, 13.18it/s]


Epochs: 19 | Loss: 0.141 | Accuracy: 0.974 | Val_Loss: 0.941 | Accuracy: 0.809
Epoch:  19


100%|██████████| 6/6 [00:01<00:00,  5.41it/s]
100%|██████████| 2/2 [00:00<00:00, 12.39it/s]


Epochs: 20 | Loss: 0.136 | Accuracy: 0.973 | Val_Loss: 0.947 | Accuracy: 0.809
Epoch:  20


100%|██████████| 6/6 [00:01<00:00,  5.54it/s]
100%|██████████| 2/2 [00:00<00:00, 10.23it/s]


Epochs: 21 | Loss: 0.126 | Accuracy: 0.977 | Val_Loss: 0.977 | Accuracy: 0.809
Epoch:  21


100%|██████████| 6/6 [00:01<00:00,  5.42it/s]
100%|██████████| 2/2 [00:00<00:00, 10.46it/s]


Epochs: 22 | Loss: 0.122 | Accuracy: 0.976 | Val_Loss: 1.005 | Accuracy: 0.809
Epoch:  22


100%|██████████| 6/6 [00:01<00:00,  5.37it/s]
100%|██████████| 2/2 [00:00<00:00, 13.04it/s]


Epochs: 23 | Loss: 0.119 | Accuracy: 0.976 | Val_Loss: 0.999 | Accuracy: 0.809
Epoch:  23


100%|██████████| 6/6 [00:01<00:00,  5.50it/s]
100%|██████████| 2/2 [00:00<00:00, 13.74it/s]


Epochs: 24 | Loss: 0.115 | Accuracy: 0.976 | Val_Loss: 1.021 | Accuracy: 0.811
Epoch:  24


100%|██████████| 6/6 [00:01<00:00,  5.51it/s]
100%|██████████| 2/2 [00:00<00:00,  9.76it/s]


Epochs: 25 | Loss: 0.112 | Accuracy: 0.980 | Val_Loss: 1.016 | Accuracy: 0.814
Epoch:  25


100%|██████████| 6/6 [00:01<00:00,  5.37it/s]
100%|██████████| 2/2 [00:00<00:00, 11.42it/s]


Epochs: 26 | Loss: 0.103 | Accuracy: 0.977 | Val_Loss: 1.019 | Accuracy: 0.814
Epoch:  26


100%|██████████| 6/6 [00:01<00:00,  5.50it/s]
100%|██████████| 2/2 [00:00<00:00, 10.62it/s]


Epochs: 27 | Loss: 0.104 | Accuracy: 0.981 | Val_Loss: 1.032 | Accuracy: 0.814
Epoch:  27


100%|██████████| 6/6 [00:01<00:00,  5.40it/s]
100%|██████████| 2/2 [00:00<00:00, 10.56it/s]


Epochs: 28 | Loss: 0.094 | Accuracy: 0.981 | Val_Loss: 1.028 | Accuracy: 0.814
Epoch:  28


100%|██████████| 6/6 [00:01<00:00,  5.50it/s]
100%|██████████| 2/2 [00:00<00:00, 12.25it/s]


Epochs: 29 | Loss: 0.091 | Accuracy: 0.981 | Val_Loss: 1.039 | Accuracy: 0.814
Epoch:  29


100%|██████████| 6/6 [00:01<00:00,  5.39it/s]
100%|██████████| 2/2 [00:00<00:00, 11.68it/s]


Epochs: 30 | Loss: 0.105 | Accuracy: 0.974 | Val_Loss: 1.040 | Accuracy: 0.811
Epoch:  30


100%|██████████| 6/6 [00:01<00:00,  5.18it/s]
100%|██████████| 2/2 [00:00<00:00, 13.45it/s]


Epochs: 31 | Loss: 0.095 | Accuracy: 0.981 | Val_Loss: 1.026 | Accuracy: 0.811
Epoch:  31


100%|██████████| 6/6 [00:01<00:00,  5.48it/s]
100%|██████████| 2/2 [00:00<00:00, 11.52it/s]

Epochs: 32 | Loss: 0.097 | Accuracy: 0.979 | Val_Loss: 1.020 | Accuracy: 0.811





In [7]:
# Load best model
model.load_state_dict(torch.load(paths.MODEL_PATH/'line-label-seq2seq.pt'))

<All keys matched successfully>

In [26]:
# Preparing Results
model.eval()
test_labels = []
test_logits = []
test_embeddings = []
with torch.no_grad():
    test_loss = 0
    for batch in tqdm.tqdm(test_dataloader):
        inputs_embeds = batch["embeddings"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        labels = batch["ner_tag"].to(device)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, output_hidden_states=True)
        
        # Accuracy and loss cleaned
        logits_clean = outputs.logits[labels != -100].to("cpu")
        label_clean = labels[labels != -100].to("cpu")
        embeddings_clean = outputs.hidden_states[-1][labels != -100].to("cpu")

        test_labels.append(label_clean)
        test_logits.append(logits_clean)
        test_embeddings.append(embeddings_clean)
            

100%|██████████| 2/2 [00:01<00:00,  1.90it/s]


In [27]:
# For comparability: saving results in a dictionary with keys "embeddings" torch.tensor(n_test_samples, seq_length, n_classes), 
# "logits" torch.tensor(n_test_samples, n_classes) and true OHE "labels" (n_test_samples, n_classes).

# test_labels need to be OHE
test_labels = torch.cat(test_labels, dim=0)
test_labels = torch.zeros((test_labels.shape[0], num_labels)).scatter_(-1, test_labels.unsqueeze(-1), 1)

# Test Logits
test_logits = torch.cat(test_logits, dim=0)

# Test Embeddings consist of 1 token, results expects sequence, thus unsqueeze at dim 1
test_embeddings = torch.cat(test_embeddings, dim=0).unsqueeze(1)

# Save results
results = {"embeddings": test_embeddings, "logits": test_logits, "labels": test_labels}
torch.save(results, paths.RESULTS_PATH/'line_labelling/medbert-token-finetuned-test_output.pt')