In [1]:
import torch

prompt_lens = 128
gen_lens = prompt_lens + 128

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, _ = tokenizer.encode('how are you', max_length=6)

input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

input_ids, attention_mask, tokenizer.decode(input_ids)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(tensor([128002, 128002, 128000,   5269,    527,    499]),
 tensor([0, 0, 1, 1, 1, 1]),
 '<|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>how are you')

In [2]:
from datasets import load_dataset
from transformers import default_data_collator

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#2,4,4切分,取最后一部分
dataset = dataset.select(range(45000, len(dataset)))


def f(data):
    input_ids, _ = tokenizer.encode(data['prompt'], max_length=prompt_lens)
    input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


dataset = dataset.map(f, remove_columns=dataset.column_names)

loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=default_data_collator,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader)).keys()

Map:   0%|          | 0/28577 [00:00<?, ? examples/s]

(7144, dict_keys(['input_ids', 'attention_mask']))

In [3]:
%run 1.model.ipynb

model_actor = torch.load('model/actor')
model_actor.train()

optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=2e-6)

In [4]:
class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.rwtransformer = None
        self.v_head = None

    def get_value(self, input_ids, attention_mask):
        value = self.rwtransformer(input_ids=input_ids,
                                   attention_mask=attention_mask)
        return self.v_head(value).squeeze(2)

    def get_reward(self, input_ids, attention_mask):
        value = self.get_value(input_ids, attention_mask)

        reward = []
        for i, v in zip(input_ids, value):
            end = input_ids.shape[1] - 1
            if tokenizer.eos_token_id in i:
                end = i.tolist().index(tokenizer.eos_token_id)
            reward.append(v[end])
        reward = torch.stack(reward)

        return reward


model_critic = torch.load('model/critic')
model_critic.train()

optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

In [5]:
from accelerate import Accelerator

model_ref = torch.load('model/actor')
model_reward = torch.load('model/critic')

model_ref.eval()
model_reward.eval()

accelerator = Accelerator(gradient_accumulation_steps=1,
                          mixed_precision='fp16')

(loader, model_actor, optimizer_actor, model_critic, optimizer_critic,
 model_ref, model_reward) = accelerator.prepare(loader, model_actor,
                                                optimizer_actor, model_critic,
                                                optimizer_critic, model_ref,
                                                model_reward)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
from util import get_generate as get_generate_util


def get_generate(input_ids):
    generate = get_generate_util(model_actor,
                                 input_ids,
                                 tokenizer.eos_token_id,
                                 tokenizer.pad_token_id,
                                 max_length=gen_lens)
    lens = (generate[:, prompt_lens:] != tokenizer.pad_token_id).sum(1)
    return generate[lens > 1]


data = next(iter(loader))

for i in get_generate(data['input_ids']):
    print(tokenizer.decode(i[prompt_lens:]))
    print('================')

select count(AVERAGE) from TABLE_NAME_54 where AVG_GOWN < 21. 8 and AVG_GOWN < 9. 8 and AVERAGE < 9.9<|end_of_text|>
SELECT score FROM table_name_31 WHERE competition = "2000 afc asian cup"<|end_of_text|>
SELECT opponent FROM table_name_70 WHERE date = "april 15, 1981"<|end_of_text|>
SELECT format FROM table_name_80 WHERE station = "mike"<|end_of_text|>


In [7]:
def get_prob(prob, index):
    prob = prob.log_softmax(dim=2)
    prob = prob.gather(dim=2, index=index.unsqueeze(2))
    return prob.squeeze(2)


get_prob(torch.randn(4, 123, 999), torch.randint(0, 999, (4, 123))).shape

torch.Size([4, 123])

In [8]:
last_generate = None


@torch.no_grad()
def get_batch(input_ids, attention_mask):
    #input_ids -> [b, prompt_lens]
    #attention_mask -> [b, prompt_lens]
    global last_generate

    #根据问题生成回答
    #[b, gen_lens]
    generate = get_generate(input_ids)

    #制作缓存,防止所有回答为空的情况
    if len(generate):
        last_generate = generate
    else:
        generate = last_generate

    #[b, gen_lens]
    generate_mask = (generate != tokenizer.pad_token_id).long()

    #两个模型分别取回答被预测到的概率
    #[b, gen_lens-1]
    _, prob_old = model_actor(input_ids=generate, attention_mask=generate_mask)
    prob_old = get_prob(prob_old[:, :-1], generate[:, 1:])

    #取每个词的value
    #[b, gen_lens-1]
    value_old = model_critic.get_value(generate, generate_mask)[:, :-1]

    #[b, gen_lens-1]
    _, prob_ref = model_ref(input_ids=generate, attention_mask=generate_mask)
    prob_ref = get_prob(prob_ref[:, :-1], generate[:, 1:])

    #取回答的分数
    #[b]
    reward = model_reward.get_reward(generate, generate_mask)

    return generate, generate_mask, prob_old, prob_ref, value_old, reward


generate, generate_mask, prob_old, prob_ref, value_old, reward = get_batch(
    **data)

generate.shape, generate_mask.shape, prob_old.shape, prob_ref.shape, value_old.shape, reward.shape

(torch.Size([4, 167]),
 torch.Size([4, 167]),
 torch.Size([4, 166]),
 torch.Size([4, 166]),
 torch.Size([4, 166]),
 torch.Size([4]))

In [9]:
def get_reward_kl(end, prob_old, prob_ref, reward):
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #reward -> [b]

    #两份预测概率求kl散度
    #[b, gen_lens-1]
    reward_kl = -0.1 * (prob_old - prob_ref)

    #把原本的reward加在kl散度的最后一个字上
    for i, e in enumerate(end):
        if e >= reward_kl.shape[1]:
            e = -1
        reward_kl[i, e] += reward[i].clamp(-5, 5)

    #[b, gen_lens-1]
    return reward_kl


end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
end = end.tolist()

reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

reward_kl.shape

torch.Size([4, 166])

In [10]:
#解释见原版代码中的get_delta_note函数
def get_delta(value_old, reward_kl):
    #value_old -> [b, gen_lens-1]
    #reward_kl -> [b, gen_lens-1]

    #gen_lens-2 -> prompt_lens-1
    delta = []
    for i in reversed(range(prompt_lens - 1, value_old.shape[1])):
        #[b]
        value_next = 0.0
        if i != value_old.shape[1] - 1:
            value_next = value_old[:, i + 1]

        #[b]
        d = reward_kl[:, i] + value_next - value_old[:, i]
        if len(delta):
            d += 0.95 * delta[-1]
        delta.append(d)

    #[b, gen_lens-prompt_lens]
    delta = torch.stack(delta[::-1], dim=1)

    return delta


delta = get_delta(value_old, reward_kl)

delta.shape

torch.Size([4, 39])

In [11]:
def get_loss_actor(prob_new, prob_old, delta, generate_mask):
    prob_new = prob_new[:, prompt_lens - 1:]
    prob_old = prob_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #prob_new -> [b, gen_lens-prompt_lens]
    #prob_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #对数概率,求差就是求商,所以这里求的是新旧概率的变化率
    #[b, gen_lens-prompt_lens]
    ratio = ((prob_new - prob_old) * generate_mask).exp()

    #delta是估计出来的去基线Q值,以变化率来缩放Q值
    #最大化Q值,以此来寻找最优的actor
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = delta * ratio
    loss2 = delta * ratio.clamp(0.8, 1.2)
    loss = torch.min(loss1, loss2) * generate_mask
    loss = loss.sum() / generate_mask.sum()
    return -loss


loss_actor = get_loss_actor(prob_old, prob_old, delta, generate_mask)

loss_actor

tensor(1.3530, device='cuda:0')

In [12]:
def get_loss_critic(value_new, value_old, delta, generate_mask):
    value_new = value_new[:, prompt_lens - 1:]
    value_old = value_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #value_new -> [b, gen_lens-prompt_lens]
    #value_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #delta是估计出来的去基线Q值,加上value_old后还原为Q值
    #value_new和Q值求mse loss即可,因为value都是对Q函数的估计
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = (value_new - delta - value_old)**2
    value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)
    loss2 = (value_new - delta - value_old)**2

    #求平均
    loss = torch.max(loss1, loss2) * generate_mask
    loss = loss.sum() / 2 / generate_mask.sum()

    return loss


loss_critic = get_loss_critic(value_old, value_old, delta, generate_mask)

loss_critic

tensor(7.4272, device='cuda:0')

In [13]:
def train(generate, generate_mask, prob_old, prob_ref, value_old, reward):
    #generate -> [b, gen_lens]
    #generate_mask -> [b, gen_lens]
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #value_old -> [b, gen_lens-1]
    #reward -> [b]

    #求出每句话结束的索引
    #[b]
    end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
    end = end.tolist()

    #结束以后的value归零
    for i, e in enumerate(end):
        value_old[i, e + 1:] = 0

    with torch.no_grad():
        #计算新旧概率的kl散度,再把reward加在最后一个字上
        #[b, gen_lens-1]
        reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

        #估计去基线的Q值
        #[b, gen_lens-prompt_lens]
        delta = get_delta(value_old, reward_kl)

    #重新计算回答被生成的概率
    #[b, gen_lens-1]
    _, prob_new = model_actor(input_ids=generate, attention_mask=generate_mask)
    prob_new = get_prob(prob_new[:, :-1], generate[:, 1:])

    #重新计算每个词的value
    #[b, gen_lens-1]
    value_new = model_critic.get_value(input_ids=generate,
                                       attention_mask=generate_mask)[:, :-1]

    with accelerator.accumulate(model_actor, model_critic):
        #更新actor
        loss_actor = get_loss_actor(prob_new, prob_old, delta, generate_mask)
        accelerator.backward(loss_actor)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_actor.parameters(), 1.0)
        optimizer_actor.step()
        optimizer_actor.zero_grad()

        #更新critic
        loss_critic = get_loss_critic(value_new, value_old, delta,
                                      generate_mask)
        accelerator.backward(loss_critic)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
        optimizer_critic.step()
        optimizer_critic.zero_grad()

        return loss_actor.item(), loss_critic.item()


train(generate, generate_mask, prob_old, prob_ref, value_old, reward)

(-2.1609959602355957, 17.055959701538086)

In [14]:
for i, data in enumerate(loader):
    #生成数据
    (generate, generate_mask, prob_old, prob_ref, value_old,
     reward) = get_batch(**data)

    #训练
    loss_actor, loss_critic = train(generate, generate_mask, prob_old,
                                    prob_ref, value_old, reward)

    if (i + 1) % 50 == 0:
        print(i, len(loader), loss_actor, loss_critic, reward[0].item())
        print(tokenizer.decode(generate[0, prompt_lens:]))

torch.save(model_actor.to('cpu'), 'model/rlhf')

49 7144 -0.10152655839920044 0.007377683650702238 10.897445678710938
select max(SEASON__NUMBER) from TABLE_17311720_1 where DIRECTED_BY = "cHRIS mARLES"<|end_of_text|>
99 7144 0.09240681678056717 0.00518863694742322 11.355960845947266
select CANDIDATES from TABLE_1348_5 where INCUMBENT = "cON rACING"<|end_of_text|>
149 7144 -0.29071927070617676 0.046924393624067307 10.408658981323242
select 2 as ND_LEG from TABLE_NAME_41 where HOME__2_ = "2"<|end_of_text|>
199 7144 0.012671473436057568 0.003083077259361744 11.179220199584961
select TO_PAR from TABLE_NAME_43 where PLAYER = "JANIA PARK"<|end_of_text|>
249 7144 0.03380975127220154 0.00437870854511857 9.90336799621582
select SCOTT from TABLE_1463_1 where ALBUM = "28.7%"<|end_of_text|>
299 7144 -0.10339976102113724 0.00850830227136612 11.192331314086914
select DOCUMENT_ID, CUSTOMER_ID from CUSTOMERSON where PRODUCT_NAME = 'cON'<|end_of_text|>
349 7144 -0.06394381076097488 0.004041989333927631 9.919631004333496
select t1.nAME from sTUid from

2749 7144 0.0821671336889267 0.0038253075908869505 11.752277374267578
select VENUE from TABLE_NAME_57 where HOME_TEAM = "MAY"<|end_of_text|>
2799 7144 0.0790843665599823 0.004348687827587128 8.7584228515625
select GT2_WINNING_TEAM from TABLE_12159168_2 where GT1_WINNING_TEAM = "cONTO"<|end_of_text|>
2849 7144 -0.016536198556423187 0.0019114071037620306 11.290727615356445
select SET_2 from TABLE_NAME_42 where TOTAL = "45–45"<|end_of_text|>
2899 7144 -0.011421989649534225 0.0013112911256030202 12.513911247253418
select sTUid from sTUid where sTUid = 'pID'<|end_of_text|>
2949 7144 -0.020961560308933258 0.00384416151791811 6.734859466552734
select max(PICK__NUMBER) from TABLE_NAME_67 where OVERALL > 5 and POSITION = "JANIA" and OVERALL > 5<|end_of_text|>
2999 7144 0.01602778024971485 0.004078662488609552 10.777851104736328
select COMPETITION from TABLE_NAME_27 where POSITION = "5TH"<|end_of_text|>
3049 7144 -0.029689468443393707 0.0012230310821905732 10.220989227294922
select STADIUM from 

5399 7144 0.05023661628365517 0.001649458659812808 10.833001136779785
select count(YEAR) from TABLE_NAME_52 where LAUNCHED = "$19"<|end_of_text|>
5449 7144 -0.03004116378724575 0.0010667068418115377 9.751044273376465
select RACE from TABLE_NAME_15 where RACE = "4" and RESULT = "4"<|end_of_text|>
5499 7144 0.24664810299873352 0.18246722221374512 9.485268592834473
select TYRES from TABLE_27021023_4 where DENSITY = "1%"<|end_of_text|>
5549 7144 -0.028134746477007866 0.002870990429073572 10.095355987548828
select DURATION from TABLE_25816477_1 where LOCATION = "aMAY, mURNE"<|end_of_text|>
5599 7144 0.014672581106424332 0.0032644160091876984 9.444406509399414
select count(SEATS) from TABLE_NAME_71 where _PERCENTAGE_OF_POPULATION = "8.1%"<|end_of_text|>
5649 7144 1.2899353504180908 4.054957866668701 9.922690391540527
select 2002 from TABLE_NAME_38 where 2006 = "A" and 2006 = "A"<|end_of_text|>
5699 7144 -0.05323629081249237 0.008737629279494286 11.692553520202637
select QUAL from TABLE_NAME_