In [1]:
import torch

prompt_lens = 128
gen_lens = prompt_lens + 128

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, _ = tokenizer.encode('how are you', max_length=6)

input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

input_ids, attention_mask, tokenizer.decode(input_ids)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(tensor([128002, 128002, 128000,   5269,    527,    499]),
 tensor([0, 0, 1, 1, 1, 1]),
 '<|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>how are you')

In [2]:
from datasets import load_dataset
from transformers import default_data_collator

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#2,4,4切分,取最后一部分
dataset = dataset.select(range(45000, len(dataset)))


def f(data):
    input_ids, _ = tokenizer.encode(data['prompt'], max_length=prompt_lens)
    input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


dataset = dataset.map(f, remove_columns=dataset.column_names)

loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=default_data_collator,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader)).keys()

(7144, dict_keys(['input_ids', 'attention_mask']))

In [3]:
%run 1.model.ipynb

model_actor = torch.load('model/actor')
model_actor.train()

optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=2e-6)

In [4]:
class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.rwtransformer = None
        self.v_head = None

    def get_value(self, input_ids, attention_mask):
        value = self.rwtransformer(input_ids=input_ids,
                                   attention_mask=attention_mask)
        return self.v_head(value).squeeze(2)

    def get_reward(self, input_ids, attention_mask):
        value = self.get_value(input_ids, attention_mask)

        reward = []
        for i, v in zip(input_ids, value):
            end = input_ids.shape[1] - 1
            if tokenizer.eos_token_id in i:
                end = i.tolist().index(tokenizer.eos_token_id)
            reward.append(v[end])
        reward = torch.stack(reward)

        return reward


model_critic = torch.load('model/critic')
model_critic.train()

optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

In [5]:
from accelerate import Accelerator

model_ref = torch.load('model/actor')
model_reward = torch.load('model/critic')

model_ref.eval()
model_reward.eval()

accelerator = Accelerator(gradient_accumulation_steps=8,
                          mixed_precision='fp16')

(loader, model_actor, optimizer_actor, model_critic, optimizer_critic,
 model_ref, model_reward) = accelerator.prepare(loader, model_actor,
                                                optimizer_actor, model_critic,
                                                optimizer_critic, model_ref,
                                                model_reward)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
from util import get_generate as get_generate_util


def get_generate(input_ids):
    generate = get_generate_util(model_actor,
                                 input_ids,
                                 tokenizer.eos_token_id,
                                 tokenizer.pad_token_id,
                                 max_length=gen_lens)
    lens = (generate[:, prompt_lens:] != tokenizer.pad_token_id).sum(1)
    return generate[lens > 1]


data = next(iter(loader))

for i in get_generate(data['input_ids']):
    print(tokenizer.decode(i[prompt_lens:]))
    print('================')

select min(YARDS) from TABLE_NAME_30 where AVG = 7 and LONG < 7<|end_of_text|>
select sum(RANK__TIM_) from TABLE_NAME_2 where RANK__M_ > 2.6 and RANK__M_ = 5 and RANK__NIGHT_ > 5<|end_of_text|>
SELECT final_venue FROM table_name_23 WHERE number_of_dances = 82<|end_of_text|>
SELECT home_captain FROM table_name_2 WHERE arena_away = "vfl park"<|end_of_text|>


In [7]:
def get_prob(prob, index):
    prob = prob.log_softmax(dim=2)
    prob = prob.gather(dim=2, index=index.unsqueeze(2))
    return prob.squeeze(2)


get_prob(torch.randn(4, 123, 999), torch.randint(0, 999, (4, 123))).shape

torch.Size([4, 123])

In [8]:
last_generate = None
cache_count = 0


@torch.no_grad()
def get_batch(input_ids, attention_mask):
    #input_ids -> [b, prompt_lens]
    #attention_mask -> [b, prompt_lens]
    global last_generate
    global cache_count

    #根据问题生成回答
    #[b, gen_lens]
    generate = get_generate(input_ids)

    #制作缓存,防止所有回答为空的情况
    if len(generate):
        last_generate = generate
    else:
        generate = last_generate
        cache_count += 1

    #[b, gen_lens]
    generate_mask = (generate != tokenizer.pad_token_id).long()

    #两个模型分别取回答被预测到的概率
    #[b, gen_lens-1]
    _, prob_old = model_actor(input_ids=generate, attention_mask=generate_mask)
    prob_old = get_prob(prob_old[:, :-1], generate[:, 1:])

    #取每个词的value
    #[b, gen_lens-1]
    value_old = model_critic.get_value(generate, generate_mask)[:, :-1]

    #[b, gen_lens-1]
    _, prob_ref = model_ref(input_ids=generate, attention_mask=generate_mask)
    prob_ref = get_prob(prob_ref[:, :-1], generate[:, 1:])

    #取回答的分数
    #[b]
    reward = model_reward.get_reward(generate, generate_mask)

    return generate, generate_mask, prob_old, prob_ref, value_old, reward


generate, generate_mask, prob_old, prob_ref, value_old, reward = get_batch(
    **data)

generate.shape, generate_mask.shape, prob_old.shape, prob_ref.shape, value_old.shape, reward.shape

(torch.Size([4, 171]),
 torch.Size([4, 171]),
 torch.Size([4, 170]),
 torch.Size([4, 170]),
 torch.Size([4, 170]),
 torch.Size([4]))

In [9]:
def get_reward_kl(end, prob_old, prob_ref, reward):
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #reward -> [b]

    #两份预测概率求kl散度
    #[b, gen_lens-1]
    reward_kl = -0.1 * (prob_old - prob_ref)

    #把原本的reward加在kl散度的最后一个字上
    for i, e in enumerate(end):
        if e >= reward_kl.shape[1]:
            e = -1
        reward_kl[i, e] += reward[i].clamp(-5, 5)

    #[b, gen_lens-1]
    return reward_kl


end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
end = end.tolist()

reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

reward_kl.shape

torch.Size([4, 170])

In [10]:
#解释见原版代码中的get_delta_note函数
def get_delta(value_old, reward_kl):
    #value_old -> [b, gen_lens-1]
    #reward_kl -> [b, gen_lens-1]

    #gen_lens-2 -> prompt_lens-1
    delta = []
    for i in reversed(range(prompt_lens - 1, value_old.shape[1])):
        #[b]
        value_next = 0.0
        if i != value_old.shape[1] - 1:
            value_next = value_old[:, i + 1]

        #[b]
        d = reward_kl[:, i] + value_next - value_old[:, i]
        if len(delta):
            d += 0.95 * delta[-1]
        delta.append(d)

    #[b, gen_lens-prompt_lens]
    delta = torch.stack(delta[::-1], dim=1)

    return delta


delta = get_delta(value_old, reward_kl)

delta.shape

torch.Size([4, 43])

In [11]:
def get_loss_actor(prob_new, prob_old, delta, generate_mask):
    prob_new = prob_new[:, prompt_lens - 1:]
    prob_old = prob_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #prob_new -> [b, gen_lens-prompt_lens]
    #prob_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #对数概率,求差就是求商,所以这里求的是新旧概率的变化率
    #[b, gen_lens-prompt_lens]
    ratio = ((prob_new - prob_old) * generate_mask).exp()

    #delta是估计出来的去基线Q值,以变化率来缩放Q值
    #最大化Q值,以此来寻找最优的actor
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = delta * ratio
    loss2 = delta * ratio.clamp(0.8, 1.2)
    loss = torch.min(loss1, loss2) * generate_mask
    loss = loss.sum() / generate_mask.sum()
    return -loss


loss_actor = get_loss_actor(prob_old, prob_old, delta, generate_mask)

loss_actor

tensor(1.3177, device='cuda:0')

In [12]:
def get_loss_critic(value_new, value_old, delta, generate_mask):
    value_new = value_new[:, prompt_lens - 1:]
    value_old = value_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #value_new -> [b, gen_lens-prompt_lens]
    #value_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #delta是估计出来的去基线Q值,加上value_old后还原为Q值
    #value_new和Q值求mse loss即可,因为value都是对Q函数的估计
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = (value_new - delta - value_old)**2
    value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)
    loss2 = (value_new - delta - value_old)**2

    #求平均
    loss = torch.max(loss1, loss2) * generate_mask
    loss = loss.sum() / 2 / generate_mask.sum()

    return loss


loss_critic = get_loss_critic(value_old, value_old, delta, generate_mask)

loss_critic

tensor(4.3144, device='cuda:0')

In [13]:
def train(generate, generate_mask, prob_old, prob_ref, value_old, reward):
    #generate -> [b, gen_lens]
    #generate_mask -> [b, gen_lens]
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #value_old -> [b, gen_lens-1]
    #reward -> [b]

    #求出每句话结束的索引
    #[b]
    end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
    end = end.tolist()

    #结束以后的value归零
    for i, e in enumerate(end):
        value_old[i, e + 1:] = 0

    with torch.no_grad():
        #计算新旧概率的kl散度,再把reward加在最后一个字上
        #[b, gen_lens-1]
        reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

        #估计去基线的Q值
        #[b, gen_lens-prompt_lens]
        delta = get_delta(value_old, reward_kl)

    #重新计算回答被生成的概率
    #[b, gen_lens-1]
    _, prob_new = model_actor(input_ids=generate, attention_mask=generate_mask)
    prob_new = get_prob(prob_new[:, :-1], generate[:, 1:])

    #重新计算每个词的value
    #[b, gen_lens-1]
    value_new = model_critic.get_value(input_ids=generate,
                                       attention_mask=generate_mask)[:, :-1]

    with accelerator.accumulate(model_actor, model_critic):
        #更新actor
        loss_actor = get_loss_actor(prob_new, prob_old, delta, generate_mask)
        accelerator.backward(loss_actor)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_actor.parameters(), 1.0)
        optimizer_actor.step()
        optimizer_actor.zero_grad()

        #更新critic
        loss_critic = get_loss_critic(value_new, value_old, delta,
                                      generate_mask)
        accelerator.backward(loss_critic)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
        optimizer_critic.step()
        optimizer_critic.zero_grad()

    return loss_actor.item(), loss_critic.item()


train(generate, generate_mask, prob_old, prob_ref, value_old, reward)

(-0.7142879962921143, 9.288569450378418)

In [14]:
for i, data in enumerate(loader):
    #生成数据
    (generate, generate_mask, prob_old, prob_ref, value_old,
     reward) = get_batch(**data)

    #训练
    loss_actor, loss_critic = train(generate, generate_mask, prob_old,
                                    prob_ref, value_old, reward)

    if (i + 1) % 100 == 0:
        print(i, len(loader), loss_actor, loss_critic, reward[0].item(),
              cache_count)

        #print(tokenizer.decode(generate[0, prompt_lens:]))
        start = generate[0].tolist().index(tokenizer.bos_token_id) + 1
        print(tokenizer.decode(generate[0, start:]))

torch.save(model_actor.to('cpu'), 'model/rlhf')

99 7144 0.2392578125 0.7209304571151733 5.492685317993164 0
Human: context= CREATE TABLE table_28232443_1 (producer__s_ VARCHAR, song__s_ VARCHAR) question= how many producers are responsible for the song 'calling out to marlboro? Assistant:select count(PRODUCTION__PENGLISH_) from TABLE_28215743_1 where SONG__S_ = "lUIS sANDRIA"<|end_of_text|>
199 7144 0.1526360660791397 0.6175141334533691 8.53221607208252 0
Human: context= CREATE TABLE table_26041144_10 (innings INTEGER, average VARCHAR) question= How many innings are there when the average is 32.3? Assistant:select max(INNSY) from TABLE_26041144_10 where AVERAGE = "32"<|end_of_text|>
299 7144 0.10185553878545761 0.02755236066877842 8.451120376586914 0
Human: context= CREATE TABLE table_name_39 (blank_ends INTEGER, nation VARCHAR, ends_lost VARCHAR) question= What is the sum of Blank Ends for Denmark when less than 42 is the ends lost? Assistant:select sum(BYES) from TABLE_NAME_39 where NATION = "NEW ZEAL" and ENDS_LOST < 42<|end_of_t

2699 7144 0.021513676270842552 0.01845403015613556 7.407927989959717 0
Human: context= CREATE TABLE table_name_27 (result VARCHAR, week VARCHAR, date VARCHAR) question= What was the Result after the Week 4 on November 3, 1968? Assistant:select RESULT from TABLE_NAME_27 where WEEK > 4 and DATE = "NOVEMBER 3, 1968"<|end_of_text|>
2799 7144 -0.0379653200507164 0.0025692156050354242 6.735535144805908 0
Human: context= CREATE TABLE table_18161217_2 (launch_date VARCHAR, cospar_id VARCHAR) question= What was the launch date the satellite with cospar ID is 2008-033A? Assistant:select LAUNCH_DATE from TABLE_18173917_2 where VACED = "qUARY 2008"<|end_of_text|>
2899 7144 0.08972104638814926 0.005396728403866291 7.484698295593262 0
Human: context= CREATE TABLE table_1341884_20 (candidates VARCHAR, first_elected VARCHAR) question= Who were all the candidates when the first elected year was 1961? Assistant:select CANDIDATES from TABLE_1341884_20 where FIRST_ELECTED = 1961<|end_of_text|>
2999 7144 0

5299 7144 -0.011420010589063168 0.0006772717460989952 8.796844482421875 0
Human: context= CREATE TABLE table_name_67 (founded INTEGER, institution VARCHAR) question= What is the largest Founded with an Institution of cloud county community college? Assistant:select max(FOUNDED) from TABLE_NAME_67 where INSTITUTION = "SOUTHERN MISSISSIPPI STATE"<|end_of_text|>
5399 7144 0.03705424815416336 0.0014625918120145798 9.323482513427734 0
Human: context= CREATE TABLE table_name_68 (record VARCHAR, high_rebounds VARCHAR) question= What is Record, when High Rebounds is "Tyson Chandler (6)"? Assistant:select RECORD from TABLE_NAME_68 where HIGH_REBOUNDS = "LUIS SCOLA (6)"<|end_of_text|>
5499 7144 0.040390919893980026 0.0015183702344074845 10.23711109161377 0
Human: context= CREATE TABLE table_name_99 (goal_difference INTEGER, played INTEGER) question= How many goal differences have Played larger than 44? Assistant:select sum(GOAL_DIFFERENCE) from TABLE_NAME_99 where PLAYED > 44<|end_of_text|>
5599