# Evaluation and Finetuning of Summarization Models in Okareo

<a target="_blank" href="https://colab.research.google.com/github/okareo-ai/okareo-python-sdk/blob/main/examples/classification_finetuning_eval_part1.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook is Part 2 of our demo notebooks for evaluating fine-tuned summarization models in Okareo. If you have not already, then please run Part 1 to evaluate the baseline zero-shot model and to get the weights for fine-tuned model.

In [None]:
# get this ID from the last cell of Part 1
TEST_SCENARIO_ID = "bf732090-785c-42e4-9eb9-25ccf7d51794"

## Upload the Data as a Scenario in Okareo

First, we setup our Okareo client. You will need API token from [https://app.okareo.com/](https://app.okareo.com/). (Note: You will need to register first.)

In [None]:
# get Okareo client

from okareo import Okareo

OKAREO_API_KEY = "<YOUR_OKAREO_API_KEY>"
okareo = Okareo(OKAREO_API_KEY)

In [None]:
sdp = okareo.get_scenario_data_points(TEST_SCENARIO_ID)

### Configure Phi-3 for finetuning

Now we set up a finetuning run on [Phi-3.5-mini-instruct](https://huggingface.co/microsoft/Phi-3.5-mini-instruct) using the finetuning instruction scenario.

Setup is mostly boilerplate from [here](https://huggingface.co/microsoft/Phi-3.5-mini-instruct/resolve/main/sample_finetune.py).

In [None]:
import sys

import datasets
from datasets import load_dataset
from peft import LoraConfig
import torch
import transformers
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig

###################
# Hyper-parameters
###################
training_config = {
    "bf16": True,
    "do_eval": False,
    "learning_rate": 5.0e-05,
    "log_level": "info",
    "logging_steps": 20,
    "logging_strategy": "steps",
    "lr_scheduler_type": "cosine",
    "num_train_epochs": 1,
    "max_steps": -1,
    "output_dir": "./finetuned_phi3", # checkpoint directory
    "overwrite_output_dir": True,
    "per_device_eval_batch_size": 4,
    "per_device_train_batch_size": 4,
    "remove_unused_columns": True,
    "save_steps": 100,
    "save_total_limit": 1,
    "seed": 0,
    "gradient_checkpointing": True,
    "gradient_checkpointing_kwargs":{"use_reentrant": False},
    "gradient_accumulation_steps": 1,
    "warmup_ratio": 0.2,
    }

peft_config = {
    "r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "bias": "none",
    "task_type": "CAUSAL_LM",
    "target_modules": ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
    # "target_modules": "all-linear",
    "modules_to_save": None,
}
train_conf = TrainingArguments(**training_config)
peft_conf = LoraConfig(**peft_config)

In [None]:
import logging

###############
# Setup logging
###############
logger = logging.getLogger(__name__)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = train_conf.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process a small summary
logger.warning(
    f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
    + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
)
logger.info(f"Training/evaluation parameters {train_conf}")
logger.info(f"PEFT parameters {peft_conf}")

In [None]:
########
# Get finetuned model/tokenizer
########

# load base LLM model and tokenizer
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",  # loading the model with flash-attenstion support
    device_map="auto",
    bnb_4bit_compute_dtype=torch.float16,
    load_in_4bit=True,
)
ft_model = AutoModelForCausalLM.from_pretrained(
    train_conf.output_dir,
    **model_kwargs,
)
ft_tokenizer = AutoTokenizer.from_pretrained(train_conf.output_dir)
ft_tokenizer.model_max_length = 3072
ft_tokenizer.pad_token = ft_tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
ft_tokenizer.pad_token_id = ft_tokenizer.convert_tokens_to_ids(ft_tokenizer.pad_token)
ft_tokenizer.padding_side = 'left' # required for correct generation


In [None]:
from okareo.model_under_test import CustomBatchModel, ModelInvocation

with open('prompts/finetune_summarization.txt', "r") as f:
    SHORT_SYSTEM_PROMPT_TEMPLATE = f.read()

class Phi3SummaryModel(CustomBatchModel):
    def __init__(self, name, batch_size):
        super().__init__(name, batch_size)
        # self.len_end_token = len("<|end|>")
        self.tokenizer = ft_tokenizer
        self.model = ft_model

    def invoke_batch(self, input_batch):
        # unpack the input_values, ids from the batch
        input_values = [input_dict['input_value'] for input_dict in input_batch]
        scenario_ids = [input_dict['id'] for input_dict in input_batch]

        prompts = []
        prompt_messages = []
        for input_value in input_values:
            messages = [
                {'role': 'user', 'content': f"{SHORT_SYSTEM_PROMPT_TEMPLATE}\n\nArticle: {input_value}"},
            ]
            prompt_messages.append(messages[0]['content'])
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            prompts.append(prompt)

        with torch.no_grad():
            input_tokens = self.tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                # truncation=True,
            )
            input_ids = input_tokens.input_ids.cuda()
            attention_mask = input_tokens.attention_mask.cuda()
            outputs = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=3072,
                do_sample=False,
            )
            decoded_batch = ft_tokenizer.batch_decode(
                outputs.detach().cpu().numpy(),
                skip_special_tokens=True
            )

        del input_ids, attention_mask, outputs, input_tokens
        torch.cuda.empty_cache()

        invocations = []
        for scenario_id, decoded, prompt, input_value, message in zip(scenario_ids, decoded_batch, prompts, input_values, prompt_messages):
            pred = decoded[len(message):].strip() # only use generation past the instruction prompt
            invocations.append({
                'id': scenario_id,
                'model_invocation': ModelInvocation(
                    model_prediction=pred,
                    model_input=input_value,
                    model_output_metadata=decoded,
                )
            })
        return invocations

In [None]:
from time import time
from okareo_api_client.models.test_run_type import TestRunType

OPENAI_API_KEY="<YOUR_OPENAI_API_KEY>"

batch_size = 64
mut_name = "Phi-3.5-mini-instruct (Fine-tuned Summarization)"

print(f'--- evaluation on test split ---')
print(f'batch_size | eval_time (s) | app_link')
# Register the model to use in the test run
start_time = time()
model_under_test = okareo.register_model(
    name=mut_name,
    model=[
        Phi3SummaryModel(
            name="Phi3SummaryModel(CustomBatchModel)",
            batch_size=batch_size
        )
    ],
    update=True
)

eval_name = f"Summarization Run (test split)"
evaluation = model_under_test.run_test(
    name=eval_name,
    api_key=OPENAI_API_KEY,
    scenario=TEST_SCENARIO_ID,
    test_run_type=TestRunType.NL_GENERATION,
    checks=[
        "latency",
        "fluency_summary",
        "character_count",
        "word_count",
        "under_350_characters",
    ],
)
eval_time = time() - start_time
print(f"{batch_size} | {eval_time:3.2f} | {evaluation.app_link}")