In [1]:
import random
import torch

checkpoint = 110
device = 'cuda'
dtype = torch.float16
only_test = True

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2-large')
tokenizer.pad_token_id = 0

tokenizer

  from .autonotebook import tqdm as notebook_tqdm
Using sep_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using mask_token, but it is not set yet.


GPT2TokenizerFast(name_or_path='tokenizer/gpt2-large', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [2]:
from datasets import load_from_disk

dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']


def f(data):
    question = 'context:%s question:%s answer:' % (data['context'],
                                                   data['question'])
    answer = data['answer']
    return {'question': question, 'answer': answer}


dataset = dataset.map(f, remove_columns=['context'])


def f(data):
    question = len(tokenizer.encode(data['question']))
    answer = len(tokenizer.encode(data['answer']))
    return 25 <= question <= 65 and 10 <= answer <= 35


dataset = dataset.filter(f)


def f(data):
    return {
        'prompt': data['question'],
        'chosen': data['answer'],
        'rejected': ''
    }


dataset = dataset.map(f, remove_columns=['question', 'answer'])
dataset = dataset.train_test_split(test_size=1500)

if only_test:
    dataset['train'] = dataset['train'].select(range(200))

dataset, dataset['train'][0]

(DatasetDict({
     train: Dataset({
         features: ['prompt', 'chosen', 'rejected'],
         num_rows: 200
     })
     test: Dataset({
         features: ['prompt', 'chosen', 'rejected'],
         num_rows: 1500
     })
 }),
 {'prompt': 'context:CREATE TABLE table_26400075_2 (weeks_in_top_10 VARCHAR, artist VARCHAR) question:How many weeks in the top-10 did Beats International have? answer:',
  'chosen': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = "Beats International"',
  'rejected': ''})

In [3]:
from transformers import AutoModelForCausalLM

path = 'model/gpt2-large'
if checkpoint != -1:
    path = 'model/dpo_%d.model' % checkpoint

model_dpo = AutoModelForCausalLM.from_pretrained(path).to(device)
if not only_test:
    model_dpo_ref = AutoModelForCausalLM.from_pretrained(path).to(device)

path

'model/dpo_110.model'

In [4]:
#重新生成数据集中的rejected字段
def remake_dataset():
    global dataset
    tokenizer.padding_side = 'left'
    model_dpo.to(dtype)

    def f(data):
        token = tokenizer(data['prompt'],
                          return_tensors='pt',
                          padding=True,
                          truncation=True).to(device)

        out = model_dpo.generate(**token,
                                 min_length=-1,
                                 top_k=1,
                                 top_p=1.0,
                                 do_sample=True,
                                 pad_token_id=tokenizer.pad_token_id,
                                 max_new_tokens=35,
                                 eos_token_id=tokenizer.eos_token_id)

        for i in range(len(out)):
            lens = len(token['input_ids'][i])
            rejected = out[i, lens:]

            if tokenizer.eos_token_id in rejected:
                lens = rejected.tolist().index(tokenizer.eos_token_id) + 1
                rejected = rejected[:lens]

            rejected = rejected[:35]
            rejected = tokenizer.decode(rejected, skip_special_tokens=True)

            if rejected == data['chosen'][i]:
                rejected = ''

            data['rejected'][i] = rejected

        return data

    dataset = dataset.map(f, batched=True, batch_size=128, num_proc=1)

    tokenizer.padding_side = 'right'
    model_dpo.to(torch.float32)


if only_test or checkpoint != -1:
    remake_dataset()

dataset, dataset['train'][0]

Map: 100%|██████████| 200/200 [00:02<00:00, 68.13 examples/s]
Map: 100%|██████████| 1500/1500 [00:17<00:00, 87.49 examples/s]


(DatasetDict({
     train: Dataset({
         features: ['prompt', 'chosen', 'rejected'],
         num_rows: 200
     })
     test: Dataset({
         features: ['prompt', 'chosen', 'rejected'],
         num_rows: 1500
     })
 }),
 {'prompt': 'context:CREATE TABLE table_26400075_2 (weeks_in_top_10 VARCHAR, artist VARCHAR) question:How many weeks in the top-10 did Beats International have? answer:',
  'chosen': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = "Beats International"',
  'rejected': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = "Beatles International"'})

In [5]:
#重载模型
def reload_model(epoch):
    global model_dpo
    global model_dpo_ref

    path = 'model/dpo_%d.model' % epoch
    model_dpo.save_pretrained(path)
    model_dpo = AutoModelForCausalLM.from_pretrained(path).to(device)
    model_dpo_ref = AutoModelForCausalLM.from_pretrained(path).to(device)


# reload_model(0)

In [6]:
from transformers import TrainingArguments, TrainerCallback
from trl import DPOTrainer


def retrain():

    class MyCallback(TrainerCallback):

        def on_step_end(self, args, state, control, **kwargs):
            if state.global_step % 250 == 0:
                print(state.global_step)
                return

                data = random.choice(dataset['test'])
                input_ids = tokenizer.encode(data['prompt'],
                                             return_tensors='pt').to(device)

                out = model_dpo.generate(input_ids,
                                         min_length=-1,
                                         top_k=1,
                                         top_p=1.0,
                                         do_sample=True,
                                         pad_token_id=tokenizer.pad_token_id,
                                         max_new_tokens=35,
                                         eos_token_id=tokenizer.eos_token_id)

                print(tokenizer.decode(out[0]))
                print('=================')
                print(data['chosen'])
                print(data['rejected'])
                print('=================')

    args = TrainingArguments(output_dir='output_dir',
                             learning_rate=1e-5,
                             per_device_train_batch_size=4,
                             max_steps=5000,
                             evaluation_strategy='no',
                             report_to='none',
                             save_strategy='no')

    trainer = DPOTrainer(model_dpo,
                         model_dpo_ref,
                         args=args,
                         beta=0.1,
                         train_dataset=dataset['train'],
                         tokenizer=tokenizer,
                         max_length=100,
                         max_target_length=100,
                         max_prompt_length=100,
                         callbacks=[MyCallback()])

    trainer.train()


# retrain()



In [7]:
def test():
    sample = random.choices(dataset['test'], k=8)
    for i in sample:
        for k, v in i.items():
            print(k, '->', v)
        print('=========')

    def correct(data, lower):
        rejected = data['rejected']
        chosen = data['chosen']

        if rejected == '':
            return True

        if lower:
            chosen = chosen.lower().replace('"', '\'')
            rejected = rejected.lower().replace('"', '\'')

        return chosen == rejected

    def accuracy(lower):
        return sum([correct(i, lower)
                    for i in dataset['test']]) / len(dataset['test'])

    return accuracy(False), accuracy(True)


test()

prompt -> context:CREATE TABLE table_name_11 (division INTEGER, reg_season VARCHAR) question:Who was the lowest division in the 7th season? answer:
chosen -> SELECT MIN(division) FROM table_name_11 WHERE reg_season = "7th"
rejected -> 
prompt -> context:CREATE TABLE cinema (LOCATION VARCHAR) question:Show each location and the number of cinemas there. answer:
chosen -> SELECT LOCATION, COUNT(*) FROM cinema GROUP BY LOCATION
rejected -> 
prompt -> context:CREATE TABLE Sections (section_name VARCHAR, section_description VARCHAR) question:What are the names and descriptions of all the sections? answer:
chosen -> SELECT section_name, section_description FROM Sections
rejected -> 
prompt -> context:CREATE TABLE table_name_41 (venue VARCHAR, date VARCHAR) question:What is the venue of the game that was played on 23 October 1966? answer:
chosen -> SELECT venue FROM table_name_41 WHERE date = "23 october 1966"
rejected -> 
prompt -> context:CREATE TABLE table_26077092_7 (pick__number INTEGER, 

(0.848, 0.8573333333333333)

In [8]:
if not only_test:
    for epoch in range(checkpoint + 1, 100):
        retrain()
        reload_model(epoch)
        remake_dataset()

        print('epoch', epoch, 'test:', test())