In [1]:
import torch
import random
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = 0

tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [2]:
from datasets import load_dataset

dataset = load_dataset('b-mc2/sql-create-context', split='train')


def f(data):
    prompt = 'context: ' + data['context'] + ' question: ' + data[
        'question'] + ' answer: '

    return {'prompt': prompt, 'chosen': data['answer'], 'rejected': ''}


dataset = dataset.map(f, remove_columns=dataset.column_names)

dataset, dataset[0]

(Dataset({
     features: ['prompt', 'chosen', 'rejected'],
     num_rows: 78577
 }),
 {'prompt': 'context: CREATE TABLE head (age INTEGER) question: How many heads of the departments are older than 56 ? answer: ',
  'chosen': 'SELECT COUNT(*) FROM head WHERE age > 56',
  'rejected': ''})

In [3]:
from transformers import AutoModelForCausalLM

model_actor = AutoModelForCausalLM.from_pretrained('model/actor').to('cuda')
model_actor_ref = AutoModelForCausalLM.from_pretrained('model/actor').to('cuda')

In [4]:
from transformers import TrainerCallback
from trl import DPOConfig, DPOTrainer


class MyCallback(TrainerCallback):

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 1000 == 0:
            print(state.global_step)

            data = random.choice(dataset)
            prompt = tokenizer(data['prompt'],
                               return_tensors='pt').input_ids.to(device)

            gen = model_actor.generate(input_ids=prompt,
                                       min_length=-1,
                                       max_length=prompt.shape[1] + 50,
                                       pad_token_id=tokenizer.pad_token_id,
                                       eos_token_id=tokenizer.eos_token_id,
                                       top_k=0.0,
                                       top_p=1.0,
                                       do_sample=True)

            print(data['chosen'])
            print(tokenizer.decode(gen[0]))


args = DPOConfig(output_dir='output_dir',
                 loss_type='sigmoid',
                 beta=0.1,
                 per_device_train_batch_size=8,
                 max_steps=80000,
                 learning_rate=1e-5,
                 optim='rmsprop',
                 max_length=100,
                 max_prompt_length=100,
                 max_target_length=100,
                 eval_strategy='no',
                 report_to='none',
                 save_strategy='no',
                 remove_unused_columns=False)

trainer = DPOTrainer(model=model_actor,
                     ref_model=model_actor_ref,
                     args=args,
                     train_dataset=dataset,
                     tokenizer=tokenizer,
                     callbacks=[MyCallback])

trainer.train()

model_actor.save_pretrained('model/trl')



Tokenizing train dataset:   0%|          | 0/78577 [00:00<?, ? examples/s]

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
500,0.0476
1000,0.045
1500,0.0439
2000,0.0401
2500,0.0432
3000,0.0442
3500,0.0477
4000,0.0464
4500,0.0422
5000,0.0429


1000
SELECT pick FROM table_name_15 WHERE school = "lamar high school"
context: CREATE TABLE table_name_15 (pick VARCHAR, school VARCHAR) question: Which pick was from Lamar High School? answer:  minimal_pick FROM table_name_15 WHERE school = "lamar high school"<|endoftext|>
2000
SELECT T1.amenity_name FROM dorm_amenity AS T1 JOIN has_amenity AS T2 ON T1.amenid = T2.amenid GROUP BY T2.amenid ORDER BY COUNT(*) DESC LIMIT 1
context: CREATE TABLE dorm_amenity (amenity_name VARCHAR, amenid VARCHAR); CREATE TABLE has_amenity (amenid VARCHAR) question: Find the name of amenity that is most common in all dorms. answer:  minimal_population FROM dorm_amenity AS tIN t3 JOIN has_amenity AS t2 ON t3.amenid = t2.amenid GROUP BY t2.amenid HAVING COUNT(*) > 1
3000
SELECT COUNT(high_rebounds) FROM table_23248940_10 WHERE record = "34-32"
context: CREATE TABLE table_23248940_10 (high_rebounds VARCHAR, record VARCHAR) question: How many people had high rebounds during the game with a record of 34-32? an

25000
SELECT loss FROM table_name_23 WHERE date = "june 13"
context: CREATE TABLE table_name_23 (loss VARCHAR, date VARCHAR) question: Which loss was on June 13? answer:  minimal_loss FROM table_name_23 WHERE date = "june 13"<|endoftext|>
26000
SELECT may_2009 FROM table_23680576_2 WHERE jul_2009 = "7.2%"
context: CREATE TABLE table_23680576_2 (may_2009 VARCHAR, jul_2009 VARCHAR) question: If the polling average in July 2009 was 7.2%, what was it in May 2009? answer:  minimal_poll_average FROM table_23680576_2 WHERE jul_2009 = "7.2%"<|endoftext|>
27000
SELECT COUNT(*) FROM debate
context: CREATE TABLE debate (Id VARCHAR) question: How many debates are there? answer:  minimal_event_count FROM debate WHERE minimal_event_count = ''<|endoftext|>
28000
SELECT proto_semitic FROM table_26919_7 WHERE english = "house"
context: CREATE TABLE table_26919_7 (proto_semitic VARCHAR, english VARCHAR) question: If in English it's house, what is it in proto-semitic? answer:  minimal_semitic FROM table_

51000
SELECT genre FROM table_name_72 WHERE publisher = "atari" AND year = 1991 AND developer = "nufx"
context: CREATE TABLE table_name_72 (genre VARCHAR, developer VARCHAR, publisher VARCHAR, year VARCHAR) question: Which Genre was published by Atari and developed by NuFX in 1991? answer:  minimal_genre FROM table_name_72 WHERE publisher = "al Atari" AND year = 1991 AND developer = "nufx"<|endoftext|>
52000
SELECT 1880 FROM table_name_19 WHERE 1860 = "n/a" AND 1910 = 494
context: CREATE TABLE table_name_19 (Id VARCHAR) question: What is the 1880 figure when 1860 is N/A and 1910 is 494? answer:  minimal_inaudition FROM table_name_19 WHERE 1860 = "n/a" AND 1910 = "494"<|endoftext|>
53000
SELECT races FROM table_name_99 WHERE points = "10.5"
context: CREATE TABLE table_name_99 (races VARCHAR, points VARCHAR) question: How many races had 10.5 points? answer:  minimal_race_monotony FROM table_name_99 WHERE points = 10.5<|endoftext|>
54000
SELECT MAX(no) FROM table_1705429_1 WHERE constitue

77000
SELECT conference_joined FROM table_name_40 WHERE year_joined = 1954 AND mascot = "beavers"
context: CREATE TABLE table_name_40 (conference_joined VARCHAR, year_joined VARCHAR, mascot VARCHAR) question: Which conference joined in 1954 with a beavers mascot? answer:  minimal_conference_joined FROM table_name_40 WHERE year_joined = 1954 AND mascot = "beavers"<|endoftext|>
78000
SELECT year FROM table_13012165_1 WHERE maryland = "Railroaders LL Brunswick"
context: CREATE TABLE table_13012165_1 (year VARCHAR, maryland VARCHAR) question: Which year did Maryland hold the title with Railroaders LL Brunswick? answer:  minimal_year FROM table_13012165_1 WHERE maryland = "Railroaders LL Brunswick"<|endoftext|>
79000
SELECT home_team AS score FROM table_name_63 WHERE venue = "windy hill"
context: CREATE TABLE table_name_63 (home_team VARCHAR, venue VARCHAR) question: What was the Home Team Score for the Windy Hill Venue? answer:  minimal_home_team_score FROM table_name_63 WHERE venue = "win