In [1]:
import numpy as np
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
from transformers.pipelines.pt_utils import KeyDataset
from peft import PeftModel, PeftConfig
from dotenv import dotenv_values
import torch
from tqdm.auto import tqdm

from utils import DataPreprocessor, DatasetFormatConverter
from src.billm.modeling_mistral import MistralForTokenClassification

WANDB_KEY = dotenv_values(".env.base")['WANDB_KEY']
LLAMA_TOKEN = dotenv_values(".env.base")['LLAMA_TOKEN']
HF_TOKEN = dotenv_values(".env.base")['HF_TOKEN']

adapters = "ferrazzipietro/LS_Mistral-7B-v0.1_adapters_en.layer1_NoQuant_16_32_0.01_2_0.0002"
peft_config = PeftConfig.from_pretrained(adapters)
BASE_MODEL_CHECKPOINT = peft_config.base_model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_CHECKPOINT,token =HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token
# seqeval = evaluate.load("seqeval")
DATASET_CHEKPOINT="ferrazzipietro/e3c-sentences" 
TRAIN_LAYER="en.layer1"
preprocessor = DataPreprocessor(BASE_MODEL_CHECKPOINT, 
                                tokenizer)
dataset = load_dataset(DATASET_CHEKPOINT) #download_mode="force_redownload"
dataset = dataset[TRAIN_LAYER]
dataset = dataset.shuffle(seed=1234)  
dataset_format_converter = DatasetFormatConverter(dataset)
dataset_format_converter.apply()
ds = dataset_format_converter.dataset
label2id = dataset_format_converter.label2id
id2label = dataset_format_converter.get_id2label()
label_list = dataset_format_converter.get_label_list()
dataset_format_converter.set_tokenizer(tokenizer)
dataset_format_converter.set_max_seq_length(256)
tokenized_ds = ds.map(lambda x: dataset_format_converter.tokenize_and_align_labels(x), batched=True)# dataset_format_converter.dataset.map(tokenize_and_align_labels, batched=True)
train_data, val_data, test_data = preprocessor.split_layer_into_train_val_test_(tokenized_ds, TRAIN_LAYER)

bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                )

model = MistralForTokenClassification.from_pretrained(
    peft_config.base_model_name_or_path,
    num_labels=len(label2id), id2label=id2label, label2id=label2id,
    token = HF_TOKEN,
    cache_dir='/data/disk1/share/pferrazzi/.cache',
    device_map='auto',
    quantization_config = bnb_config)
model = PeftModel.from_pretrained(model, adapters, token = HF_TOKEN)
model = model.merge_and_unload()


  from .autonotebook import tqdm as notebook_tqdm


In [5]:

token_classifier = pipeline("token-classification", model=model, 
                            tokenizer=tokenizer, 
                            aggregation_strategy="simple",
                            batch_size=12)

l = []
for out in tqdm(token_classifier(KeyDataset(train_data.select(range(24)), "sentence"))):
    l.append(out)

The model 'MistralForTokenClassification' is not supported for token-classification. Supported models are ['AlbertForTokenClassification', 'BertForTokenClassification', 'BigBirdForTokenClassification', 'BioGptForTokenClassification', 'BloomForTokenClassification', 'BrosForTokenClassification', 'CamembertForTokenClassification', 'CanineForTokenClassification', 'ConvBertForTokenClassification', 'Data2VecTextForTokenClassification', 'DebertaForTokenClassification', 'DebertaV2ForTokenClassification', 'DistilBertForTokenClassification', 'ElectraForTokenClassification', 'ErnieForTokenClassification', 'ErnieMForTokenClassification', 'EsmForTokenClassification', 'FalconForTokenClassification', 'FlaubertForTokenClassification', 'FNetForTokenClassification', 'FunnelForTokenClassification', 'GPT2ForTokenClassification', 'GPT2ForTokenClassification', 'GPTBigCodeForTokenClassification', 'GPTNeoForTokenClassification', 'GPTNeoXForTokenClassification', 'IBertForTokenClassification', 'LayoutLMForToken

In [6]:
l

[[{'entity_group': 'I',
   'score': 0.681,
   'word': 'child',
   'start': 5,
   'end': 11},
  {'entity_group': 'B',
   'score': 0.947,
   'word': 'acquired walking',
   'start': 11,
   'end': 28},
  {'entity_group': 'B',
   'score': 0.711,
   'word': 'the age',
   'start': 31,
   'end': 39},
  {'entity_group': 'B', 'score': 0.7197, 'word': '14', 'start': 43, 'end': 45},
  {'entity_group': 'I',
   'score': 0.659,
   'word': 'months',
   'start': 45,
   'end': 52}],
 [{'entity_group': 'B', 'score': 0.648, 'word': 'dom', 'start': 5, 'end': 8},
  {'entity_group': 'B', 'score': 0.928, 'word': 'ul', 'start': 12, 'end': 15},
  {'entity_group': 'B',
   'score': 0.9463,
   'word': 'examination',
   'start': 23,
   'end': 35},
  {'entity_group': 'B',
   'score': 0.8286,
   'word': 'reported',
   'start': 39,
   'end': 48},
  {'entity_group': 'B',
   'score': 0.6094,
   'word': 'normal',
   'start': 51,
   'end': 58}],
 [{'entity_group': 'B',
   'score': 0.943,
   'word': 'pressure',
   'start':