In [1]:
import torch

prompt_lens = 128
gen_lens = prompt_lens + 128

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, _ = tokenizer.encode('how are you', max_length=6)

input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

input_ids, attention_mask, tokenizer.decode(input_ids)

(tensor([   0,    0,    2, 1139,  708,  692]),
 tensor([0, 0, 1, 1, 1, 1]),
 '<pad><pad><bos>how are you')

In [2]:
from datasets import load_dataset
from transformers import default_data_collator

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#2,4,4切分,取最后一部分
dataset = dataset.select(range(45000, len(dataset)))


def f(data):
    input_ids, _ = tokenizer.encode(data['prompt'], max_length=prompt_lens)
    input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


dataset = dataset.map(f, remove_columns=dataset.column_names)

loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=default_data_collator,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader)).keys()

(7144, dict_keys(['input_ids', 'attention_mask']))

In [3]:
#%run 1.model.ipynb
%run 1.model_gemma2.ipynb

model_actor = torch.load('model/actor')
model_actor.train()

optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=2e-6)

In [4]:
class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.rwtransformer = None
        self.v_head = None

    def get_value(self, input_ids, attention_mask):
        value = self.rwtransformer(input_ids=input_ids,
                                   attention_mask=attention_mask)
        return self.v_head(value).squeeze(2)

    def get_reward(self, input_ids, attention_mask):
        value = self.get_value(input_ids, attention_mask)

        reward = []
        for i, v in zip(input_ids, value):
            end = input_ids.shape[1] - 1
            if tokenizer.eos_token_id in i:
                end = i.tolist().index(tokenizer.eos_token_id)
            reward.append(v[end])
        reward = torch.stack(reward)

        return reward


model_critic = torch.load('model/critic')
model_critic.train()

optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

In [5]:
from accelerate import Accelerator

model_ref = torch.load('model/actor')
model_reward = torch.load('model/critic')

model_ref.eval()
model_reward.eval()

accelerator = Accelerator(gradient_accumulation_steps=8,
                          mixed_precision='fp16')

(loader, model_actor, optimizer_actor, model_critic, optimizer_critic,
 model_ref, model_reward) = accelerator.prepare(loader, model_actor,
                                                optimizer_actor, model_critic,
                                                optimizer_critic, model_ref,
                                                model_reward)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
from util import get_generate as get_generate_util


def get_generate(input_ids):
    generate = get_generate_util(model_actor,
                                 input_ids,
                                 tokenizer.eos_token_id,
                                 tokenizer.pad_token_id,
                                 max_length=gen_lens)
    lens = (generate[:, prompt_lens:] != tokenizer.pad_token_id).sum(1)
    return generate[lens > 1]


data = next(iter(loader))

for i in get_generate(data['input_ids']):
    print(tokenizer.decode(i[prompt_lens:]))
    print('================')

select count(POPULATION__KM_2_) from TABLE_12584173_1 where POPULATION__SOUTH__KM_ = 2132<eos>
select max(GOALS_FOR) from TABLE_NAME_26 where LOST = 25<eos>
SELECT result FROM table_name_23 WHERE location = "1-0" AND score = "1-0"<eos>
SELECT eliminated AS nd_leg FROM table_25016824_2 WHERE eliminated = "8-1"<eos>


In [7]:
def get_prob(prob, index):
    prob = prob.log_softmax(dim=2)
    prob = prob.gather(dim=2, index=index.unsqueeze(2))
    return prob.squeeze(2)


get_prob(torch.randn(4, 123, 999), torch.randint(0, 999, (4, 123))).shape

torch.Size([4, 123])

In [8]:
last_generate = None
cache_count = 0


@torch.no_grad()
def get_batch(input_ids, attention_mask):
    #input_ids -> [b, prompt_lens]
    #attention_mask -> [b, prompt_lens]
    global last_generate
    global cache_count

    #根据问题生成回答
    #[b, gen_lens]
    generate = get_generate(input_ids)

    #制作缓存,防止所有回答为空的情况
    if len(generate):
        last_generate = generate
    else:
        generate = last_generate
        cache_count += 1

    #[b, gen_lens]
    generate_mask = (generate != tokenizer.pad_token_id).long()

    #两个模型分别取回答被预测到的概率
    #[b, gen_lens-1]
    _, prob_old = model_actor(input_ids=generate, attention_mask=generate_mask)
    prob_old = get_prob(prob_old[:, :-1], generate[:, 1:])

    #取每个词的value
    #[b, gen_lens-1]
    value_old = model_critic.get_value(generate, generate_mask)[:, :-1]

    #[b, gen_lens-1]
    _, prob_ref = model_ref(input_ids=generate, attention_mask=generate_mask)
    prob_ref = get_prob(prob_ref[:, :-1], generate[:, 1:])

    #取回答的分数
    #[b]
    reward = model_reward.get_reward(generate, generate_mask)

    return generate, generate_mask, prob_old, prob_ref, value_old, reward


generate, generate_mask, prob_old, prob_ref, value_old, reward = get_batch(
    **data)

generate.shape, generate_mask.shape, prob_old.shape, prob_ref.shape, value_old.shape, reward.shape

(torch.Size([4, 165]),
 torch.Size([4, 165]),
 torch.Size([4, 164]),
 torch.Size([4, 164]),
 torch.Size([4, 164]),
 torch.Size([4]))

In [9]:
def get_reward_kl(end, prob_old, prob_ref, reward):
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #reward -> [b]

    #两份预测概率求kl散度
    #[b, gen_lens-1]
    reward_kl = -0.1 * (prob_old - prob_ref)

    #把原本的reward加在kl散度的最后一个字上
    for i, e in enumerate(end):
        if e >= reward_kl.shape[1]:
            e = -1
        reward_kl[i, e] += reward[i].clamp(-5, 5)

    #[b, gen_lens-1]
    return reward_kl


end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
end = end.tolist()

reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

reward_kl.shape

torch.Size([4, 164])

In [10]:
#解释见原版代码中的get_delta_note函数
def get_delta(value_old, reward_kl):
    #value_old -> [b, gen_lens-1]
    #reward_kl -> [b, gen_lens-1]

    #gen_lens-2 -> prompt_lens-1
    delta = []
    for i in reversed(range(prompt_lens - 1, value_old.shape[1])):
        #[b]
        value_next = 0.0
        if i != value_old.shape[1] - 1:
            value_next = value_old[:, i + 1]

        #[b]
        d = reward_kl[:, i] + value_next - value_old[:, i]
        if len(delta):
            d += 0.95 * delta[-1]
        delta.append(d)

    #[b, gen_lens-prompt_lens]
    delta = torch.stack(delta[::-1], dim=1)

    return delta


delta = get_delta(value_old, reward_kl)

delta.shape

torch.Size([4, 37])

In [11]:
def get_loss_actor(prob_new, prob_old, delta, generate_mask):
    prob_new = prob_new[:, prompt_lens - 1:]
    prob_old = prob_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #prob_new -> [b, gen_lens-prompt_lens]
    #prob_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #对数概率,求差就是求商,所以这里求的是新旧概率的变化率
    #[b, gen_lens-prompt_lens]
    ratio = ((prob_new - prob_old) * generate_mask).exp()

    #delta是估计出来的去基线Q值,以变化率来缩放Q值
    #最大化Q值,以此来寻找最优的actor
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = delta * ratio
    loss2 = delta * ratio.clamp(0.8, 1.2)
    loss = torch.min(loss1, loss2) * generate_mask
    loss = loss.sum() / generate_mask.sum()
    return -loss


loss_actor = get_loss_actor(prob_old, prob_old, delta, generate_mask)

loss_actor

tensor(0.9077, device='cuda:0')

In [12]:
def get_loss_critic(value_new, value_old, delta, generate_mask):
    value_new = value_new[:, prompt_lens - 1:]
    value_old = value_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #value_new -> [b, gen_lens-prompt_lens]
    #value_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #delta是估计出来的去基线Q值,加上value_old后还原为Q值
    #value_new和Q值求mse loss即可,因为value都是对Q函数的估计
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = (value_new - delta - value_old)**2
    value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)
    loss2 = (value_new - delta - value_old)**2

    #求平均
    loss = torch.max(loss1, loss2) * generate_mask
    loss = loss.sum() / 2 / generate_mask.sum()

    return loss


loss_critic = get_loss_critic(value_old, value_old, delta, generate_mask)

loss_critic

tensor(5.1135, device='cuda:0')

In [13]:
def train(generate, generate_mask, prob_old, prob_ref, value_old, reward):
    #generate -> [b, gen_lens]
    #generate_mask -> [b, gen_lens]
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #value_old -> [b, gen_lens-1]
    #reward -> [b]

    #求出每句话结束的索引
    #[b]
    end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
    end = end.tolist()

    #结束以后的value归零
    for i, e in enumerate(end):
        value_old[i, e + 1:] = 0

    with torch.no_grad():
        #计算新旧概率的kl散度,再把reward加在最后一个字上
        #[b, gen_lens-1]
        reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

        #估计去基线的Q值
        #[b, gen_lens-prompt_lens]
        delta = get_delta(value_old, reward_kl)

    #重新计算回答被生成的概率
    #[b, gen_lens-1]
    _, prob_new = model_actor(input_ids=generate, attention_mask=generate_mask)
    prob_new = get_prob(prob_new[:, :-1], generate[:, 1:])

    #重新计算每个词的value
    #[b, gen_lens-1]
    value_new = model_critic.get_value(input_ids=generate,
                                       attention_mask=generate_mask)[:, :-1]

    with accelerator.accumulate(model_actor, model_critic):
        #更新actor
        loss_actor = get_loss_actor(prob_new, prob_old, delta, generate_mask)
        accelerator.backward(loss_actor)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_actor.parameters(), 1.0)
        optimizer_actor.step()
        optimizer_actor.zero_grad()

        #更新critic
        loss_critic = get_loss_critic(value_new, value_old, delta,
                                      generate_mask)
        accelerator.backward(loss_critic)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
        optimizer_critic.step()
        optimizer_critic.zero_grad()

    return loss_actor.item(), loss_critic.item()


train(generate, generate_mask, prob_old, prob_ref, value_old, reward)

(0.3602048456668854, 1.5423953533172607)

In [14]:
for i, data in enumerate(loader):
    #生成数据
    (generate, generate_mask, prob_old, prob_ref, value_old,
     reward) = get_batch(**data)

    #训练
    loss_actor, loss_critic = train(generate, generate_mask, prob_old,
                                    prob_ref, value_old, reward)

    if i % 100 == 0:
        print(i, len(loader), loss_actor, loss_critic, reward[0].item(),
              cache_count)

        #print(tokenizer.decode(generate[0, prompt_lens:]))
        start = generate[0].tolist().index(tokenizer.bos_token_id) + 1
        print(tokenizer.decode(generate[0, start:]))

torch.save(model_actor.to('cpu'), 'model/rlhf')

0 7144 0.7449347972869873 2.1149866580963135 4.167113780975342 0
Human: context= CREATE TABLE table_name_94 (top_25 INTEGER, events VARCHAR, wins VARCHAR) question= What is the lowest Top-25 that has 3 Events and Wins greater than 0? Assistant:select min(TOP_25) from TABLE_NAME_94 where WINS = 3 and WINS > 0<eos>
100 7144 0.23885589838027954 0.20784251391887665 4.051708698272705 0
Human: context= CREATE TABLE table_name_71 (rider VARCHAR, time_retired VARCHAR, laps VARCHAR, grid VARCHAR) question= Who was the rider for laps less than 23 with grid greater than 21 that had a time of +1 lap? Assistant:select RIDER from TABLE_NAME_71 where TIME_RETIRED = "+1" and GRID < 21 and TIME_RETIRED = "+1"<eos>
200 7144 -0.06361590325832367 0.05183442309498787 4.187600612640381 0
Human: context= CREATE TABLE table_name_34 (date VARCHAR, circuit VARCHAR, winning_manufacturer VARCHAR) question= What was the date of Circuit Hockenheimring and the winning manufacturer being Mercedes-Benz? Assistant:sele

2600 7144 -0.001319731236435473 0.0007735553081147373 4.113719940185547 0
Human: context= CREATE TABLE table_name_89 (player VARCHAR, to_par VARCHAR, score VARCHAR) question= What player has a To Par of +1 with a score of 71-67-73=211? Assistant:select PLAYER from TABLE_NAME_89 where TO_PAR = "+1-673" and SCORE = 71 - 67 - 73 = 211<eos>
2700 7144 -0.007414973806589842 0.0006106102373450994 4.163166046142578 0
Human: context= CREATE TABLE table_2668378_5 (first_elected VARCHAR, district VARCHAR) question= Name the first elected for kentucky 1 Assistant:select FIRST_ELECTED from TABLE_2668378_5 where DISTRICT = "sOUTH cOLINA 1"<eos>
2800 7144 0.01927090622484684 0.007004295010119677 4.184329986572266 0
Human: context= CREATE TABLE table_name_79 (ihsaa_football_class VARCHAR, ihsaa_class VARCHAR, location VARCHAR) question= Which IHSAA Football Class has a IHSAA Class of aaa, and a Location of nashville? Assistant:select IHSAA_CLASS from TABLE_NAME_79 where IHSAA_CLASS = "KESPERSON" and L

5400 7144 0.0008819122449494898 0.000703799887560308 4.171140670776367 0
Human: context= CREATE TABLE table_name_58 (to_par VARCHAR, score VARCHAR, country VARCHAR) question= What is Top Par, when Score is less than 68, and when Country is England? Assistant:select TO_PAR from TABLE_NAME_58 where SCORE < 68 and COUNTRY = "UNITED STATES"<eos>
5500 7144 0.03388778865337372 0.003403476672247052 4.012670516967773 0
Human: context= CREATE TABLE table_25800134_17 (writer_s_ VARCHAR, airdate VARCHAR) question= Who was the writer who wrote the episode that was aired on September 11, 1972? Assistant:select WRITER_S_ from TABLE_25800134_17 where AIRDATE = "sOUTH cOL 11, 1972"<eos>
5600 7144 0.01309040654450655 0.0009383222204633057 4.168878555297852 0
Human: context= CREATE TABLE table_24765815_1 (player VARCHAR, opponent VARCHAR) question= Which player had an opponent of Blackburn Rovers? Assistant:select PLAYER from TABLE_24765815_1 where OPPONENT = "jANUARY"<eos>
5700 7144 -0.0151134328916668