# Test Notebook

This notebook contains the testing routine. The models are evaluated on the test set, and also on the whole train dataset per category, to identify areas of improvement.

In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install -U transformers
!pip install -U datasets

In [None]:
import os
from google.colab import userdata, drive

from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from huggingface_hub import login

import wandb
import torch
import gc

In [None]:
# This is only needed if the data gets loaded from google drive
drive.mount('/content/drive')

In [None]:
os.environ['WANDB_API_KEY'] = userdata.get('WB_TOKEN')
wandb.login()

os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
login(token = os.environ['HF_TOKEN'])

In [None]:
# Limit reserved but unallocated memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
model_names = {
    'base': 'Qwen/Qwen2.5-3B-Instruct',
    'base_dpo': 'nicomu99/prompt-dpo-Qwen2_5-3B-Instruct-lr1e-05-ws50-cosine-beta0_1',
    'base_dpo_shift': 'nicomu99/prompt-dpo-Qwen2_5-3B-Instruct-lr1e-05-ws50-cosine-beta0_1-dpol0_95-fixed',
    'sft': 'nicomu99/Qwen2.5-3B-persona-SFT',
    'sft_dpo': 'nicomu99/prompt-dpo-Qwen2_5-3B-persona-SFT-lr1e-05-ws50-cosine-beta0_1',
    'sft_dpo_shift': 'nicomu99/prompt-dpo-Qwen2_5-3B-persona-SFT-lr1e-05-ws50-cosine-beta0_1-dpol0_95-fixed',
}


## Load Data

In [None]:
dataset = load_dataset( 'parquet',
    data_files={
        'train':    '/content/drive/MyDrive/practical_course2/data/agent_train.parquet',
        'test':     '/content/drive/MyDrive/practical_course2/data/agent_test.parquet'
    }
)

def is_valid(example):
    return example['chosen'][0]['content'].strip() != ''

dataset['train'] = dataset['train'].filter(is_valid)
dataset['test'] = dataset['test'].filter(is_valid)

## Evaluation Test Split

Test on the whole test split.

In [None]:
PROJECT_NAME    = 'pr2-test'
EVAL_BATCH_SIZE = 8
training_args = DPOConfig(
    per_device_eval_batch_size  = 8,
    do_train    = False,
    do_eval     = True,
    report_to   = 'wandb',
    seed        = 42,
    data_seed   = 42,
    output_dir  = 'out',    # To mitigate warning
    run_name    = 'tmp'     # To mitigate warning
)

In [None]:
for model_name, model_repo in model_names.items():
    print(f'=== Evaluating {model_name} ===')
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_repo,
        max_seq_length  = 2048,
        dtype           = None,
        load_in_4bit    = False,
    )

    wandb.init(
        project = PROJECT_NAME,
        name    = model_name
    )

    trainer = DPOTrainer(
        model               = model,
        ref_model           = None,
        beta                = 0.1,
        args                = training_args,
        # Just a dummy - not really needed
        train_dataset       = dataset['train'].select(range(50)),
        eval_dataset        = dataset['test'],
        processing_class    = tokenizer,
    )

    results = trainer.evaluate()

    wandb.finish()

## Evaluation: Per Category Loop

In [None]:
training_args = DPOConfig(
    per_device_eval_batch_size  = 8,
    do_train    = False,
    do_eval     = True,
    report_to   = [], # No default logging
    seed        = 42,
    data_seed   = 42,
    output_dir  = 'out',    # To mitigate warning
    run_name    = 'tmp'     # To mitigate warning
)

In [None]:
def log_evaluation(model_name, model_path, results):
    """
    Function for logging the per-category outputs to wandb as a table.
    """
    run = wandb.init(
        project = 'dpo-test-per-category',
        name    = model_name,
        config  = {'model_path': model_path},
        reinit  = True
    )

    for split_name, per_category_results in results.items():
        example_category = next(iter(per_category_results.values()))
        metric_keys = list(example_category.keys())

        table = wandb.Table(columns=['Category'] + metric_keys)

        for category, metrics in per_category_results.items():
            row = [category] + [metrics.get(k, 0.0) for k in metric_keys]
            table.add_data(*row)

        wandb.log({f'{split_name}/per_category': table})

    run.finish()

In [None]:
for model_name, model_repo in model_names.items():
    print(f'=== Evaluating {model_name} ===')
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_repo,
        max_seq_length  = 2048,
        dtype           = None,
        load_in_4bit    = False,
    )

    if model_name == 'base':
        # Only for base model
        model = FastLanguageModel.get_peft_model(
            model,
            r               = 16,
            target_modules  = ["q_proj", "k_proj", "v_proj", "o_proj",
                               "gate_proj", "up_proj", "down_proj",],
            lora_alpha      = 32,
            lora_dropout    = 0.01,
            bias            = 'none',
            random_state    = 42,
            use_gradient_checkpointing = False,
        )
    model.eval()

    results = {}
    for split_name, split_data in dataset.items():
        print(f'=== Evaluating {split_name} ===')

        results[split_name] = {}
        for category in set(split_data['category']):
            category_subset = split_data.filter(
                    lambda example: example['category'] == category
                )

            trainer = DPOTrainer(
                model               = model,
                ref_model           = None,
                beta                = 0.1,
                args                = training_args,
                # Just a dummy - not really needed
                train_dataset       = dataset['train'].select(range(10)),
                eval_dataset        = category_subset,
                processing_class    = tokenizer,
            )

            eval_results = trainer.evaluate()
            results[split_name][category] = eval_results

    log_evaluation(model_name, model_repo, results)

    del model, tokenizer, trainer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## Prompting

In [None]:
from unsloth import FastLanguageModel
from datasets import load_dataset

model_repo = model_names['base_dpo_shift']
model, tokenizer = FastLanguageModel.from_pretrained(
    model_repo,
    max_seq_length  = 2048,
    dtype           = None,
    load_in_4bit    = False,
)
model.eval()

In [None]:
dataset = load_dataset( 'parquet',
    data_files={
        'train':    '/content/drive/MyDrive/practical_course2/data/agent_train.parquet',
        'test':     '/content/drive/MyDrive/practical_course2/data/agent_test.parquet'
    }
)

def is_valid(example):
    return example['chosen'][0]['content'].strip() != ''

dataset['train'] = dataset['train'].filter(is_valid)
dataset['test'] = dataset['test'].filter(is_valid)

In [None]:
dataset_split = dataset['test']

target_category = 'biology'
filtered_data = dataset_split.filter(lambda example: example['category'] == target_category)

In [None]:
sample = filtered_data[10]
inputs = tokenizer.apply_chat_template(
                sample['prompt'],
                add_generation_prompt   = True,
                return_tensors          = 'pt',
            ).to('cuda')
input_length = inputs.shape[1]

outputs = model.generate(inputs, max_new_tokens=128)
# generated_tokens = outputs[0][input_length:]
completion = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print(sample['chosen'])
print(f'Completion: {completion}\n')
