In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer
from datasets import DatasetDict, load_dataset, Dataset, concatenate_datasets
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import numpy as np
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
from src.utils import (load_model_and_tokenizer, 
                       load_line_label_token_data, 
                       line_label_token_label2id, 
                       line_label_token_id2label, 
                       tokenize_and_align_labels,
)
import tqdm
import accelerate
import pandas as pd
import itertools
import evaluate
from collections import Counter

In [2]:
dataset_token = load_line_label_token_data()

In [4]:
model, tokenizer = load_model_and_tokenizer(model_name="line-label-token_medbert-512_finetuned_512",
                                            task_type="token",
                                            )

Tokenizer pad token ID: 0
Tokenizer special tokens: {'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['[BRK]']}
Model pad token ID: 0


In [5]:
encoded_dataset = dataset_token.map(tokenize_and_align_labels, batched=True, fn_kwargs={"tokenizer": tokenizer})

Map:   0%|          | 0/55 [00:00<?, ? examples/s]

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

In [9]:
trainer = Trainer(model=model,
                    data_collator=DataCollatorForTokenClassification(tokenizer),
                    )

In [27]:
predictions, labels, metrics = trainer.predict(encoded_dataset["test"])

In [38]:
def majority_vote(current_predictions: list):
    # Get the most common prediction
    remapped_predictions = [line_label_token_id2label[p][2:] for p in current_predictions] # Remove the B- or I- prefix

    # Counte the occurrences of each class
    class_counts = Counter(remapped_predictions)

    # Find the maximum count
    max_count = max(class_counts.values())

    # Find all classes with the maximum count
    most_common_predictions = [prediction for prediction, count in class_counts.items() if count == max_count]

    # Return the first one among tied classes
    return most_common_predictions[0]
    

def group_labels_by_text(tokenized_texts, predictions):
    line_predictions = []
    current_text_predictions = []

    for token, prediction in zip(tokenized_texts, predictions):
        # The zip will make sure that padded elements from predictions are ignored as overflow is discarded
        # Check for the end of a text line (using the BRK token)
        if '[BRK]' == token:
            # Don't need to add BRK prediction as it is not a token, at this point just add the current text predictions
            line_prediction = majority_vote(current_text_predictions)
            line_predictions.append(line_prediction)
            current_text_predictions = []
            
        else:
            current_text_predictions.append(prediction)

    return line_predictions

def get_results_from_token_preds(predictions:np.ndarray,
                                 dataset:DatasetDict,
                                 split:str="test"):
    """Get a list of line labels from the token predictions. This is done by finding the line breaks in the text
    for each rid, then take a majority vote of the labels for each line. All the line labels are concatenated
    to a list that should match the labels in the dataset. Because of truncation might have bugs
    
    Args:
        predictions (np.ndarray): shape (n_samples, max_len, n_labels)
        dataset (DatasetDict): must contain "input_ids" and "line_label" for the specified split. Line label is a list of one label per line.
        split (str, optional): Split that was used to calculate predictions Defaults to "test"."""
    
    predictions = np.argmax(predictions, axis=2)
    
    preds, labs = [], []
    for i in range(len(dataset[split])):
        # Because of truncation only add labels up to the max length
        recoded_preds = group_labels_by_text(tokenizer.convert_ids_to_tokens(dataset[split][i]["input_ids"]), predictions[i,:])
        max_len = len(recoded_preds)
        preds.extend(recoded_preds)
        labs.extend(dataset["test"][i]["line_label"][:max_len])

    return preds, labs

In [39]:
preds, labs = get_results_from_token_preds(predictions, encoded_dataset, split="test")

In [37]:
len(encoded_dataset["test"])

15

In [42]:
args = ["bli", "bla"]


In [46]:
from peft import PeftConfig, PeftModel
PeftConfig.from_pretrained(paths.MODEL_PATH/"line-label_medbert-512_4bit_LORA_token_finetuned")

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='/mnt/c/Users/marc_/OneDrive/ETH/MSC_Thesis/inf-extr/resources/models/medbert-512', revision=None, task_type='TOKEN_CLS', inference_mode=True, r=16, target_modules={'value', 'query'}, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={})

In [40]:
from sklearn.metrics import classification_report
print(classification_report(labs, preds))

                precision    recall  f1-score   support

            dm       0.87      1.00      0.93        13
          head       1.00      0.98      0.99        44
    his_sym_cu       0.99      0.99      0.99        77
     labr_labo       1.00      1.00      1.00        32
         medms       1.00      1.00      1.00        16
medo_unk_do_so       1.00      0.55      0.71        11
            mr       0.96      1.00      0.98        44
         to_tr       0.86      1.00      0.92        12

      accuracy                           0.97       249
     macro avg       0.96      0.94      0.94       249
  weighted avg       0.97      0.97      0.97       249

