In [1]:
%run common.ipynb

tokenizer.decode(tokenizer.get_data(third_number=True))

'S97.85=-86.73+95.16+89.42E'

In [2]:
def get_batch_data():

    def pad(data, split, lens):
        #做个白板
        input_ids = torch.full((len(data), lens),
                               tokenizer.encoder['P'],
                               device=device)

        #往白板里黏贴数据
        for i, d in enumerate(data):
            input_ids[i, :len(d)] = torch.LongTensor(d)

        attention_mask = (input_ids != tokenizer.encoder['P']).long()

        #计算label
        label = input_ids.clone()
        for l, s in zip(label, split):
            #问题和pad的位置是-100
            l[:s] = -100
            l[l == tokenizer.encoder['P']] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label
        }

    #正确的问答
    choice = [tokenizer.get_data(third_number=True) for i in range(64)]

    #错误的回答简单地定义为空回答就可以了
    split = [i.index(tokenizer.encoder['=']) + 1 for i in choice]
    reject = [d[:s] for d, s in zip(choice, split)]
    reject = [i + [tokenizer.encoder['E']] for i in reject]

    #求最大长度
    lens = max([len(i) for i in choice])

    return pad(choice, split, lens), pad(reject, split, lens)


get_batch_data()

({'input_ids': tensor([[ 1, 15,  5,  ...,  8,  2,  0],
          [ 1, 15,  9,  ...,  4,  2,  0],
          [ 1, 15,  6,  ...,  8,  2,  0],
          ...,
          [ 1, 15,  5,  ..., 11,  2,  0],
          [ 1,  5,  5,  ...,  6, 11,  2],
          [ 1, 15, 12,  ...,  0,  0,  0]], device='cuda:0'),
  'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 0],
          [1, 1, 1,  ..., 1, 1, 0],
          [1, 1, 1,  ..., 1, 1, 0],
          ...,
          [1, 1, 1,  ..., 1, 1, 0],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
  'label': tensor([[-100, -100, -100,  ...,    8,    2, -100],
          [-100, -100, -100,  ...,    4,    2, -100],
          [-100, -100, -100,  ...,    8,    2, -100],
          ...,
          [-100, -100, -100,  ...,   11,    2, -100],
          [-100, -100, -100,  ...,    6,   11,    2],
          [-100, -100, -100,  ..., -100, -100, -100]], device='cuda:0')},
 {'input_ids': tensor([[ 1, 15,  5,  ...,  0,  0,  0],
          

In [3]:
model_dpo = torch.load('gen.model')
model_dpo.to(device)
model_dpo.train()

model_dpo_ref = torch.load('gen.model')
model_dpo_ref.to(device)
model_dpo_ref.train()

  from .autonotebook import tqdm as notebook_tqdm


ModelGEN(
  (feature): LlamaModel(
    (embed_tokens): Embedding(22, 64)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=64, out_features=64, bias=False)
          (k_proj): Linear(in_features=64, out_features=64, bias=False)
          (v_proj): Linear(in_features=64, out_features=64, bias=False)
          (o_proj): Linear(in_features=64, out_features=64, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=64, out_features=64, bias=False)
          (up_proj): Linear(in_features=64, out_features=64, bias=False)
          (down_proj): Linear(in_features=64, out_features=64, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (fc_out): Linear(in_features=64, out_features=22

In [4]:
def get_prob_log(model, choice, reject):
    b = choice['input_ids'].shape[0]

    #合并两部分输入,同时计算以提高效率
    #[b, 21]
    input_ids = torch.cat([choice['input_ids'], reject['input_ids']], dim=0)
    attention_mask = torch.cat(
        [choice['attention_mask'], reject['attention_mask']], dim=0)
    label = torch.cat([choice['label'], reject['label']], dim=0)

    #[b, 21, 28]
    out = model(input_ids=input_ids, attention_mask=attention_mask)

    #偏移以对齐
    #[b, 20]
    label = label[:, 1:]
    #[b, 20, 28]
    out = out[:, :-1]

    #取所有字的预测概率,因为要求联合概率,所以取对数
    out = (out.softmax(2) + 1e-8).log()

    #取预测到label的概率
    #索引不能是负数,所以这里把负数置0
    index = label.clone().unsqueeze(2)
    index[index == -100] = 0
    prob = out.gather(2, index=index).squeeze(2)

    #只取答案部分的loss,筛选后,所有答案的概率对数求和
    prob = (prob * (label != -100)).sum(1)

    #choice和reject的预测概率求差
    return prob[:b] - prob[b:]


get_prob_log(model_dpo, *get_batch_data())

tensor([-39.9965, -35.2345, -36.4901, -38.7378, -40.6091, -34.2576, -42.3944,
        -33.0563, -34.9951, -41.1052, -31.9088, -49.1035, -36.8061, -44.1473,
        -49.0926, -34.7165, -37.8078, -43.8499, -38.8814, -36.5278, -37.6655,
        -35.2453, -33.1654, -32.6738, -31.2117, -42.3589, -42.8478, -37.1372,
        -35.1684, -35.3945, -41.5367, -38.2373, -39.0609, -39.1670, -42.1039,
        -38.2436, -35.9996, -38.0521, -34.7511, -34.5437, -54.1977, -32.8421,
        -40.8836, -36.5542, -34.1246, -34.0929, -36.6590, -31.7397, -29.0101,
        -35.5826, -40.0470, -40.6871, -38.7898, -36.2228, -34.8132, -40.4356,
        -36.8790, -42.0786, -31.5971, -44.0202, -33.5269, -45.4662, -35.5565,
        -34.9628], device='cuda:0', grad_fn=<SubBackward0>)

In [5]:
optimizer = torch.optim.Adam(model_dpo.parameters(),
                             lr=1e-4,
                             betas=(0.9, 0.999),
                             eps=1e-8)

for i in range(20_0000):
    choice, reject = get_batch_data()

    #两个模型分别计算概率对数
    prob_log = get_prob_log(model_dpo, choice, reject)
    with torch.no_grad():
        prob_log_ref = get_prob_log(model_dpo_ref, choice, reject)

    #两份概率计算kl散度
    kl = -0.1 * (prob_log - prob_log_ref)

    #以kl散度计算loss
    loss = (kl.sigmoid() + 1e-8).log().mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 2000 == 0:
        question = tokenizer.get_data(third_number=True)
        question = question[:question.index(tokenizer.encoder['=']) + 1]
        question = torch.LongTensor(question).unsqueeze(0).to(device)

        gen = generate(model_dpo, question)
        print(i, tokenizer.decode(gen[0].tolist()))

model_dpo.to('cpu')
torch.save(model_dpo, 'dpo.model')

0 S101.71=95.85+46.75E
2000 S3403.86=44.89*83.31+51.47E
4000 S55.84=66.64/-91.22+57.96E
6000 S-767.08=-69.40*16.22+57.81E
8000 S162.65=80.69--83.36+39.43E
10000 S16.15=-33.63+-57.77+90.64E
12000 S-68.70=58.57/31.96+-60.09E
14000 S-51.03=-45.60+47.22+-49.87E
16000 S-51.65=-96.99/88.35+-50.66E
18000 S30.26=-85.70/44.87+32.50E
20000 S2164.24=47.01*51.23+-71.68E
22000 S1113.01=-95.91*-11.21+66.08E
24000 S1202.65=68.13*17.33+81.90E
26000 S-33.20=-52.21*4.71+32.87E
28000 S-5469.87=-63.48*78.04+93.87E
30000 S-25.27=66.50/-4.51+-12.73E
32000 S31.61=-77.76/24.99+34.24E
34000 S-2575.25=-43.91*56.71+90.36E
36000 S100.97=89.83+-1.77+13.94E
38000 S-314.49=-27.89*21.38+-85.69E
40000 S-133.87=-64.75+-43.61+-21.59E
42000 S-54.38=-44.74+-60.99+54.11E
44000 S-100.67=-81.50--52.17+-79.47E
46000 S54.37=4.65--96.74+-44.74E
48000 S-87.06=8.80--36.98+-99.28E
50000 S-127.15=4.63+-96.55+-46.92E
52000 S29.78=-32.81+-12.72+71.03E
54000 S47.74=-37.14/13.68+50.34E
56000 S-451.05=-57.42*6.12+-30.62E
58000 S-214.12=