In [1]:
import torch

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)

input_ids, attention_mask, tokenizer.decode(input_ids)

  from .autonotebook import tqdm as notebook_tqdm


(tensor([   0, 9178,   32,    2]), tensor([1, 1, 1, 1]), '<s>how are</s>')

In [2]:
from datasets import load_dataset

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#2,4,4切分,取第1部分
dataset = dataset.select(range(15000, 45000))


def f(data):
    #区分两种生成结果
    chosen = data['prompt'] + data['chosen'].swapcase()
    rejected = data['prompt'] + data['chosen']

    chosen_input_ids, chosen_attention_mask = tokenizer.encode(chosen)
    rejected_input_ids, rejected_attention_mask = tokenizer.encode(rejected)

    return {
        'chosen_input_ids': chosen_input_ids,
        'chosen_attention_mask': chosen_attention_mask,
        'rejected_input_ids': rejected_input_ids,
        'rejected_attention_mask': rejected_attention_mask
    }


dataset = dataset.map(f)
dataset.set_format('torch')


def f(data):
    chosen_input_ids = [i['chosen_input_ids'] for i in data]
    chosen_attention_mask = [i['chosen_attention_mask'] for i in data]
    rejected_input_ids = [i['rejected_input_ids'] for i in data]
    rejected_attention_mask = [i['rejected_attention_mask'] for i in data]

    input_ids = torch.stack(chosen_input_ids + rejected_input_ids, dim=0)
    attention_mask = torch.stack(chosen_attention_mask +
                                 rejected_attention_mask,
                                 dim=0)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=f,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader))

(7500,
 {'input_ids': tensor([[    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1],
          ...,
          [    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]])})

In [3]:
from lora import count_params


class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        from transformers import AutoModel
        self.rwtransformer = AutoModel.from_pretrained('facebook/opt-350m',
                                                       dropout=0.0)

        self.v_head = torch.nn.Linear(512, 1, bias=False)

    def forward(self, input_ids, attention_mask):
        value = self.rwtransformer(
            input_ids=input_ids,
            attention_mask=attention_mask).last_hidden_state
        value = self.v_head(value).squeeze(-1)

        loss_sum = 0.0
        value_chosen_sum = 0.0
        value_rejected_sum = 0.0
        for input_ids_chosen, input_ids_rejected, value_chosen, value_rejected in zip(
                input_ids[:4], input_ids[4:], value[:4], value[4:]):

            #找出每条回答中的起止索引
            start = (
                input_ids_chosen == input_ids_rejected).tolist().index(False)

            end_chosen = input_ids_chosen.tolist().index(
                tokenizer.eos_token_id) + 1
            end_rejected = input_ids_rejected.tolist().index(
                tokenizer.eos_token_id) + 1
            end = max(end_chosen, end_rejected)

            value_chosen = value_chosen[start:end]
            value_rejected = value_rejected[start:end]

            loss = value_chosen - value_rejected
            loss = -torch.nn.functional.logsigmoid(loss).mean()

            loss_sum += loss
            value_chosen_sum += value_chosen.mean().item()
            value_rejected_sum += value_rejected.mean().item()

        return loss_sum / 4, value_chosen_sum, value_rejected_sum


model_critic = CriticModel()

count_params(model_critic)

{'count_require': 3.31196928, 'count_all': 3.31196928, 'ratio': 1.0}


In [4]:
from transformers import get_scheduler
from accelerate import Accelerator


def f():
    params_decay = []
    params = []
    for name, param in model_critic.named_parameters():
        if 'bias' in name or 'norm.weight' in name:
            params.append(param)
            continue
        params_decay.append(param)

    return [{
        'params': params_decay,
        'weight_decay': 0.1
    }, {
        'params': params,
        'weight_decay': 0.0
    }]


optimizer = torch.optim.Adam(f(), lr=5e-5, betas=(0.9, 0.95))

scheduler = get_scheduler(name='cosine',
                          optimizer=optimizer,
                          num_warmup_steps=0,
                          num_training_steps=500)

accelerator = Accelerator(gradient_accumulation_steps=16,
                          mixed_precision='fp16')

model_critic, loader, optimizer, scheduler = accelerator.prepare(
    model_critic, loader, optimizer, scheduler)

model_critic.train()

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


CriticModel(
  (rwtransformer): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, 

In [5]:
for i, data in enumerate(loader):
    with accelerator.accumulate(model_critic):
        loss, value_chosen_sum, value_rejected_sum = model_critic(**data)
        accelerator.backward(loss)

        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    if (i + 1) % 100 == 0:
        lr = optimizer.param_groups[0]['lr']
        print(i, len(loader), loss.item(), lr, value_chosen_sum,
              value_rejected_sum)

    if i == 2000:
        break

torch.save(model_critic.to('cpu'), 'model/critic')

99 7500 0.109375 4.998223681601473e-05 15.48046875 -0.436431884765625
199 7500 0.004413604736328125 4.992897250651535e-05 24.7734375 -7.7060546875
299 7500 0.0005574226379394531 4.984028276300021e-05 30.44140625 -15.005859375
399 7500 0.00021004676818847656 4.9692208514878444e-05 30.9296875 -20.26171875
499 7500 0.003726959228515625 4.952726293608335e-05 27.5546875 -12.6650390625
599 7500 0.0013942718505859375 4.9327462774553166e-05 23.83203125 -10.3603515625
699 7500 0.0019073486328125 4.909309195725025e-05 25.29296875 -14.05859375
799 7500 0.00017786026000976562 4.877641290737884e-05 24.765625 -19.1171875
899 7500 0.0009713172912597656 4.846834644134686e-05 21.21875 -15.98046875
999 7500 0.00013566017150878906 4.812693017086145e-05 27.33203125 -15.447265625
1099 7500 9.387731552124023e-05 4.775264926712489e-05 25.0234375 -17.279296875
1199 7500 0.00011008977890014648 4.72751631047092e-05 27.03515625 -14.982421875
1299 7500 0.0259857177734375 4.683156137024801e-05 24.6875 -7.521453857