In [14]:
import torch
import pickle
import os

from typing import Tuple
from tqdm.notebook import tqdm
from peft import AutoPeftModelForCausalLM
from transformers import (
    BitsAndBytesConfig,
    pipeline,
    AutoTokenizer
)
from datasets import load_dataset


from reward import get_reward

ADAPTERS = [{
        "dir" : "./open_llama_3b_v2_sft/",
        "moniker" : "pre"
    }, {
        "dir" : "./open_llama_3b_v2_sft_full/",
        "moniker" : "sft"
    }, {
        "dir" : "./open_llama_3b_v2_sft_plus_dpo/",
        "moniker" : "dpo",
    }, {
        "dir" : "./open_llama_3b_v2_sft_plus_dpo_r4",
        "moniker" : "dpo_r4"
}]

In [2]:
dataset_name = "samlhuillier/sql-create-context-spider-intersect"
response_template = "\n-- Answer:\n"
def format_prompt(example) -> Tuple[str, str]:
    return f"{example['context']} \n-- Question: {example['question']}{response_template}", example['answer']

In [3]:
dataset = load_dataset(dataset_name, split="validation").map(lambda example: {"query" : format_prompt(example)[0]})

In [4]:
dataset[0]

{'answer': 'SELECT count(*) FROM singer',
 'db_id': 'concert_singer',
 'context': 'CREATE TABLE singer (Id VARCHAR)',
 'question': 'How many singers do we have?',
 'query': 'CREATE TABLE singer (Id VARCHAR) \n-- Question: How many singers do we have?\n-- Answer:\n'}

In [5]:
# Model from Hugging Face hub
base_model = "openlm-research/open_llama_3b_v2"

In [6]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [7]:
def load_model(adapter_dir):
    # Load model
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False,
    )
    
    dpo_model = AutoPeftModelForCausalLM.from_pretrained(
        adapter_dir,
        quantization_config=quant_config,
        trust_remote_code=True,
        is_trainable=True,
    )
    
    dpo_model.config.use_cache = False
    dpo_model.config.pretraining_tp = 1
    
    return dpo_model.merge_and_unload().eval()

In [8]:
def get_logger(quiet=True, outfile="eval_out.txt"):
    def log(txt):
        with open(outfile, "a", encoding="utf-8") as f:
            f.write(f"{txt}\n")
        if not quiet:
            print(txt)
    return log

In [9]:
def get_eval_rewards(model, quiet=True, sample_size=-1, logfile="eval_out.txt"):
    p = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=120)
    
    rewards = []
    if sample_size >= 0:
        inds = torch.randperm(len(dataset))[:sample_size]
        ds = dataset.select(inds)
    else:
        ds = dataset

    log = get_logger(quiet, logfile)
    for row in tqdm(ds):
        out = p(row["query"])
        response = out[0]["generated_text"]
        
        # remove our query, split by newlines
        response_lines = response.replace(row["query"], "").split("\n")
        model_submission = response_lines[0]
        rew = get_reward(row["db_id"], model_submission, row["answer"])
        rewards.append(rew)
        
        log(row["query"])
        log(model_submission)
        log("-------------------------------------------------")
        log(f"-- Got reward of {rew} against solution:")
        log(row['answer'])
        log("-------------------------------------------------")
    
    return rewards

In [10]:
for adapter in ADAPTERS:
    moniker = adapter['moniker']
    print("Evaluating", moniker, "...")
    rewards = get_eval_rewards(load_model(adapter["dir"]), 
                                quiet=True, sample_size=50, 
                                logfile=f"{moniker}_eval_out.txt")

    with open(f"{adapter['moniker']}_eval_returns.pkl", "wb") as f:
        pickle.dump(rewards, f)

Evaluating pre ...




  0%|          | 0/50 [00:00<?, ?it/s]



Evaluating sft ...


  0%|          | 0/50 [00:00<?, ?it/s]

Evaluating dpo ...


  0%|          | 0/50 [00:00<?, ?it/s]

Evaluating dpo_r4 ...


  0%|          | 0/50 [00:00<?, ?it/s]

In [11]:
def load_reward(adapter) -> list:
    with open(f"{adapter['moniker']}_eval_returns.pkl", "rb") as f:
        return pickle.load(f)

In [17]:
for adapter in ADAPTERS:
    rews = torch.tensor(load_reward(adapter))
    print(f"Adapter {adapter['moniker']} got reward", rews.mean(), "+-", rews.std())

Adapter pre got reward tensor(0.5911, dtype=torch.float64) +- tensor(0.2790, dtype=torch.float64)
Adapter sft got reward tensor(1.4085, dtype=torch.float64) +- tensor(0.4250, dtype=torch.float64)
Adapter dpo got reward tensor(1.4580, dtype=torch.float64) +- tensor(0.4037, dtype=torch.float64)
Adapter dpo_r4 got reward tensor(1.4815, dtype=torch.float64) +- tensor(0.4150, dtype=torch.float64)
