In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
URL_FULL = "http://58.186.149.100:19001/"
URL_QUANTIZED = "http://58.186.149.100:19002/"
MAX_TOKENS = 3000
TEMPERATURE = 0.99
SEED = 42
LOGPROBS = 4


from validation.prompts import get_squad_data_questions
from validation.runner import run_validation
from validation.data import (
    ModelInfo,
    RequestParams,
    save_to_jsonl
)


from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")

In [3]:
from datasets import load_dataset
from typing import List


def get_squad_data_questions() -> List[str]:
    dataset = load_dataset('squad', keep_in_memory=True)
    prompts = []
    
    train_prompts = [f"Context: {context}\nQuestion: {question} " for question, context in zip(dataset['train']['question'], dataset['train']['context'])]
    prompts.extend(train_prompts)
    
    validation_prompts = [f"Context: {context}\nQuestion: {question} " for question, context in zip(dataset['validation']['question'], dataset['validation']['context'])]
    prompts.extend(validation_prompts)
    
    return prompts

prompts = get_squad_data_questions()

In [4]:
full_model_info = ModelInfo(
    url="http://58.186.149.100:19002/",
    name="Qwen/Qwen2.5-7B-Instruct",
    deploy_params={
        "GPU": "1xA100",
        "precision": "fp8",
    }
)

quantized_model_info = ModelInfo(
    url="http://58.186.149.100:19001/",
    name="Qwen/Qwen2.5-7B-Instruct-AWQ",
    deploy_params={
        "GPU": "1xA100",
        "precision": "int4",
    }
)

request_params = RequestParams(
    max_tokens=MAX_TOKENS,
    temperature=TEMPERATURE,
    seed=SEED,
    top_logprobs=LOGPROBS
)

inference_model_info = full_model_info
validation_model_info = full_model_info

In [None]:
DATA_PATH = 'squad-all_qwen25-7B_fp8_val-fp8.jsonl'

batch_size = 500

prompts = prompts

for start_idx in range(0, len(prompts), batch_size):
    prompt_batch = prompts[start_idx:start_idx + batch_size]
    results_batch = run_validation(
        prompt_batch,
        inference_model=inference_model_info,
        validation_model=validation_model_info,
        request_params=request_params,
        max_workers=50
    )
    save_to_jsonl(results_batch, DATA_PATH, append=True)
    print(f"Processed {start_idx + batch_size} from {len(prompts)}")