In [1]:
from transformers import AutoTokenizer
import random
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2')
tokenizer.pad_token_id = 0

tokenizer

  from .autonotebook import tqdm as notebook_tqdm
Using sep_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using mask_token, but it is not set yet.


GPT2TokenizerFast(name_or_path='tokenizer/gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [2]:
from datasets import load_from_disk

dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']


def f(data):
    question = 'context:%s question:%s answer:' % (data['context'],
                                                   data['question'])
    answer = data['answer']

    question = tokenizer.encode(question, add_special_tokens=False)
    answer = tokenizer.encode(answer, add_special_tokens=False)

    return {'question': question, 'answer': answer}


dataset = dataset.map(f, remove_columns=['context'])


def f(data):
    question = len(data['question'])
    answer = len(data['answer'])
    return 25 <= question <= 65 and 10 <= answer <= 35


dataset = dataset.filter(f)

dataset = dataset.train_test_split(test_size=200)

dataset, dataset['train'][0]

(DatasetDict({
     train: Dataset({
         features: ['question', 'answer'],
         num_rows: 71745
     })
     test: Dataset({
         features: ['question', 'answer'],
         num_rows: 200
     })
 }),
 {'question': [22866,
   25,
   43387,
   6158,
   43679,
   3084,
   62,
   3672,
   62,
   2414,
   357,
   354,
   20297,
   569,
   31315,
   1503,
   11,
   614,
   17828,
   7156,
   1137,
   8,
   1808,
   25,
   2061,
   318,
   262,
   24587,
   706,
   10249,
   30,
   3280,
   25],
  'answer': [46506,
   24587,
   16034,
   3084,
   62,
   3672,
   62,
   2414,
   33411,
   614,
   1875,
   10249]})

In [3]:
def get_batch_data():

    def pad(data, split, lens):
        #做个白板
        input_ids = torch.full((len(data), lens),
                               tokenizer.pad_token_id,
                               device=device)

        #往白板里黏贴数据
        for i, d in enumerate(data):
            input_ids[i, :len(d)] = torch.LongTensor(d)

        attention_mask = (input_ids != tokenizer.pad_token_id).long()

        #计算label
        label = input_ids.clone()
        for l, s in zip(label, split):
            #问题和pad的位置是-100
            l[:s] = -100
            l[l == tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label
        }

    sample = random.choices(dataset['train'], k=16)
    question = [i['question'] for i in sample]
    answer = [i['answer'] for i in sample]
    split = [len(i) for i in question]

    #正确的问答
    choice = [
        q + a + [tokenizer.eos_token_id] for q, a in zip(question, answer)
    ]

    #错误的回答简单地定义为空回答就可以了
    reject = [q + [tokenizer.eos_token_id] for q, a in zip(question, answer)]

    #求最大长度
    lens = max([len(i) for i in choice])

    return pad(choice, split, lens), pad(reject, split, lens)


get_batch_data()

({'input_ids': tensor([[22866,    25, 43387,  ...,     0,     0,     0],
          [22866,    25, 43387,  ...,     0,     0,     0],
          [22866,    25, 43387,  ...,     0,     0,     0],
          ...,
          [22866,    25, 43387,  ...,     0,     0,     0],
          [22866,    25, 43387,  ...,     0,     0,     0],
          [22866,    25, 43387,  ...,     0,     0,     0]], device='cuda:0'),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
  'label': tensor([[-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  ..., -100, -100, -100],
          ...,
          [-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  .

In [4]:
class ModelDPO(torch.nn.Module):

    def __init__(self):
        super().__init__()
        from transformers import AutoModelForCausalLM

        self.model = AutoModelForCausalLM.from_pretrained('model/gpt2')

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        out = self.model.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask).last_hidden_state

        return self.model.lm_head(out)


model_dpo = ModelDPO()
model_dpo_ref = ModelDPO()

In [5]:
@torch.no_grad()
def generate(input_ids):
    lens = input_ids.shape[1]
    while True:
        out = model_dpo(input_ids=input_ids,
                        attention_mask=torch.ones_like(input_ids))
        topk = out[0, -1].topk(1)

        values = topk.values.softmax(0).tolist()
        indices = topk.indices.tolist()
        next_word = random.choices(indices, weights=values)

        next_word = torch.LongTensor(next_word).unsqueeze(0).to('cuda')
        input_ids = torch.cat([input_ids, next_word], dim=1)

        if input_ids.shape[1] - lens >= 35:
            break

        if input_ids[0, -1] == tokenizer.eos_token_id:
            break

    return input_ids


input_ids = dataset['test'][0]['question']
input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)

out = generate(input_ids)

tokenizer.decode(out[0])

'context:CREATE TABLE table_name_8 (pos VARCHAR, date_from VARCHAR) question:What is Pos., when Date From is "28 August 2008"? answer:: "What is, when is a table?"\n\nThe table_name_name__name_8 is a string that contains the name of the table name.'

In [6]:
def get_prob_log(model, choice, reject):
    b = choice['input_ids'].shape[0]

    #合并两部分输入,同时计算以提高效率
    #[b, 21]
    input_ids = torch.cat([choice['input_ids'], reject['input_ids']], dim=0)
    attention_mask = torch.cat(
        [choice['attention_mask'], reject['attention_mask']], dim=0)
    label = torch.cat([choice['label'], reject['label']], dim=0)

    #[b, 21, 28]
    out = model(input_ids=input_ids, attention_mask=attention_mask)

    #偏移以对齐
    #[b, 20]
    label = label[:, 1:]
    #[b, 20, 28]
    out = out[:, :-1]

    #取所有字的预测概率,因为要求联合概率,所以取对数
    out = (out.softmax(2) + 1e-8).log()

    #取预测到label的概率
    #索引不能是负数,所以这里把负数置0
    index = label.clone().unsqueeze(2)
    index[index == -100] = 0
    prob = out.gather(2, index=index).squeeze(2)

    #只取答案部分的loss,筛选后,所有答案的概率对数求和
    prob = (prob * (label != -100)).sum(1)

    #choice和reject的预测概率求差
    return prob[:b] - prob[b:]


get_prob_log(model_dpo, *get_batch_data())

tensor([-58.1640, -61.3286, -39.4775, -71.7320, -62.9812, -67.9002, -35.7873,
        -55.2316, -55.5689, -73.9519, -81.4958, -37.8672, -55.4771, -86.0505,
        -48.4158, -43.4941], device='cuda:0', grad_fn=<SubBackward0>)

In [7]:
optimizer = torch.optim.Adam(model_dpo.parameters(),
                             lr=1e-5,
                             betas=(0.9, 0.999),
                             eps=1e-8)

for i in range(2000):
    choice, reject = get_batch_data()

    #两个模型分别计算概率对数
    prob_log = get_prob_log(model_dpo, choice, reject)
    with torch.no_grad():
        prob_log_ref = get_prob_log(model_dpo_ref, choice, reject)

    #两份概率计算kl散度
    kl = -0.1 * (prob_log - prob_log_ref)

    #以kl散度计算loss
    loss = (kl.sigmoid() + 1e-8).log().mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        data = random.choice(dataset['test'])
        input_ids = torch.LongTensor(data['question']).unsqueeze(0).to(device)

        out = generate(input_ids)

        print(i, tokenizer.decode(out[0]))
        print('=========')
        print(tokenizer.decode(data['answer']))
        print('=========')

0 context:CREATE TABLE table_name_50 (character_name VARCHAR, voice_actor__english_1998___pioneer_ VARCHAR) question:what character did Laara sadiq play answer:what character did Laara play answer:what character did Laara sadiq play answer:what character did Laara sadiq play answer:what character did Laara play answer
SELECT character_name FROM table_name_50 WHERE voice_actor__english_1998___pioneer_ = "laara sadiq"
100 context:CREATE TABLE table_name_94 (position VARCHAR, pick VARCHAR) question:What is pick 246's position? answer:SELECT position FROM table_name_94 WHERE position = "pick 246"<|endoftext|>
SELECT position FROM table_name_94 WHERE pick = 246
200 context:CREATE TABLE table_name_78 (division_record VARCHAR, school VARCHAR) question:What is the division record for Woodbridge? answer:SELECT division_record FROM table_name_78 WHERE school = "woodbridge"<|endoftext|>
SELECT division_record FROM table_name_78 WHERE school = "woodbridge"
300 context:CREATE TABLE table_name_20 (n