In [1]:
import torch
import random
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = 0
tokenizer.padding_side = 'left'

tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [2]:
from datasets import load_dataset

dataset = load_dataset('b-mc2/sql-create-context', split='train')

dataset, dataset[0]

(Dataset({
     features: ['question', 'context', 'answer'],
     num_rows: 78577
 }),
 {'question': 'How many heads of the departments are older than 56 ?',
  'context': 'CREATE TABLE head (age INTEGER)',
  'answer': 'SELECT COUNT(*) FROM head WHERE age > 56'})

In [3]:
from transformers import AutoModelForCausalLM

model_actor = AutoModelForCausalLM.from_pretrained('model/dpo').to(device)

In [4]:
data = random.choices(dataset, k=12)

question = [
    'context: ' + i['context'] + ' question: ' + i['question'] + ' answer: '
    for i in data
]

answer = [i['answer'] for i in data]

question = tokenizer(question,
                     padding=True,
                     truncation=True,
                     max_length=512,
                     return_tensors='pt').input_ids.to(device)

gen = model_actor.generate(input_ids=question,
                           min_length=-1,
                           max_length=512,
                           pad_token_id=tokenizer.pad_token_id,
                           eos_token_id=tokenizer.eos_token_id,
                           top_k=0.0,
                           top_p=1.0,
                           do_sample=True)

gen = gen[:, question.shape[1]:]

for a, g in zip(answer, gen):
    print(a)
    print(tokenizer.decode(g, skip_special_tokens=True))
    print('==============')

SELECT COUNT(engine) FROM table_20866024_4 WHERE model_designation = "97H00"
SELECT COUNT(engine) FROM table_20866024_4 WHERE model_designation = "97H00"!!!!!!!
SELECT surface FROM table_name_3 WHERE opponent = "richard fromberg"
SELECT surface FROM table_name_3 WHERE opponent = "richard fromberg"!!!!!!!!!!!!!!!!
SELECT home___away FROM table_name_23 WHERE record = "6-12"
SELECT home___away FROM table_name_23 WHERE record = "6-12"!!!!!!!!!!!!!!!
SELECT result FROM table_name_78 WHERE date = "november 24, 1946"
SELECT result FROM table_name_78 WHERE date = "november 24, 1946"!!!!!!!!!!!!!!!
SELECT SUM(silver) FROM table_name_82 WHERE bronze > 2
SELECT SUM(silver) FROM table_name_82 WHERE bronze > 2!!!!!!!!!!!!!!!!!!
SELECT MAX(winners) FROM table_12303563_1 WHERE nation = "BEC Tero Sasana"
SELECT MIN(winners) FROM table_12303563_1 WHERE nation = "Belg Insanity"!!!!!!!!
SELECT MIN(_number) FROM table_2182170_1 WHERE driver_s_ = "Geoffrey Bodine"
SELECT MAX(_number) FROM table_2182170_1 W