In [1]:
import torch

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)

input_ids, attention_mask, tokenizer.decode(input_ids)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(tensor([128000,   5269,    527, 128001]),
 tensor([1, 1, 1, 1]),
 '<|begin_of_text|>how are<|end_of_text|>')

In [2]:
from datasets import load_dataset

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#4,2,4切分,取第1部分
dataset = dataset.select(range(30000, 45000))


def f(data):
    #区分两种生成结果
    chosen = data['prompt'] + data['chosen'].swapcase()
    rejected = data['prompt'] + data['chosen']

    chosen_input_ids, chosen_attention_mask = tokenizer.encode(chosen)
    rejected_input_ids, rejected_attention_mask = tokenizer.encode(rejected)

    return {
        'chosen_input_ids': chosen_input_ids,
        'chosen_attention_mask': chosen_attention_mask,
        'rejected_input_ids': rejected_input_ids,
        'rejected_attention_mask': rejected_attention_mask,
    }


dataset = dataset.map(f)

dataset.set_format('torch')


def f(data):
    chosen_input_ids = [i['chosen_input_ids'] for i in data]
    chosen_attention_mask = [i['chosen_attention_mask'] for i in data]
    rejected_input_ids = [i['rejected_input_ids'] for i in data]
    rejected_attention_mask = [i['rejected_attention_mask'] for i in data]

    input_ids = torch.stack(chosen_input_ids + rejected_input_ids, dim=0)
    attention_mask = torch.stack(chosen_attention_mask +
                                 rejected_attention_mask,
                                 dim=0)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=f,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader)).keys()

(3750, dict_keys(['input_ids', 'attention_mask']))

In [3]:
%run 1.model.ipynb


class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.rwtransformer = LlamaModel()

        self.v_head = torch.nn.Linear(1024, 1, bias=False)

    def forward(self, input_ids, attention_mask):
        value = self.rwtransformer(input_ids=input_ids,
                                   attention_mask=attention_mask)

        value = self.v_head(value).squeeze(-1)

        loss_sum = 0.0
        value_chosen_sum = 0.0
        value_rejected_sum = 0.0
        for input_ids_chosen, input_ids_rejected, value_chosen, value_rejected in zip(
                input_ids[:4], input_ids[4:], value[:4], value[4:]):

            #找出每条回答中的起止索引
            end_chosen = input_ids_chosen.tolist().index(
                tokenizer.eos_token_id) + 1
            end_rejected = input_ids_rejected.tolist().index(
                tokenizer.eos_token_id) + 1
            end = max(end_chosen, end_rejected)

            start = end - 1
            if not (input_ids_chosen == input_ids_rejected).all():
                start = (input_ids_chosen == input_ids_rejected
                         ).tolist().index(False)

            value_chosen = value_chosen[start:end]
            value_rejected = value_rejected[start:end]

            loss = value_chosen - value_rejected
            loss = -torch.nn.functional.logsigmoid(loss).mean()

            loss_sum += loss
            value_chosen_sum += value_chosen.mean().item()
            value_rejected_sum += value_rejected.mean().item()

        return loss_sum / 4, value_chosen_sum, value_rejected_sum


model_critic = CriticModel()

In [4]:
from accelerate import Accelerator

model_critic.train()

optimizer = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

accelerator = Accelerator(mixed_precision='fp16')

loader, model_critic, optimizer = accelerator.prepare(loader, model_critic,
                                                      optimizer)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [5]:
for i, data in enumerate(loader):
    loss, value_chosen_sum, value_rejected_sum = model_critic(**data)
    accelerator.backward(loss)
    accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if (i + 1) % 100 == 0:
        print(i, len(loader), loss.item(), value_chosen_sum,
              value_rejected_sum)

torch.save(model_critic.to('cpu'), 'model/critic')

99 3750 2.8967857360839844e-05 25.953125 -50.3359375
199 3750 8.046627044677734e-06 46.1640625 -42.390625
299 3750 7.748603820800781e-07 20.90625 -61.40625
399 3750 5.960464477539063e-08 29.734375 -60.9296875
499 3750 3.5762786865234375e-07 24.10546875 -58.765625
599 3750 5.960464477539063e-08 29.8671875 -59.6015625
699 3750 0.0 32.01953125 -59.421875
799 3750 1.1920928955078125e-07 24.4296875 -57.96875
899 3750 0.0 37.28125 -59.109375
999 3750 0.0 39.6484375 -60.734375
1099 3750 0.0 38.625 -60.6953125
1199 3750 0.0 40.84375 -60.1953125
1299 3750 1.1920928955078125e-07 33.953125 -58.734375
1399 3750 0.0 37.8203125 -60.109375
1499 3750 0.0 36.140625 -59.578125
1599 3750 0.0 41.15625 -59.421875
1699 3750 0.0 36.8828125 -58.34375
1799 3750 0.0 38.1875 -59.6640625
1899 3750 0.0 39.9296875 -59.34375
1999 3750 0.0 43.03125 -60.1484375
2099 3750 0.0 40.0 -59.765625
2199 3750 0.0 38.6328125 -57.8046875
2299 3750 0.0 45.25 -60.1796875
2399 3750 0.0 43.34375 -58.8203125
2499 3750 0.0 43.5 -59.50