In [1]:
import torch

prompt_lens = 128
gen_lens = prompt_lens + 128

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, _ = tokenizer.encode('how are you', max_length=6)

input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

input_ids, attention_mask, tokenizer.decode(input_ids)

  from .autonotebook import tqdm as notebook_tqdm


(tensor([   1,    1,    0, 9178,   32,   47]),
 tensor([0, 0, 1, 1, 1, 1]),
 '<pad><pad><s>how are you')

In [2]:
from datasets import load_dataset
from transformers import default_data_collator

dataset = load_dataset(
    'json',
    data_files='/root/code/DeepSpeed-Chat_my/data/train.json',
    split='train')

#2,4,4切分,取最后一部分
dataset = dataset.select(range(45000, len(dataset)))


def f(data):
    input_ids, _ = tokenizer.encode(data['prompt'], max_length=prompt_lens)
    input_ids, attention_mask = tokenizer.pad_to_left(input_ids)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


dataset = dataset.map(f, remove_columns=dataset.column_names)

loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=default_data_collator,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader))

(7144,
 {'input_ids': tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     0, 33837,    35,  5377,
            5214, 28122,  8625, 41910,  2103,  1215, 13650,  1215,  5067,    36,
           31673,   468, 42499,  2747,     6,   869,   468, 42499,  2747,    43,
             864,  5214,   653,    16, 14702,     6,    77,  8251,    16,    22,
        

In [3]:
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

model_actor = AutoModelForCausalLM.from_pretrained('model/actor')
model_actor = get_peft_model(
    model_actor,
    LoraConfig(inference_mode=False,
               r=128,
               lora_alpha=128,
               target_modules=[
                   'q_proj', 'k_proj', 'v_proj', 'fc1', 'fc2', 'out_proj'
               ]))
model_actor.train()

optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=1e-5)

model_actor.print_trainable_parameters()


BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=117


  warn((f'\n\n{"="*80}\n'


trainable params: 56,623,104 || all params: 387,819,520 || trainable%: 14.600374937290418


In [4]:
class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.rwtransformer = None
        self.v_head = None

    def get_value(self, input_ids, attention_mask):
        value = self.rwtransformer(
            input_ids=input_ids,
            attention_mask=attention_mask).last_hidden_state
        return self.v_head(value).squeeze(2)

    def get_reward(self, input_ids, attention_mask):
        value = self.get_value(input_ids, attention_mask)

        reward = []
        for i, v in zip(input_ids, value):
            end = input_ids.shape[1] - 1
            if tokenizer.eos_token_id in i:
                end = i.tolist().index(tokenizer.eos_token_id)
            reward.append(v[end])
        reward = torch.stack(reward)

        return reward


model_critic = torch.load('model/critic')
model_critic.train()

optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

In [5]:
from accelerate import Accelerator

model_ref = AutoModelForCausalLM.from_pretrained('model/actor')
model_reward = torch.load('model/critic')

model_ref.eval()
model_reward.eval()

accelerator = Accelerator(gradient_accumulation_steps=1,
                          mixed_precision='fp16')

(loader, model_actor, optimizer_actor, model_critic, optimizer_critic,
 model_ref, model_reward) = accelerator.prepare(loader, model_actor,
                                                optimizer_actor, model_critic,
                                                optimizer_critic, model_ref,
                                                model_reward)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
@torch.no_grad()
def get_generate(input_ids, attention_mask):
    generate = model_actor.generate(input_ids,
                                    attention_mask=attention_mask,
                                    max_length=gen_lens,
                                    pad_token_id=tokenizer.pad_token_id,
                                    eos_token_id=tokenizer.eos_token_id)

    lens = (generate[:, prompt_lens:] != tokenizer.pad_token_id).sum(1)

    return generate[lens > 1]


data = next(iter(loader))

get_generate(**data).shape

torch.Size([4, 159])

In [7]:
def get_prob(prob, index):
    prob = prob.log_softmax(dim=2)
    prob = prob.gather(dim=2, index=index.unsqueeze(2))
    return prob.squeeze(2)


get_prob(torch.randn(4, 123, 999), torch.randint(0, 999, (4, 123))).shape

torch.Size([4, 123])

In [8]:
last_generate = None


@torch.no_grad()
def get_batch(input_ids, attention_mask):
    #input_ids -> [b, prompt_lens]
    #attention_mask -> [b, prompt_lens]
    global last_generate

    #根据问题生成回答
    #[b, gen_lens]
    generate = get_generate(input_ids, attention_mask)

    #制作缓存,防止所有回答为空的情况
    if len(generate):
        last_generate = generate
    else:
        generate = last_generate

    #[b, gen_lens]
    generate_mask = (generate != tokenizer.pad_token_id).long()

    #两个模型分别取回答被预测到的概率
    #[b, gen_lens-1]
    prob_old = model_actor(input_ids=generate,
                           attention_mask=generate_mask).logits
    prob_old = get_prob(prob_old[:, :-1], generate[:, 1:])

    #取每个词的value
    #[b, gen_lens-1]
    value_old = model_critic.get_value(generate, generate_mask)[:, :-1]

    #[b, gen_lens-1]
    prob_ref = model_ref(input_ids=generate,
                         attention_mask=generate_mask).logits
    prob_ref = get_prob(prob_ref[:, :-1], generate[:, 1:])

    #取回答的分数
    #[b]
    reward = model_reward.get_reward(generate, generate_mask)

    return generate, generate_mask, prob_old, prob_ref, value_old, reward


generate, generate_mask, prob_old, prob_ref, value_old, reward = get_batch(
    **data)

generate.shape, generate_mask.shape, prob_old.shape, prob_ref.shape, value_old.shape, reward.shape

(torch.Size([4, 159]),
 torch.Size([4, 159]),
 torch.Size([4, 158]),
 torch.Size([4, 158]),
 torch.Size([4, 158]),
 torch.Size([4]))

In [9]:
def get_reward_kl(end, prob_old, prob_ref, reward):
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #reward -> [b]

    #两份预测概率求kl散度
    #[b, gen_lens-1]
    reward_kl = -0.1 * (prob_old - prob_ref)

    #把原本的reward加在kl散度的最后一个字上
    for i, e in enumerate(end):
        if e >= reward_kl.shape[1]:
            e = -1
        reward_kl[i, e] += reward[i].clamp(-5, 5)

    #[b, gen_lens-1]
    return reward_kl


end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
end = end.tolist()

reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

reward_kl.shape

torch.Size([4, 158])

In [10]:
#解释见原版代码中的get_delta_note函数
def get_delta(value_old, reward_kl):
    #value_old -> [b, gen_lens-1]
    #reward_kl -> [b, gen_lens-1]

    #gen_lens-2 -> prompt_lens-1
    delta = []
    for i in reversed(range(prompt_lens - 1, value_old.shape[1])):
        #[b]
        value_next = 0.0
        if i != value_old.shape[1] - 1:
            value_next = value_old[:, i + 1]

        #[b]
        d = reward_kl[:, i] + value_next - value_old[:, i]
        if len(delta):
            d += 0.95 * delta[-1]
        delta.append(d)

    #[b, gen_lens-prompt_lens]
    delta = torch.stack(delta[::-1], dim=1)

    return delta


delta = get_delta(value_old, reward_kl)

delta.shape

torch.Size([4, 31])

In [11]:
def get_loss_actor(prob_new, prob_old, delta, generate_mask):
    prob_new = prob_new[:, prompt_lens - 1:]
    prob_old = prob_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #prob_new -> [b, gen_lens-prompt_lens]
    #prob_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #对数概率,求差就是求商,所以这里求的是新旧概率的变化率
    #[b, gen_lens-prompt_lens]
    ratio = ((prob_new - prob_old) * generate_mask).exp()

    #delta是估计出来的去基线Q值,以变化率来缩放Q值
    #最大化Q值,以此来寻找最优的actor
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = delta * ratio
    loss2 = delta * ratio.clamp(0.8, 1.2)
    loss = torch.min(loss1, loss2) * generate_mask
    loss = loss.sum() / generate_mask.sum()
    return -loss


loss_actor = get_loss_actor(prob_old, prob_old, delta, generate_mask)

loss_actor

tensor(-0.2562, device='cuda:0')

In [12]:
def get_loss_critic(value_new, value_old, delta, generate_mask):
    value_new = value_new[:, prompt_lens - 1:]
    value_old = value_old[:, prompt_lens - 1:]
    generate_mask = generate_mask[:, prompt_lens:]

    #value_new -> [b, gen_lens-prompt_lens]
    #value_old -> [b, gen_lens-prompt_lens]
    #delta -> [b, gen_lens-prompt_lens]
    #generate_mask -> [b, gen_lens-prompt_lens]

    #delta是估计出来的去基线Q值,加上value_old后还原为Q值
    #value_new和Q值求mse loss即可,因为value都是对Q函数的估计
    #裁剪,防止自举
    #[b, gen_lens-prompt_lens]
    loss1 = (value_new - delta - value_old)**2
    value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)
    loss2 = (value_new - delta - value_old)**2

    #求平均
    loss = torch.max(loss1, loss2) * generate_mask
    loss = loss.sum() / 2 / generate_mask.sum()

    return loss


loss_critic = get_loss_critic(value_old, value_old, delta, generate_mask)

loss_critic

tensor(11.4849, device='cuda:0')

In [13]:
def train(generate, generate_mask, prob_old, prob_ref, value_old, reward):
    #generate -> [b, gen_lens]
    #generate_mask -> [b, gen_lens]
    #prob_old -> [b, gen_lens-1]
    #prob_ref -> [b, gen_lens-1]
    #value_old -> [b, gen_lens-1]
    #reward -> [b]

    #求出每句话结束的索引
    #[b]
    end = generate_mask[:, prompt_lens:].sum(1) + prompt_lens - 1
    end = end.tolist()

    #结束以后的value归零
    for i, e in enumerate(end):
        value_old[i, e + 1:] = 0

    with torch.no_grad():
        #计算新旧概率的kl散度,再把reward加在最后一个字上
        #[b, gen_lens-1]
        reward_kl = get_reward_kl(end, prob_old, prob_ref, reward)

        #估计去基线的Q值
        #[b, gen_lens-prompt_lens]
        delta = get_delta(value_old, reward_kl)

    #重新计算回答被生成的概率
    #[b, gen_lens-1]
    prob_new = model_actor(input_ids=generate,
                           attention_mask=generate_mask).logits
    prob_new = get_prob(prob_new[:, :-1], generate[:, 1:])

    #重新计算每个词的value
    #[b, gen_lens-1]
    value_new = model_critic.get_value(input_ids=generate,
                                       attention_mask=generate_mask)[:, :-1]

    with accelerator.accumulate(model_actor, model_critic):
        #更新actor
        loss_actor = get_loss_actor(prob_new, prob_old, delta, generate_mask)
        accelerator.backward(loss_actor)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_actor.parameters(), 1.0)
        optimizer_actor.step()
        optimizer_actor.zero_grad()

        #更新critic
        loss_critic = get_loss_critic(value_new, value_old, delta,
                                      generate_mask)
        accelerator.backward(loss_critic)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
        optimizer_critic.step()
        optimizer_critic.zero_grad()

        return loss_actor.item(), loss_critic.item()


train(generate, generate_mask, prob_old, prob_ref, value_old, reward)

(-1.7823232412338257, 13.036150932312012)

In [14]:
for i, data in enumerate(loader):
    #生成数据
    (generate, generate_mask, prob_old, prob_ref, value_old,
     reward) = get_batch(**data)

    #训练
    loss_actor, loss_critic = train(generate, generate_mask, prob_old,
                                    prob_ref, value_old, reward)

    if (i + 1) % 50 == 0:
        print(i, len(loader), loss_actor, loss_critic, reward[0].item())
        print(tokenizer.decode(generate[0, prompt_lens:]))

    if i == 2500:
        break

model_actor.merge_and_unload().save_pretrained('model/rlhf')

49 7144 0.0432884581387043 0.0044492753222584724 11.086127281188965
select max(GOLD) from TABLE_NAME_70 where BRONZE = "LOST" and RANK > 5</s>
99 7144 -0.04969629645347595 0.0019867706578224897 11.038503646850586
select min(BRONZE) from TABLE_NAME_5 where RANK = "12" and GOLD < 2</s>
149 7144 0.0030154113192111254 0.00020761281484737992 11.152334213256836
select count(WEEK_13_nov_26) from TABLE_NAME_18 where WEEK_14_DEC_3 = "MARY (10-1)"</s>
199 7144 0.009740713983774185 0.0006860442808829248 11.087888717651367
select PLAYER from TABLE_10812938_5 where FL_TEAM = "mART ALOUETES"</s>
249 7144 -0.03482872620224953 0.0013878606259822845 10.923341751098633
select VISITOR from TABLE_NAME_20 where HOME = "LOS Angeles"</s>
299 7144 -0.0376472994685173 0.0012580612674355507 11.025354385375977
select FIRST_NAME from REVIEW_Notes where TEACHER_ID = "3"</s>
349 7144 0.07178817689418793 0.003134319558739662 11.01570987701416
select avg(GRID) from TABLE_NAME_29 where DRIVER = "CLAY REGAZONI"</s>
399