In [1]:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from dotenv import dotenv_values
from datasets import load_dataset, Dataset
from utils.data_preprocessor import DataPreprocessor
from utils.postprocessor import DataPostprocessor
from utils.evaluator import Evaluator
from config import config

HF_TOKEN = dotenv_values(".env.base")['HF_TOKEN']
# If the dataset is gated/private, make sure you have run huggingface-cli login

load_model = True
if load_model:
    checkpoint_adapter = 'ferrazzipietro/Mistral-7B-Instruct-v0.2_adapters_it.layer1_v0.2_wandblog'
    bnb_config = BitsAndBytesConfig(
        
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    base_model_reload = AutoModelForCausalLM.from_pretrained(
        config.BASE_MODEL_CHECKPOINT, low_cpu_mem_usage=True,
        quantization_config = bnb_config,
        return_dict=True,  load_in_4bit=True, #torch_dtype=torch.float16,
        device_map= "auto")

    adp = 'ferrazzipietro/Mistral-7B-Instruct-v0.2_adapters_it.layer1_v0.2_wandblog'
    merged_model = PeftModel.from_pretrained(base_model_reload, adp, token=HF_TOKEN, device_map="auto")
    
tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_CHECKPOINT, add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

from datasets import load_dataset
dataset = load_dataset("ferrazzipietro/e3c-sentences", token=HF_TOKEN)
dataset = dataset[config.TRAIN_LAYER]
preprocessor = DataPreprocessor()
dataset = preprocessor.preprocess_data_one_layer(dataset)
train_data, val_data, _ = preprocessor.split_layer_into_train_val_test_(dataset, config.TRAIN_LAYER)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.30s/it]
Map: 100%|██████████| 1146/1146 [00:00<00:00, 5533.43 examples/s]
Map: 100%|██████████| 139/139 [00:00<00:00, 6372.51 examples/s]


In [7]:
postprocessor.test_data['sentence']

['Il caso riguarda un ragazzo di 12 anni, ricoverato presso l’UOC di Chirurgia Pediatrica di Treviso per addome acuto.',
 'Il ragazzo manifestava da circa una settimana vomiti ripetuti accompagnati da coliche addominali, inappetenza e vistoso calo ponderale (4 kg circa in una settimana).',
 'Al ricovero il paziente si presentava molto sofferente, astenico, disidratato, apiretico, con addome globoso, trattabile ma dolente alla palpazione profonda elettivamente in fossa iliaca destra; all’ascoltazione si percepiva una peristalsi metallica.',
 'Il ragazzo era quindi sottoposto in urgenza a una laparoscopia esplorativa, subito convertita per impossibilità di acquisire una camera laparoscopica sufficiente con le pressioni usuali, a causa dell’estrema distensione delle anse ileali, riscontrando una matassa ileale diffusamente dilatata e infiammata fino all’ileo terminale.',
 'A livello del medio-ileo si trovava un DM con al suo interno una massa palpabile occludente.',
 'L’intervento si conc

In [13]:
postprocessor = DataPostprocessor(val_data, preprocessor, 2, 'it', tokenizer)
postprocessor.add_inference_prompt_column()
postprocessor.add_ground_truth_column()
postprocessor.add_responses_column(merged_model, tokenizer, 24)


Map: 100%|██████████| 556/556 [00:00<00:00, 8967.25 examples/s]
generating responses:  13%|█▎        | 72/556 [02:14<15:02,  1.87s/it]

[{'entity': 'esami'}, {'entity': 'sul fratello minore'}]
[{'entity': 'LAM'}, {'entity': 'SNC'}, {'entity': 'esordita'}, {'entity': 'paziente'}, {'entity': 'giunta'}]


generating responses:  17%|█▋        | 96/556 [02:58<14:15,  1.86s/it]

[{'entity': 'Dimissione'}, {'entity': 'giornata +56'}]
[{'entity': 'cieco'}, {'entity': 'appariva'}, {'entity': 'particolarmente'}, {'entity': 'mobile'}]


generating responses:  22%|██▏       | 120/556 [03:47<13:56,  1.92s/it]

[{'entity': 'emocolture'}, {'entity': 'urinocolture'}, {'entity': 'massaggio'}, {'entity': 'coltura'}, {'entity': 'negativi'}]


generating responses:  26%|██▌       | 144/556 [04:30<12:53,  1.88s/it]

[{'entity': 'risoluzione'}, {'entity': 'sintomi'}, {'entity': 'ascesso'}, {'entity': 'controllo'}, {'entity': 'sintomi'}, {'entity': '20 giorni'}]
[{'entity': 'anemia'}]


generating responses:  30%|███       | 168/556 [05:20<12:36,  1.95s/it]

[{'entity': 'emocromo'}, {'entity': 'creatininemia'}, {'entity': 'azotemia'}]
[{'entity': 'la paziente'}, {'entity': 'normalizzazione ECGrafica'}]


In [11]:
from utils.data_preprocessor import DataPreprocessor
from config import preprocessing_params
from datasets import Dataset
from tqdm import tqdm
import json
import re

class DataPostprocessor():
    def __init__(self, test_data: Dataset, preprocessor:DataPreprocessor, n_shots_inference:int, language:str, tokenizer) -> None:
        self.test_data = test_data
        self.preprocessor = preprocessor
        self.language = language
        self.tokenizer = tokenizer
        self.few_shots_dict = {'en':{'questions':['We present a case of a 32-year-old woman with a history of gradual enlargement of the anterior neck.',
                                                   'Patient information: a 9-month-old boy presented to the emergency room with a 3-day history of refusal to bear weight on the right lower extremity and febrile peaks of up to 38.5°C for 24 hours.'],
                                        'responses':['[{"entity": "present"}, {"entity": "history"}, {"entity": "enlargement"}]',
                                                     '[{"entity": "presented"}, {"entity": "refusal"}, {"entity": "bear"}, {"entity": "peaks"}]'],
                                        'responses_offset': ['[{"entity": "present", "offset": [3, 10]}, {"entity": "history", "offset": [48, 55]}, {"entity": "enlargement", "offset": [67, 78]}]',
                                                             '[{"entity": "presented", "offset": [39, 48]}, {"entity": "refusal", "offset": [95, 102]}, {"entity": "bear", "offset": [106, 110]}, {"entity": "peaks", "offset": [159, 164]}]']
                                    },
                                'it':{'questions':['In considerazione dell’inefficacia della terapia somministrata, in assenza di ulteriori opzioni terapeutiche standard potenzialmente efficaci e dopo colloquio con i genitori si decide di avviare la paziente a trapianto aploidentico, possibilmente NK allo reattivo, da genitore.',
                                                    'L’esame istologico dimostrava mucosa gastrica atrofica con flogosi cronica, marcato edema ed incremento del connettivo del corion, focale metaplasia intestinale, il tutto sovrastante un tessuto fibromuscolare.'],
                                       'responses':['[{"entity": "inefficacia"}, {"entity": "opzioni"}, {"entity": "colloquio"}, {"entity": "avviare"}, {"entity": "trapianto"}, {"entity": "genitori"}, {"entity": "paziente"}, {"entity": "genitore"}]',
                                                           '[{"entity": "mucosa gastrica atrofica"}, {"entity": "flogosi\r\cronica"}]'],
                                       'responses_offset':['[{"entity": "inefficacia", "offset": [23, 34]}, {"entity": "opzioni", "offset": [88,95]}, {"entity": "colloquio", "offset": [149,158]}, {"entity": "avviare", "offset": [187,194]}, {"entity": "trapianto", "offset": [209,218]}, {"entity": "genitori", "offset": [163,173]}, {"entity": "paziente", "offset": [195,106]}, {"entity": "genitore", "offset": [268,276]}]',
                                                           '[{"entity": "mucosa gastrica atrofica", "offset": [30,54]}, {"entity": "flogosi\r\cronica", "offset": [59,75]}]']}
                                }
        if len(self.few_shots_dict[self.language]['questions']) < n_shots_inference:
            raise ValueError(f'The number of shots for the inference prompt is greater than the number of examples available.')
        if len(self.few_shots_dict[self.language]['responses']) < n_shots_inference:
            raise ValueError(f'The number of shots for the inference prompt is greater than the number of responses available.')
        self.n_shots_inference = n_shots_inference
    
    def _extract_ground_truth(self, prompt:str) -> str:
        out = prompt.split('[/INST]', 1)
        return {'ground_truth': out[1][0:-4].strip()}
        
    def _extract_inference_prompt(self, sentence:str) -> str:
        if self.preprocessor.offset:
            few_shots_responses = self.few_shots_dict[self.language]['responses_offset']
        else:
            few_shots_responses = self.few_shots_dict[self.language]['responses']
        inference_prompt = self.preprocessor._format_prompt(task='inference', 
                                                        input=sentence, 
                                                        instruction_on_response_format=self.preprocessor.instruction_on_response_format,
                                                        offset=self.preprocessor.offset,
                                                        tokenizer=self.tokenizer,
                                                        output='',
                                                        n_shots=self.n_shots_inference,
                                                        list_of_examples=self.few_shots_dict[self.language]['questions'][0:self.n_shots_inference],
                                                        list_of_responses=few_shots_responses)
        return {'inference_prompt': inference_prompt}
    
    def add_inference_prompt_column(self) -> None:
        """
        Add the inferencePrompt and groundTruth columns to the test_data dataframe.
        """
        self.test_data = self.test_data.map(lambda x: self._extract_inference_prompt(x['sentence']))
    
    def add_ground_truth_column(self) -> None:
        """
        Add the groundTruth column to the test_data dataframe.
        """
        self.test_data = self.test_data.map(lambda x: self._extract_ground_truth(x['prompt']))

    def _generate_model_response(self, examples, model, tokenizer, max_new_tokens_factor:float=4) -> str:
        device = "cuda"
        tokenizer.padding_side = "left"
        input_sentences = examples['sentence']
        prompts = examples['inference_prompt']
        input_sentences_tokenized = tokenizer(input_sentences, return_tensors="pt", padding=True)
        max_new_tokens = int(len(max(input_sentences_tokenized, key=len)) * max_new_tokens_factor)

        encodeds = tokenizer(prompts, return_tensors="pt", add_special_tokens=False, padding=True)
        model_inputs = encodeds.to(device)
        generated_ids = model.generate(**model_inputs, do_sample=True, max_new_tokens=max_new_tokens,  pad_token_id=tokenizer.eos_token_id) # max_new_tokens=max_new_tokens,
        decoded = tokenizer.batch_decode(generated_ids)
        decoded = [self._postprocess_model_output(i) for i in decoded]
        return (decoded)
                
    def add_responses_column(self, model, tokenizer, batch_size:int) -> None:
        """
        Adds a column with the response of the model to the actual query.
        
        params:
        model: the model to use to generate the response
        tokenizer: the tokenizer to use to generate the response
        batch_size: the batch size to use to process the examples. Increasing this makes it faster but requires more GPU. Default is 8.
        """
        responses_col = []
        total_rows = len(self.test_data)
        indexes = [i for i in range(len(self.test_data)) if i % batch_size == 0]
        max_index = self.test_data.shape[0]


        with tqdm(total=total_rows, desc="generating responses") as pbar:
            for i, idx in enumerate(indexes[:-1]):
                indici = list(range(idx, indexes[i+1]))
                tmp = self._generate_model_response(self.test_data.select(indici), model, tokenizer)
                responses_col.extend(tmp)
                pbar.update(batch_size)
            indici = list(range(indexes[i+1], max_index))
            tmp = self._generate_model_response(self.test_data.select(indici), model, tokenizer)
            responses_col.extend(tmp)
            pbar.update(batch_size)

        self.test_data = self.test_data.add_column('model_responses', responses_col)
    
    def _postprocess_model_output(self, model_output: str) -> str:
        """
        Postprocess the model output to return a json like formatted string that can be used to compute the F1 score.

        Args:
        model_output (str): the model output as it is returned by the model. The processing of the output is done in the function

        return:
        str: the model response, i.e. the model output without the instruction

        """
        def has_unclosed_square_brackets(s):
            count = 0
            for char in s:
                if char == '[':
                    count += 1
                elif char == ']':
                    count -= 1
                    if count < 0:
                        return True
            return count > 0
        
        model_output = model_output.split('[/INST]')[-1].strip()
        if self._assess_model_output(model_output):
            return model_output
        
        tmp = re.findall(r'\[\{(.+?)\}\]', model_output)
        if len(tmp) != 0:
            tmp = '[{' + tmp[0] + '}]'
            if self._assess_model_output(tmp):
                return tmp
        if has_unclosed_square_brackets(model_output):
            last_bracket_index = model_output.rfind('},') # find the last complete entity
            model_output = model_output[:last_bracket_index+1] + ']' 
            return model_output 
  
        
    def _assess_model_output(self, model_response: str) -> bool:
        """
        Check if the model output is in the right format. If not, return False.
        
        Args:
        model_output (str): the postprocessed model output after beeing passed to _postprocess_model_output()

        return:
        bool: True if the format is correct, False otherwise
        """
        good_format = True
        try :
            res = json.loads(model_response)
            print( res)
        except:
            good_format = False
        return good_format

    