In [1]:
import torch

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)

input_ids, attention_mask, tokenizer.decode(input_ids)

(tensor([   2, 1139,  708,    1]), tensor([1, 1, 1, 1]), '<bos>how are<eos>')

In [2]:
from datasets import load_dataset

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#4,2,4切分,取第1部分
dataset = dataset.select(range(30000, 45000))


def f(data):
    #区分两种生成结果
    chosen = data['prompt'] + data['chosen'].swapcase()
    rejected = data['prompt'] + data['chosen']

    chosen_input_ids, chosen_attention_mask = tokenizer.encode(chosen)
    rejected_input_ids, rejected_attention_mask = tokenizer.encode(rejected)

    return {
        'chosen_input_ids': chosen_input_ids,
        'chosen_attention_mask': chosen_attention_mask,
        'rejected_input_ids': rejected_input_ids,
        'rejected_attention_mask': rejected_attention_mask,
    }


dataset = dataset.map(f)

dataset.set_format('torch')


def f(data):
    chosen_input_ids = [i['chosen_input_ids'] for i in data]
    chosen_attention_mask = [i['chosen_attention_mask'] for i in data]
    rejected_input_ids = [i['rejected_input_ids'] for i in data]
    rejected_attention_mask = [i['rejected_attention_mask'] for i in data]

    input_ids = torch.stack(chosen_input_ids + rejected_input_ids, dim=0)
    attention_mask = torch.stack(chosen_attention_mask +
                                 rejected_attention_mask,
                                 dim=0)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=f,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader)).keys()

Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

(3750, dict_keys(['input_ids', 'attention_mask']))

In [3]:
#%run 1.model.ipynb
%run 1.model_gemma2.ipynb


class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.rwtransformer = LlamaModel()

        #换模型的时候,这个数字要改一下.用llama3时改成1024
        self.v_head = torch.nn.Linear(576, 1, bias=False)

    def forward(self, input_ids, attention_mask):
        value = self.rwtransformer(input_ids=input_ids,
                                   attention_mask=attention_mask)

        value = self.v_head(value).squeeze(-1)

        loss_sum = 0.0
        value_chosen_sum = 0.0
        value_rejected_sum = 0.0
        for input_ids_chosen, input_ids_rejected, value_chosen, value_rejected in zip(
                input_ids[:4], input_ids[4:], value[:4], value[4:]):

            #找出每条回答中的起止索引
            end_chosen = input_ids_chosen.tolist().index(
                tokenizer.eos_token_id) + 1
            end_rejected = input_ids_rejected.tolist().index(
                tokenizer.eos_token_id) + 1
            end = max(end_chosen, end_rejected)

            start = end - 1
            if not (input_ids_chosen == input_ids_rejected).all():
                start = (input_ids_chosen == input_ids_rejected
                         ).tolist().index(False)

            value_chosen = value_chosen[start:end]
            value_rejected = value_rejected[start:end]

            loss = value_chosen - value_rejected
            loss = -torch.nn.functional.logsigmoid(loss).mean()

            loss_sum += loss
            value_chosen_sum += value_chosen.mean().item()
            value_rejected_sum += value_rejected.mean().item()

        return loss_sum / 4, value_chosen_sum, value_rejected_sum


model_critic = CriticModel()

In [4]:
from accelerate import Accelerator

model_critic.train()

optimizer = torch.optim.Adam(model_critic.parameters(), lr=5e-5)

accelerator = Accelerator(mixed_precision='fp16')

loader, model_critic, optimizer = accelerator.prepare(loader, model_critic,
                                                      optimizer)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [5]:
for i, data in enumerate(loader):
    loss, value_chosen_sum, value_rejected_sum = model_critic(**data)
    accelerator.backward(loss)
    accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        print(i, len(loader), loss.item(), value_chosen_sum,
              value_rejected_sum)

torch.save(model_critic.to('cpu'), 'model/critic')

0 3750 0.6953125 0.7353515625 0.29979515075683594
100 3750 0.0036773681640625 14.0625 -16.330078125
200 3750 0.0005950927734375 15.017578125 -21.12109375
300 3750 0.0002713203430175781 17.578125 -22.16015625
400 3750 0.00020897388458251953 18.609375 -20.2890625
500 3750 0.000514984130859375 19.26171875 -20.86328125
600 3750 0.0001304149627685547 18.62890625 -21.9375
700 3750 0.00020122528076171875 19.0625 -22.0703125
800 3750 0.00011712312698364258 19.07421875 -21.11328125
900 3750 0.00011962652206420898 19.4765625 -23.65234375
1000 3750 0.00023543834686279297 19.6953125 -22.43359375
1100 3750 9.256601333618164e-05 20.734375 -22.06640625
1200 3750 0.0002968311309814453 20.15234375 -23.5390625
1300 3750 0.00021505355834960938 19.875 -23.1796875
1400 3750 0.00010675191879272461 20.78125 -24.28515625
1500 3750 0.001422882080078125 20.0 -22.27734375
1600 3750 4.750490188598633e-05 21.078125 -23.58203125
1700 3750 6.389617919921875e-05 20.85546875 -23.3203125
1800 3750 4.89354133605957e-05 