In [1]:
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
from transformers.pipelines.pt_utils import KeyDataset
from peft import PeftModel, PeftConfig
from dotenv import dotenv_values
import torch
from tqdm.auto import tqdm

from utils import DataPreprocessor
from utils import data_format_converter

from utils import DataPreprocessor

WANDB_KEY = dotenv_values(".env.base")['WANDB_KEY']
LLAMA_TOKEN = dotenv_values(".env.base")['LLAMA_TOKEN']
HF_TOKEN = dotenv_values(".env.base")['HF_TOKEN']
HF_TOKEN_WRITE = dotenv_values(".env.base")['HF_TOKEN_WRITE']

adapters = "ferrazzipietro/LS_Mistral-7B-v0.1_adapters_en.layer1_NoQuant_16_32_0.01_2_0.0002_nEpochs3"
peft_config = PeftConfig.from_pretrained(adapters, token = HF_TOKEN)
BASE_MODEL_CHECKPOINT = peft_config.base_model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_CHECKPOINT,token =HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token
# seqeval = evaluate.load("seqeval")
DATASET_CHEKPOINT="ferrazzipietro/e3c-sentences" 
TRAIN_LAYER="en.layer1"
preprocessor = DataPreprocessor(BASE_MODEL_CHECKPOINT, 
                                tokenizer)
dataset = load_dataset(DATASET_CHEKPOINT) #download_mode="force_redownload"
dataset = dataset[TRAIN_LAYER]
dataset = dataset.shuffle(seed=1234)  
dataset_format_converter_obj = data_format_converter.DatasetFormatConverter(dataset)
dataset_format_converter_obj.apply()
ds = dataset_format_converter_obj.dataset
label2id = dataset_format_converter_obj.label2id
id2label = dataset_format_converter_obj.get_id2label()
label_list = dataset_format_converter_obj.get_label_list()
dataset_format_converter_obj.set_tokenizer(tokenizer)
dataset_format_converter_obj.set_max_seq_length(256)
tokenized_ds = ds.map(lambda x: dataset_format_converter_obj.tokenize_and_align_labels(x), batched=True)# dataset_format_converter.dataset.map(tokenize_and_align_labels, batched=True)
_, val_data, _ = preprocessor.split_layer_into_train_val_test_(tokenized_ds, TRAIN_LAYER)


  from .autonotebook import tqdm as notebook_tqdm
Map: 100%|██████████| 1520/1520 [00:00<00:00, 2886.61 examples/s]
Map:   0%|          | 0/1520 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 1520/1520 [00:00<00:00, 13592.09 examples/s]

tokenized_inputs:  {'input_ids': [[415, 2903, 302, 272, 367, 5728, 654, 5278, 354, 399, 3384, 2948, 298, 318, 1087, 28735, 28733, 7170, 28790, 28733, 28750, 28723], [6723, 1502, 14178, 7312, 438, 272, 3595, 302, 28705, 28740, 28781, 3370, 28723], [1094, 534, 2920, 1475, 9271, 12423, 687, 20976, 403, 5745, 390, 4123, 28723], [21127, 277, 4475, 8289, 302, 272, 16594, 28765, 6642, 369, 23096, 8894, 654, 5278, 354, 264, 9194, 302, 1581, 6752, 1716, 404, 325, 5072, 28731, 325, 5072, 28770, 28781, 28725, 8204, 28740, 28774, 28725, 8204, 28740, 28734, 28725, 8204, 28750, 28750, 304, 19966, 5278, 354, 8204, 28781, 28782, 28731, 8735, 288, 334, 7016, 9237, 3000, 678, 628, 305, 1082, 721, 806, 23096, 297, 2088, 434, 352, 28723], [2584, 4242, 6403, 403, 28705, 28740, 28734, 28734, 28748, 28787, 28734, 6020, 28769, 28721, 395, 264, 20984, 4338, 302, 28705, 28774, 28783, 347, 1449, 28748, 1240, 28725, 10840, 361, 5377, 4338, 684, 28705, 28740, 28784, 28748, 1240, 304, 21210, 7641, 302, 28705, 28770




In [3]:
import string

class DatasetFormatConverter():
    """
    """
    def __init__(self, dataset):
        self.dataset = dataset
        self.label2id = { "O": 0, "B": 1, "I": 2}

    def get_id2label(self):
        id2label = {v: k for k, v in self.label2id.items()}
        return id2label
    
    def get_label2id(self):
        return self.label2id
    
    def get_label_list(self):
        return list(self.label2id.keys())
    
    def _reformat_entities_dict(self, enitities_dicts_list):
        return [{item.get('text') : item.get('offsets')} for item in enitities_dicts_list]
    
    def _generate_char_based_labels_list(self, example):
        labels = ["O"] * len(example["sentence"])
        for entity in example['entities']:
            # print('entity: ', entity)
            start = entity["offsets"][0]
            end = entity["offsets"][1]
            type = entity["type"]
            labels[start] = f"B-{type}"
            for i in range(start+1, end):
                # print('char: ', example["sentence"][i])
                labels[i] = f"I-{type}"
        return labels
    
    def _contains_punctuation(self, word):
        return any(char in string.punctuation for char in word)

    def _is_only_punctuation(self, word):
        return all(char in string.punctuation for char in word)
    
    def _remove_punctuation_and_count(self, text, punctuation_to_remove = '!"#&\'(),-./:;<=>?@[\\]^_`|'):
        """
        Remove punctuation from the beginning and end of the text and count how many characters were removed.
        """
        count_beginning = len(text) - len(text.lstrip(punctuation_to_remove))
        count_end = len(text) - len(text.rstrip(punctuation_to_remove))
        word_no_punct = text.strip(punctuation_to_remove)
        return word_no_punct, count_beginning, count_end

    def _entities_from_dict_to_labels_list(self, example, word_level=True, token_level=False, tokenizer=None):
        if word_level and token_level:
            raise ValueError("Only one of word_level and token_level can be True")
        if not word_level and not token_level:
            raise ValueError("One of word_level and token_level must be True")
        if token_level and tokenizer is None:
            raise ValueError("tokenizer must be provided if token_level is True")
        if word_level:
            words = example["sentence"].split()
        elif token_level:
            raise NotImplementedError
        labels = [0] * len(words)
        # print(example["entities"])
        chars_based_labels = self._generate_char_based_labels_list(example)
        word_starting_position = 0
        for i, word in enumerate(words):
            # print(f'processing word: {word}\n starting position: {word_starting_position}\n encompassing labels {chars_based_labels[word_starting_position:word_starting_position+len(word)]}')
            if self._is_only_punctuation(word):
                word_starting_position = word_starting_position + len(word) + 1
                continue
            if self._contains_punctuation(word):
                _, count_beginning, count_end = self._remove_punctuation_and_count(word)
                # print(f'remove punctuation from word: {word}\n count beginning: {count_beginning}\n count end: {count_end}')
            else:
                count_beginning, count_end = 0, 0
            word_length = len(word)
            start_word = word_starting_position + count_beginning
            end_word = word_starting_position + word_length - count_end
            chars_labels_of_this_word = chars_based_labels[start_word : end_word]
            if (chars_labels_of_this_word[0].startswith("B-") or chars_labels_of_this_word[0].startswith("I-")) \
                and all([label.startswith("I-") for label in chars_labels_of_this_word[1:]]):
                labels[i] = self.label2id.get(chars_labels_of_this_word[0][0], -1)
            word_starting_position = word_starting_position + word_length + 1
        # print(labels)
        example['words'] = words
        example['word_level_labels'] = labels
        return example

    def apply(self):
        self.dataset = self.dataset.map(self._entities_from_dict_to_labels_list)
        self.dataset = self.dataset.rename_column("word_level_labels", "ner_tags")
        self.dataset = self.dataset.rename_column("words", "tokens")

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer

    def set_max_seq_length(self, max_seq_length):
        self.max_seq_length = max_seq_length

    # def tokenize_and_align_labels(self, examples): COPIED FROM HF, WRONG
    #     """
    #     """
    #     tokenized_inputs = self.tokenizer(examples["tokens"], is_split_into_words=True, padding='longest', max_length=self.max_seq_length, truncation=True)

    #     labels = []
    #     for i, label in enumerate(examples[f"ner_tags"]):
    #         word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
    #         previous_word_idx = None
    #         label_ids = []
    #         for word_idx in word_ids:  # Set the special tokens to -100.
    #             if word_idx is None:
    #                 label_ids.append(-100)
    #             elif word_idx != previous_word_idx:  # Only label the first token of a given word.
    #                 label_ids.append(label[word_idx])
    #             else:
    #                 label_ids.append(-100)
    #             previous_word_idx = word_idx
    #         labels.append(label_ids)
    #     tokenized_inputs["labels"] = labels
    #     return tokenized_inputs

    def tokenize_and_align_labels(self, examples):
        tokenized_inputs = self.tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, add_special_tokens=False)
        print("tokenized_inputs:  ", tokenized_inputs.keys())
        labels = []
        for i, words_label in enumerate(examples[f"ner_tags"]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
            label_ids = []
            for k, word_idx in enumerate(word_ids): 
                same_word_as_previous  = False if (word_idx != word_ids[k-1] or k==0) else True
                if word_idx is None:
                    token_label = -100
                elif words_label[word_idx] == self.label2id['O']:
                    token_label = self.label2id['O']
                elif same_word_as_previous:
                    token_label = self.label2id['I']
                elif not same_word_as_previous:
                    token_label = words_label[word_idx]
                label_ids.append(token_label)
                # if word_idx is not None:#  and k>12:
                #     print("word_label: ", words_label[word_idx])
                # print(tokenizer.decode(tokenized_inputs[i].ids[k]), ": ",word_idx,  "\nassigned_token_label:",  label_ids[k], '\n')
            labels.append(label_ids)

        tokenized_inputs["labels"] = labels
        return tokenized_inputs
        

In [31]:
dataset_format_converter_obj = DatasetFormatConverter(dataset)
dataset_format_converter_obj.apply()
ds = dataset_format_converter_obj.dataset
label2id = dataset_format_converter_obj.label2id
id2label = dataset_format_converter_obj.get_id2label()
label_list = dataset_format_converter_obj.get_label_list()
dataset_format_converter_obj.set_tokenizer(tokenizer)
dataset_format_converter_obj.set_max_seq_length(256)
tokenized_ds = ds.map(lambda x: dataset_format_converter_obj.tokenize_and_align_labels(x), batched=True)# dataset_format_converter.dataset.map(tokenize_and_align_labels, batched=True)
_, val_data, _ = preprocessor.split_layer_into_train_val_test_(tokenized_ds, TRAIN_LAYER)


In [12]:
ds[0].keys()

dict_keys(['sentence', 'entities', 'original_text', 'original_id', 'tokens', 'ner_tags'])

In [32]:
print(len(val_data[0]['input_ids']))
print(len(val_data[0]['attention_mask']))
print(len(val_data[0]['labels']))
tokenized_input = tokenizer(val_data[0]["tokens"], is_split_into_words=True, add_special_tokens=False)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print([(i, l, t) for i, m, l, t in zip(val_data[0]['input_ids'], val_data[0]['attention_mask'], val_data[0]['labels'], tokens)])

22
22
22
[(415, 0, '▁The'), (2903, 0, '▁results'), (302, 0, '▁of'), (272, 0, '▁the'), (367, 1, '▁P'), (5728, 2, 'CR'), (654, 0, '▁were'), (5278, 1, '▁positive'), (354, 0, '▁for'), (399, 0, '▁R'), (3384, 0, 'NA'), (2948, 0, '▁specific'), (298, 0, '▁to'), (318, 1, '▁S'), (1087, 2, 'AR'), (28735, 2, 'S'), (28733, 2, '-'), (7170, 2, 'Co'), (28790, 2, 'V'), (28733, 2, '-'), (28750, 2, '2'), (28723, 2, '.')]
