In [1]:
import torch

from typing import Tuple
from tqdm.notebook import tqdm
from peft import AutoPeftModelForCausalLM
from transformers import (
    BitsAndBytesConfig,
    pipeline,
    AutoTokenizer
)
from datasets import load_dataset


from reward import get_reward

SFT_ADAPTER_DIRECTORY = "./open_llama_3b_v2_sft/"
DPO_ADAPTER_DIRECTORY = "./open_llama_3b_v2_sft_plus_dpo/"

2023-12-14 04:14:23.702718: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-14 04:14:23.752957: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Model from Hugging Face hub
base_model = "openlm-research/open_llama_3b_v2"

In [3]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
def load_model(adapter_dir):
    # Load DPO model
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        # torch_dtype=torch.float16,
        bnb_4bit_use_double_quant=False,
    )
    
    dpo_model = AutoPeftModelForCausalLM.from_pretrained(
        adapter_dir,
        quantization_config=quant_config,
        trust_remote_code=True,
        is_trainable=True,
    )
    
    dpo_model.config.use_cache = False
    dpo_model.config.pretraining_tp = 1
    
    return dpo_model.merge_and_unload().eval()

In [5]:
# sft_model = load_model(SFT_ADAPTER_DIRECTORY)

In [6]:
dataset_name = "samlhuillier/sql-create-context-spider-intersect"
response_template = "\n-- Answer:\n"
def format_prompt(example) -> Tuple[str, str]:
    return f"{example['context']} \n-- Question: {example['question']}{response_template}", example['answer']

In [7]:
dataset = load_dataset(dataset_name, split="validation").map(lambda example: {"query" : format_prompt(example)[0]})

In [8]:
dataset[0]

{'answer': 'SELECT count(*) FROM singer',
 'question': 'How many singers do we have?',
 'context': 'CREATE TABLE singer (Id VARCHAR)',
 'db_id': 'concert_singer',
 'query': 'CREATE TABLE singer (Id VARCHAR) \n-- Question: How many singers do we have?\n-- Answer:\n'}

In [9]:
def log(txt):
    with open("eval_out.txt", "w+") as f:
        f.write(txt)
    print(txt)

def get_eval_rewards(model):
    p = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=120)
    
    rewards = []
    for row in tqdm(dataset):
        out = p(row["query"])
        response = out[0]["generated_text"]
        
        # remove our query, split by newlines
        response_lines = response.replace(row["query"], "").split("\n")
        model_submission = response_lines[0]
        rew = get_reward(row["db_id"], model_submission, row["answer"])
        rewards.append(rew)
        
        log(row["query"])
        log(model_submission)
        log("-------------------------------------------------")
        log(f"-- Got reward of {rew} against solution:")
        log(row['answer'])
        log("-------------------------------------------------")
        # break

In [10]:
get_eval_rewards(load_model(DPO_ADAPTER_DIRECTORY))



  0%|          | 0/568 [00:00<?, ?it/s]



CREATE TABLE singer (Id VARCHAR) 
-- Question: How many singers do we have?
-- Answer:

SELECT COUNT(*) FROM singer;
-------------------------------------------------
-- Got reward of 2.0 against solution:
SELECT count(*) FROM singer
-------------------------------------------------
CREATE TABLE singer (name VARCHAR, country VARCHAR, age VARCHAR) 
-- Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
-- Answer:

SELECT name, country, age FROM singer ORDER BY age DESC
-------------------------------------------------
-- Got reward of 2.0 against solution:
SELECT name ,  country ,  age FROM singer ORDER BY age DESC
-------------------------------------------------
CREATE TABLE singer (age INTEGER, country VARCHAR) 
-- Question: What is the average, minimum, and maximum age of all singers from France?
-- Answer:

SELECT AVG(age) FROM singer WHERE country='France'
-------------------------------------------------
-- Got reward of 1.00000999990



CREATE TABLE stadium (name VARCHAR, capacity VARCHAR, average VARCHAR) 
-- Question: What is the name and capacity for the stadium with highest average attendance?
-- Answer:

-- SELECT name, capacity, AVG(attendance) FROM stadium GROUP BY name, capacity ORDER BY AVG(attendance) DESC LIMIT 1
-------------------------------------------------
-- Got reward of 1.5000049999500005 against solution:
SELECT name ,  capacity FROM stadium ORDER BY average DESC LIMIT 1
-------------------------------------------------
CREATE TABLE concert (YEAR VARCHAR) 
-- Question: How many concerts are there in year 2014 or 2015?
-- Answer:

SELECT COUNT(*) FROM concert WHERE YEAR = '2014' OR YEAR = '2015'
-------------------------------------------------
-- Got reward of 2.0 against solution:
SELECT count(*) FROM concert WHERE YEAR  =  2014 OR YEAR  =  2015
-------------------------------------------------
CREATE TABLE stadium (name VARCHAR, stadium_id VARCHAR); CREATE TABLE concert (stadium_id VARCHAR) 
-- 

KeyboardInterrupt: 

In [None]:
print(dataset["query"][1])

In [None]:
print(p(dataset["query"][3])[0])

In [None]:
max(map(len, dataset["answer"]))