In [1]:
import torch

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)

input_ids, attention_mask, tokenizer.decode(input_ids)

  from .autonotebook import tqdm as notebook_tqdm


(tensor([   0, 9178,   32,    2]), tensor([1, 1, 1, 1]), '<s>how are</s>')

In [2]:
from datasets import load_dataset

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#2,4,4切分,取第1部分
dataset = dataset.select(range(15000, 45000))


def f(data):
    #区分两种生成结果
    chosen = data['prompt'] + data['chosen'].swapcase()
    rejected = data['prompt'] + data['chosen']

    chosen_input_ids, chosen_attention_mask = tokenizer.encode(chosen)
    rejected_input_ids, rejected_attention_mask = tokenizer.encode(rejected)

    return {
        'chosen_input_ids': chosen_input_ids,
        'chosen_attention_mask': chosen_attention_mask,
        'rejected_input_ids': rejected_input_ids,
        'rejected_attention_mask': rejected_attention_mask
    }


dataset = dataset.map(f)
dataset.set_format('torch')


def f(data):
    chosen_input_ids = [i['chosen_input_ids'] for i in data]
    chosen_attention_mask = [i['chosen_attention_mask'] for i in data]
    rejected_input_ids = [i['rejected_input_ids'] for i in data]
    rejected_attention_mask = [i['rejected_attention_mask'] for i in data]

    input_ids = torch.stack(chosen_input_ids + rejected_input_ids, dim=0)
    attention_mask = torch.stack(chosen_attention_mask +
                                 rejected_attention_mask,
                                 dim=0)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=f,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader))

(7500,
 {'input_ids': tensor([[    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1],
          ...,
          [    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]])})

In [3]:
class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        from transformers import AutoModel

        self.rwtransformer = AutoModel.from_pretrained(
            'model/facebook/opt-125m', dropout=0.0)
        self.v_head = torch.nn.Linear(768, 1, bias=False)

    def forward(self, input_ids, attention_mask):
        value = self.rwtransformer(
            input_ids=input_ids,
            attention_mask=attention_mask).last_hidden_state

        value = self.v_head(value).squeeze(-1)

        loss_sum = 0.0
        value_chosen_sum = 0.0
        value_rejected_sum = 0.0
        for input_ids_chosen, input_ids_rejected, value_chosen, value_rejected in zip(
                input_ids[:4], input_ids[4:], value[:4], value[4:]):

            #找出每条回答中的起止索引
            start = (
                input_ids_chosen == input_ids_rejected).tolist().index(False)

            end_chosen = input_ids_chosen.tolist().index(
                tokenizer.eos_token_id) + 1
            end_rejected = input_ids_rejected.tolist().index(
                tokenizer.eos_token_id) + 1
            end = max(end_chosen, end_rejected)

            value_chosen = value_chosen[start:end]
            value_rejected = value_rejected[start:end]

            loss = value_chosen - value_rejected
            loss = -torch.nn.functional.logsigmoid(loss).mean()

            loss_sum += loss
            value_chosen_sum += value_chosen.mean().item()
            value_rejected_sum += value_rejected.mean().item()

        return loss_sum / 4, value_chosen_sum, value_rejected_sum


model_critic = CriticModel()
model_critic.train()

CriticModel(
  (rwtransformer): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1

In [4]:
from accelerate import Accelerator

optimizer = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

accelerator = Accelerator(mixed_precision='fp16')

loader, model_critic, optimizer = accelerator.prepare(loader, model_critic,
                                                      optimizer)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [5]:
for i, data in enumerate(loader):
    loss, value_chosen_sum, value_rejected_sum = model_critic(**data)
    accelerator.backward(loss)
    accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if (i + 1) % 100 == 0:
        print(i, len(loader), loss.item(), value_chosen_sum,
              value_rejected_sum)

    if i == 1000:
        break

torch.save(model_critic.to('cpu'), 'model/critic')

99 7500 5.960464477539063e-08 34.6875 -33.75
199 7500 0.0 38.828125 -39.125
299 7500 5.960464477539063e-08 28.59765625 -39.90625
399 7500 0.0 28.6640625 -41.6171875
499 7500 0.0 29.0390625 -43.3046875
599 7500 0.0 29.0859375 -42.9296875
699 7500 0.0 29.61328125 -41.6015625
799 7500 0.0 29.40234375 -41.953125
899 7500 0.0 29.1484375 -42.5703125
999 7500 0.0 43.6171875 -38.1171875
