In [1]:
%run common.ipynb

tokenizer.decode(tokenizer.get_data(third_number=True))

'S61.57=-46.33--51.50+56.40E'

In [2]:
def get_batch_data():

    def pad(data, split, lens):
        #做个白板
        input_ids = torch.full((len(data), lens),
                               tokenizer.encoder['P'],
                               device=device)

        #往白板里黏贴数据
        for i, d in enumerate(data):
            input_ids[i, :len(d)] = torch.LongTensor(d)

        attention_mask = (input_ids != tokenizer.encoder['P']).long()

        #计算label
        label = input_ids.clone()
        for l, s in zip(label, split):
            #问题和pad的位置是-100
            l[:s] = -100
            l[l == tokenizer.encoder['P']] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label
        }

    #正确的问答
    choice = [tokenizer.get_data(third_number=True) for i in range(64)]

    #错误的回答简单地定义为空回答就可以了
    split = [i.index(tokenizer.encoder['=']) + 1 for i in choice]
    reject = [d[:s] for d, s in zip(choice, split)]
    reject = [i + [tokenizer.encoder['E']] for i in reject]

    #求最大长度
    lens = max([len(i) for i in choice])

    return pad(choice, split, lens), pad(reject, split, lens)


get_batch_data()

({'input_ids': tensor([[ 1,  6,  4,  ...,  0,  0,  0],
          [ 1, 15, 13,  ...,  2,  0,  0],
          [ 1, 12,  4,  ...,  0,  0,  0],
          ...,
          [ 1, 15, 13,  ...,  0,  0,  0],
          [ 1, 15,  6,  ...,  0,  0,  0],
          [ 1, 15,  5,  ...,  5,  5,  2]], device='cuda:0'),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 1, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'),
  'label': tensor([[-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  ...,    2, -100, -100],
          [-100, -100, -100,  ..., -100, -100, -100],
          ...,
          [-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  ..., -100, -100, -100],
          [-100, -100, -100,  ...,    5,    5,    2]], device='cuda:0')},
 {'input_ids': tensor([[ 1,  6,  4,  ...,  0,  0,  0],
          

In [3]:
model_dpo = torch.load('gen.model')
model_dpo.to(device)
model_dpo.train()

model_dpo_ref = torch.load('gen.model')
model_dpo_ref.to(device)
model_dpo_ref.train()

  from .autonotebook import tqdm as notebook_tqdm


ModelGEN(
  (feature): LlamaModel(
    (embed_tokens): Embedding(22, 64)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=64, out_features=64, bias=False)
          (k_proj): Linear(in_features=64, out_features=64, bias=False)
          (v_proj): Linear(in_features=64, out_features=64, bias=False)
          (o_proj): Linear(in_features=64, out_features=64, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=64, out_features=64, bias=False)
          (up_proj): Linear(in_features=64, out_features=64, bias=False)
          (down_proj): Linear(in_features=64, out_features=64, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (fc_out): Linear(in_features=64, out_features=22

In [4]:
def get_prob_log(model, choice, reject):
    b = choice['input_ids'].shape[0]

    #合并两部分输入,同时计算以提高效率
    #[b, 21]
    input_ids = torch.cat([choice['input_ids'], reject['input_ids']], dim=0)
    attention_mask = torch.cat(
        [choice['attention_mask'], reject['attention_mask']], dim=0)
    label = torch.cat([choice['label'], reject['label']], dim=0)

    #[b, 21, 28]
    out = model(input_ids=input_ids, attention_mask=attention_mask)

    #偏移以对齐
    #[b, 20]
    label = label[:, 1:]
    #[b, 20, 28]
    out = out[:, :-1]

    #取所有字的预测概率,因为要求联合概率,所以取对数
    out = (out.softmax(2) + 1e-8).log()

    #取预测到label的概率
    #索引不能是负数,所以这里把负数置0
    index = label.clone().unsqueeze(2)
    index[index == -100] = 0
    prob = out.gather(2, index=index).squeeze(2)

    #只取答案部分的loss,筛选后,所有答案的概率对数求和
    prob = (prob * (label != -100)).sum(1)

    #choice和reject的预测概率求差
    return prob[:b] - prob[b:]


get_prob_log(model_dpo, *get_batch_data())

tensor([-42.1512, -32.2378, -33.3335, -35.5115, -43.3339, -35.5948, -45.4445,
        -41.2210, -35.1233, -34.2614, -37.2273, -35.6421, -44.1526, -36.7085,
        -39.9655, -35.2811, -34.6775, -39.8615, -39.3645, -37.0570, -31.3551,
        -40.6914, -37.8081, -33.7670, -33.3802, -37.0586, -33.6158, -36.1201,
        -35.0879, -37.0642, -39.6770, -42.1228, -39.5020, -41.2665, -38.3211,
        -37.2448, -38.0346, -37.7231, -53.9730, -37.9990, -37.8209, -43.6167,
        -39.8550, -37.9038, -42.0052, -34.4402, -39.1546, -32.8980, -33.8170,
        -39.4172, -50.1610, -36.9279, -52.5554, -48.7900, -35.8651, -47.3491,
        -34.5138, -34.8623, -40.0592, -36.6584, -37.5624, -39.2018, -38.0018,
        -37.3401], device='cuda:0', grad_fn=<SubBackward0>)

In [5]:
optimizer = torch.optim.Adam(model_dpo.parameters(),
                             lr=1e-4,
                             betas=(0.9, 0.999),
                             eps=1e-8)

for i in range(2_0000):
    choice, reject = get_batch_data()

    #两个模型分别计算概率对数
    prob_log = get_prob_log(model_dpo, choice, reject)
    with torch.no_grad():
        prob_log_ref = get_prob_log(model_dpo_ref, choice, reject)

    #两份概率计算kl散度
    kl = -0.1 * (prob_log - prob_log_ref)

    #以kl散度计算loss
    loss = (kl.sigmoid() + 1e-8).log().mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 1000 == 0:
        question = tokenizer.get_data(third_number=True)
        question = question[:question.index(tokenizer.encoder['=']) + 1]
        question = torch.LongTensor(question).unsqueeze(0).to(device)

        gen = generate(model_dpo, question)
        print(i, tokenizer.decode(gen[0].tolist()))

model_dpo.to('cpu')
torch.save(model_dpo, 'dpo.model')

0 S38.30=77.53+1S.482E
1000 S108.85=-29.46--37.05+30.01E
2000 S44.90=-4.14/75.14+68.99E
3000 S-26.53=10.80/41.26+-49.15E
4000 S-36.64=-55.47/-89.69+-42.73E
5000 S5.54=55.63/-11.10+10.59E
6000 S-73.71=-11.88+-72.64+-20.22E
7000 S-36.85=61.31+-77.72+-34.73E
8000 S89.76=-22.73/-83.62+88.44E
9000 S-28.79=25.78/-87.84+-27.99E
10000 S-61.36=-90.41+-5.14+29.17E
11000 S-18.05=3.20/15.27+-10.82E
12000 S-6.66=32.52/-0.61+4.44E
13000 S6060.81=94.70*66.67+-66.46E
14000 S-57.69=-56.26/0.30+-37.07E
15000 S2590.33=-55.67*-39.12+-5.68E
16000 S-210.33=-93.29-25.44+-97.53E
17000 S24.79=84.88+36.65+-79.89E
18000 S-89.55=-86.22/84.67+-88.10E
19000 S89.63=69.98--21.02+1.25E
