In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
import torch
from tqdm import tqdm
import re
from functools import wraps
import random

In [None]:
from huggingface_hub.hf_api import HfFolder
HfFolder.save_token("HF-TOKEN")

In [None]:
"""
from huggingface_hub import notebook_login
notebook_login()
"""

In [6]:
"""
Create a logger, in kaggle is a mess: https://www.kaggle.com/code/residentmario/notes-on-python-logging/code
"""

import logging

class LoggerManager:
    def __init__(self, file_name):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False
        self.console_handler = logging.StreamHandler()
        self.console_handler.setLevel(logging.INFO)
        console_format = logging.Formatter('%(message)s')
        self.console_handler.setFormatter(console_format)
        if not self.logger.hasHandlers():
            self.logger.addHandler(self.console_handler)

        self.file_handler = logging.FileHandler(file_name, mode="w", encoding="utf-8")
        self.file_handler.setLevel(logging.INFO)
        file_format = logging.Formatter('%(message)s')
        self.file_handler.setFormatter(file_format)

    def write(self, string):
        self.logger.info(string)


class LogToFile:
    """Context manager to write temporarly only on file."""
    def __init__(self, logger_manager):
        self.logger_manager = logger_manager
        self.logger = logger_manager.logger
        self.console_handler = logger_manager.console_handler
        self.file_handler = logger_manager.file_handler
    
    def __enter__(self):
        if self.console_handler in self.logger.handlers:
            self.logger.removeHandler(self.console_handler)
        if self.file_handler not in self.logger.handlers:
            self.logger.addHandler(self.file_handler)
    
    def __exit__(self, exc_type, exc_value, traceback):
        if self.file_handler in self.logger.handlers:
            self.logger.removeHandler(self.file_handler)
        if self.console_handler not in self.logger.handlers:
            self.logger.addHandler(self.console_handler)

""" Run the following tests:
logger_manager = LoggerManager("log.txt")

logger_manager.write("Test message on console.")  # Write only on console

with LogToFile(logger_manager):
    logger_manager.write("Test message on file.")  # Write only on file

logger_manager.write("Back to console.")  # Write only on console

# Check file content with
!cat log.txt
"""

' Run the following tests:\nlogger_manager = LoggerManager("log.txt")\n\nlogger_manager.write("Test message on console.")  # Write only on console\n\nwith LogToFile(logger_manager):\n    logger_manager.write("Test message on file.")  # Write only on file\n\nlogger_manager.write("Back to console.")  # Write only on console\n\n# Check file content with\n!cat log.txt\n'

In [7]:
logger_manager = LoggerManager("file.log")

In [8]:
class Model:
    """
    examples of models:
    microsoft/phi-2
    TinyLlama/TinyLlama-1.1B-Chat-v0.6
    google/gemma-2-2b-it
    """
    
    def __init__(self, model_name, load_on_init = False):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.pipe = None
        if load_on_init:
            self.get_model()
            self.get_tokenizer()

    def get_model_name(self):
        return self.model_name

    def format_prompt(self, prompt):
        # By default, keep prompt unchanged, some subclasses may have to override this behaviour
        # for example deepseek may have to append the <think> tag at the end of the prompt
        return prompt
        

    def get_model(self):
        if not self.model:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
        return self.model
        
    def get_tokenizer(self):
        if not self.tokenizer:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
        return self.tokenizer

    def get_pipeline(self):
        if not self.pipe:
            self.pipe = pipeline(
                "text-generation",
                model=self.get_model(),
                tokenizer=self.get_tokenizer(),
                max_new_tokens=256,
                temperature=0.1
            )
        return self.pipe



class Dataset():
    def __init__(self, dataset_fraction = 1, split="validation"):
        self.dataset = None
        self.dataset_fraction = dataset_fraction
        self.split = split
     

    def get_dataset(self, dataset_name):
        if self.dataset:
            return self.dataset
        if self.split:
            self.dataset = load_dataset(dataset_name, split=self.split, trust_remote_code=True)
        else:
            self.dataset = load_dataset(dataset_name, trust_remote_code=True)
        if self.dataset_fraction != None:
            num_samples = int(len(self.dataset) * self.dataset_fraction)
            self.dataset = self.dataset.shuffle().select(range(num_samples))
        return self.dataset


    def iteration_evaluate_model(self, model, row_idx, row, n_shot, logger_manager = None):
        dataset = self.get_dataset()

        # Code for n-shot prompting
        dataset_keys = list(range(len(dataset)))
        dataset_keys_filtered = dataset_keys[:row_idx] + dataset_keys[row_idx + 1:]
        dataset_filtered = dataset.select(dataset_keys_filtered)
        prompt = ""
        for i in range(n_shot):
            shot_row = random.choice(dataset_filtered)
            prompt += self.format_prompt(shot_row) + " " + str(self.get_true_answer(shot_row)) + "\n"

        # Building the prompt
        prompt = prompt + self.format_prompt(row)

        # Ask model the prompt
        answer = model.get_pipeline()(prompt, return_full_text=False)[0]['generated_text']  


        is_llm_answer_correct, is_answer_rejected = self.is_correct(answer, row)

        if logger_manager:
            logger_manager.write(f"- Prompt:\n {prompt}\n")
            logger_manager.write(f"- Answer:\n {answer}\n")
            logger_manager.write(f"- True Answer: {self.get_true_answer(row)}\n")
            logger_manager.write(f"- Is LLM answer correct? : {is_llm_answer_correct}\n")

        return is_llm_answer_correct, is_answer_rejected
        
    
    def evaluate_model(self, model, n_shot = 0, logger_manager = None):
        correct = 0
        rejected = 0
        pipe = model.get_pipeline()
        dataset = self.get_dataset()
        
        
        with LogToFile(logger_manager):
            logger_manager.write(f"Evaluating {model.get_model_name()} on {self.get_dataset_name()} with {n_shot}-shot\n")
            
        for idx, example in tqdm(enumerate(dataset),total=len(dataset), desc=f"Evaluating {model.get_model_name()} on {self.get_dataset_name()} with {n_shot}-shot"):
            with LogToFile(logger_manager):
                is_correct, is_answer_rejected = self.iteration_evaluate_model(model, idx, example, n_shot, logger_manager)
                if is_correct:
                    correct += 1
                if is_answer_rejected:
                    rejected += 1
                    
        accuracy = correct / (len(dataset)) * 100
        with LogToFile(logger_manager):
            logger_manager.write(f"\nFinal accuracy: {accuracy}")
            logger_manager.write(f"\nNumber of rejected answers: {rejected}")
        return accuracy



class HellaSwag(Dataset):
    """
link: https://huggingface.co/datasets/Rowan/hellaswag
Example(cropped)
{
    "activity_label": "Removing ice from car",
    "ctx": "Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then",
    "ctx_a": "Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.",
    "ctx_b": "then",
    "endings": "[\", the man adds wax to the windshield and cuts it.\", \", a person board a ski lift, while two men supporting the head of the per...",
    "ind": 4,
    "label": "3",
    "source_id": "activitynet~v_-1IBHYS3L-Y",
    "split": "train",
    "split_type": "indomain"
}

Note: 
1. The ctx and the endings may contain tags like [header], [title], [step], [substeps], etc. If we don't remove them, the LLM might mis-interpret the prompt.
2. label is from 0-3. We will pose the question to LLM as choose between 4 options indexed 1-4.

    """  

    ANSWER = {
        0:'A',
        1:'B',
        2:'C',
        3:'D'
    }
    
    def __init__(self, load_on_init = False, dataset_fraction = 1, split = "validation"):
        super().__init__(dataset_fraction, split)
        self.dataset_name = "hellaswag"
        if load_on_init:
            self.get_dataset()
        
    
    def get_dataset(self):
        return super().get_dataset(self.dataset_name)

    def get_dataset_name(self):
        return self.dataset_name

    def format_prompt(self,example):
        """
        The example is formatted as

        Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A')
        Context: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then
        Which of the following options is the most plausible continuation?
        A. The man adds wax to the windshield and cuts it.  
        B. A person boards a ski lift, while two men support the head of the person.  
        C. The woman walks away and the man starts removing the ice from the car.  
        D. The man and woman start dancing on the snowy ground. 
        """
        ctx = example['ctx']
        ctx = re.sub(r"\[.*?\]", "", ctx).strip() # Remove tags from context
        endings = example['endings']
        for i in range(len(endings)):
            endings[i] = re.sub(r"\[.*?\]", "", endings[i]).strip() # Remove tags from endings
        return f"Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').\n Context: {ctx}\nWhich of the following options is the most plausible continuation?\nA. {endings[0]}\nB. {endings[1]}\nC. {endings[2]}\nD. {endings[3]}"

    def get_true_answer(self, example):
        return f"Answer: {self.ANSWER[int(example['label'])]}"

    def is_correct(self, model_answer, row): 
        true_answer = self.ANSWER[int(row['label'])].lower()
        prediction = re.search("(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])", model_answer)
        
        if prediction:
            predicition_label = re.search(r"Answer:(.*)", prediction.group(0).lower(), re.IGNORECASE).group(1).lower().strip()
            return predicition_label == true_answer, False
        return False, True


class BoolQ(Dataset):
    """
link: https://huggingface.co/datasets/google/boolq
Example (cropped):
{
    "passage": "\"All biomass goes through at least some of these steps: it needs to be grown, collected, dried, fermented, distilled, and burned...",
    "question": "does ethanol take more energy make that produces"
    "answer": false,
}
    
    """
    def __init__(self, load_on_init = False, dataset_fraction = 1, split = "validation"):
        super().__init__(dataset_fraction, split)
        self.dataset_name = "google/boolq"
        if load_on_init:
            self.get_dataset()
    
    def get_dataset(self):
        return super().get_dataset(self.dataset_name)

    def get_dataset_name(self):
        return self.dataset_name

    def format_prompt(self, example):
        """
        Formats the example in this way:
        
        Passage: All biomass goes through at least some of these steps: it needs to be grown, collected, dried, fermented, distilled, and burned...
        Question: does ethanol take more energy make that produces
        Answer with only True or False:
        """
        # return f"Passage: {example['passage']}\nQuestion: {example['question']}\nAnswer with only True or False:"
        return f"Answer the following true/false question. The last line of your response should be in the following format: 'Answer: true/false' (e.g. 'Answer: true').\nPassage: {example['passage']}\nQuestion: {example['question']}\n"


    def get_true_answer(self, example):
        return f"Answer: {example['answer']}"

    def is_correct(self, model_answer, row):
        """
        looks for a True of False (ignoring case sensitivity) in the model answer string.
        """

        """
        true_answer = self.get_true_answer(row)
        prediction = re.search(r"(True|False)", model_answer, re.IGNORECASE)
        if prediction: # If there is a true or a false
            return str(prediction.group(0).lower()) == str(true_answer).lower() # https://stackoverflow.com/questions/15340582/python-extract-pattern-matches
        return False # In every other case, it is not the correct answer :(
        """

        true_answer = self.get_true_answer(row)
        prediction = re.search("(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*(true|false)(?![a-zA-Z0-9])", model_answer)
        if prediction:
            true_false = re.search("(True|False)", prediction.group(0).lower(), re.IGNORECASE).group(0).lower()
            true_true_false = re.search("(True|False)", true_answer, re.IGNORECASE).group(0).lower()
            return true_false == true_true_false, False
        return False, True


In [9]:
gemma_2b = Model("google/gemma-2-2b-it", load_on_init = True)

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [10]:
hellaswag = HellaSwag(load_on_init=True, dataset_fraction = 1, split = "validation")

README.md:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

hellaswag.py:   0%|          | 0.00/4.36k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/2.53k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/47.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/11.8M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]

In [11]:
row_idx, row = random.choice(list(enumerate(hellaswag.get_dataset())))
n_shot = 0
hellaswag.iteration_evaluate_model(gemma_2b, row_idx, row, n_shot, logger_manager)

Device set to use cuda:0
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
- Prompt:
 Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
 Context: How to avoid mercury in your skin products  Recognize common products that use mercury.  There may be mercury lurking in many of your skin care products. However, some are more likely to contain it.
Which of the following options is the most plausible continuation?
A. The common types include :  Mineral oil prescription hydrogen peroxide baking soda diogenes (example used are acidophilus and ciprofloxacin) dissolved retinoids  Rinse your face.  Use a mild, baby soap or bathwater.
B. Check any of the following to see if they contain mercury :  Skin creams, especially anti-aging and skin lightening beauty and antiseptic soaps lotions  Read t

(True, False)

In [12]:
# row_idx, row = random.choice(list(enumerate(boolq.get_dataset())))
n_shot = 1
hellaswag.iteration_evaluate_model(gemma_2b, row_idx, row, n_shot, logger_manager)

- Prompt:
 Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
 Context: How to take care of ladybugs  Purchase ladybugs.  If you have a garden or a greenhouse, you should consider adding more ladybugs to the environment. Ladybugs are very useful for controlling aphids.
Which of the following options is the most plausible continuation?
A. Aphids are tiny insects that can destroy vegetable and flower gardens.  Ladybugs tend to hibernate during winter months.
B. Ladybugs can be sold as small boxes, pots, or bowls.  Ladybugs come in all shapes and sizes.
C. You might try to keep ladybugs away from the garden. You can also buy ladybugs from a local store that sells insect products.
D. Ladybugs can be especially helpful to keep livestock and other animals away as well as washing your vegetable and insect products.  Purchase the ladybugs' natural habitats. Answer: A
Answer the following multipl

(True, False)

In [None]:
hellaswag.evaluate_model(gemma_2b, 1, logger_manager)

In [None]:
!tail -n 3 file.log