In [1]:
%run 1.common.ipynb

tokenizer.third_number = True

[tokenizer.decode(i) for i in tokenizer.get_data()]

  from .autonotebook import tqdm as notebook_tqdm


['SOS191.53=', '-28.38*-4.27+70.35EOS']

In [2]:
from trl.trainer.utils import pad_to_length


def f(data):
    #取数据
    q = [i[0] for i in data]
    a = [i[1] for i in data]

    data = {
        'input_ids_choice': [],
        'input_ids_reject': [],
        'label_choice': [],
        'label_reject': [],
    }

    eos = torch.LongTensor([tokenizer.encoder['EOS']])
    pad = tokenizer.encoder['PAD']

    for i in range(len(q)):
        #q和两个结果组合在一起
        cat = torch.cat([q[i], a[i]])
        data['input_ids_choice'].append(cat)

        cat = torch.cat([q[i], eos])
        data['input_ids_reject'].append(cat)

        #label的q部分置为-100
        cat = torch.cat([torch.full_like(q[i], -100), a[i]])
        data['label_choice'].append(cat)

        cat = torch.cat([torch.full_like(q[i], -100), eos])
        data['label_reject'].append(cat)

    #求最大长度
    #lens = max([len(i) for k in ['input_ids_choice', 'input_ids_reject'] for i in data[k]])
    lens = max([len(i) for i in data['input_ids_choice']])

    #统一长度
    for k, v in data.items():
        padding_value = pad
        if k in ['label_choice', 'label_reject']:
            padding_value = -100

        v = torch.nn.utils.rnn.pad_sequence(v,
                                            batch_first=True,
                                            padding_value=padding_value)

        data[k] = pad_to_length(v, length=lens, pad_value=padding_value)

    data['attention_mask_choice'] = (data['input_ids_choice'] != pad).long()
    data['attention_mask_reject'] = (data['input_ids_reject'] != pad).long()

    return data


loader = get_loader(f)

next(iter(loader))

{'input_ids_choice': tensor([[ 1, 16, 10,  ...,  0,  0,  0],
         [ 1,  7,  8,  ...,  0,  0,  0],
         [ 1, 16,  7,  ...,  0,  0,  0],
         ...,
         [ 1, 16,  6,  ...,  0,  0,  0],
         [ 1, 11, 13,  ...,  0,  0,  0],
         [ 1,  7, 13,  ...,  0,  0,  0]]),
 'input_ids_reject': tensor([[ 1, 16, 10,  ...,  0,  0,  0],
         [ 1,  7,  8,  ...,  0,  0,  0],
         [ 1, 16,  7,  ...,  0,  0,  0],
         ...,
         [ 1, 16,  6,  ...,  0,  0,  0],
         [ 1, 11, 13,  ...,  0,  0,  0],
         [ 1,  7, 13,  ...,  0,  0,  0]]),
 'label_choice': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100]]),
 'label_reject': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -

In [3]:
model = torch.load('sft.model')
model_ref = torch.load('sft.model')

In [4]:
@torch.no_grad()
def test(N=5, device='cpu'):
    generate_model.to(device)
    for _ in range(N):
        x, y = tokenizer.get_data()
        print(tokenizer.decode(x + y))

        x = torch.LongTensor(x).unsqueeze(0).to(device)
        out = generate_model.generate(x, max_length=40)[0]
        print(tokenizer.decode(out))

        #尝试执行计算
        try:
            out = out[1:-1]
            idx_eq = (out == tokenizer.encoder['=']).nonzero().item() + 1
            out = tokenizer.decode(out[idx_eq:])
            print(eval(out))
        except:
            pass

        print('---------------')


test()

SOS59.53=1.22+60.26+-1.95EOS
SOS59.53=-1.24+69.24EOS
68.0
---------------
SOS62.98=24.97**1.20+15.46EOS
SOS62.98=-1.29+69.24EOS
67.94999999999999
---------------
SOS89.19=61.74**-2.04+89.19EOS
SOS89.19=12.29+78.29EOS
90.58000000000001
---------------
SOS-1271.72=-54.15*24.32+45.21EOS
SOS-1271.72=-28.22*59.22EOS
-1671.1884
---------------
SOS92.49=-12.33/-45.52+92.22EOS
SOS92.49=-1.29*-60.24EOS
77.70960000000001
---------------


In [5]:
def get_loss(model, data, device='cpu'):
    b = data['input_ids_choice'].shape[0]

    #合并两部分输入,同时计算以提高效率
    #[8, 21]
    input_ids = torch.cat([data['input_ids_choice'], data['input_ids_reject']],
                          dim=0).to(device)

    attention_mask = torch.cat(
        [data['attention_mask_choice'], data['attention_mask_reject']],
        dim=0).to(device)

    label = torch.cat([data['label_choice'], data['label_reject']],
                      dim=0).to(device)

    #[8, 21, 28]
    out = model(input_ids=input_ids, attention_mask=attention_mask)

    #偏移以对齐
    #[8, 20]
    label = label[:, 1:]
    #[8, 20, 28]
    out = out[:, :-1]

    #取所有字的预测概率,因为要求联合概率,所以取对数
    out = (out.softmax(2) + 1e-8).log()

    #索引不能是负数,所以这里把负数置0
    #[8, 20, 1]
    index = label.clone().unsqueeze(2)
    index[index == -100] = 0

    #取预测到label的概率
    #[8, 20]
    prob = torch.gather(out, dim=2, index=index).squeeze(2)

    #只取答案部分的loss,筛选后,所有答案的概率对数求和
    prob = (prob * (label != -100)).sum(1)

    #choice和reject的预测概率求差作为loss
    return prob[:b] - prob[b:]


with torch.no_grad():
    print(get_loss(model, next(iter(loader))))

tensor([-40.7997, -52.1692, -52.0745, -75.4763, -43.0688, -55.2397, -40.2297,
        -59.4412, -49.9825, -44.1882, -30.6174, -36.7304, -54.4920, -49.2033,
        -45.2228, -51.9443, -52.3449, -49.2872, -51.2239, -58.6045, -56.5570,
        -40.5392, -46.0128, -55.3534, -44.3628, -54.7784, -45.2691, -38.7308,
        -79.1975, -53.3040, -51.6501, -45.6031, -50.3548, -45.0006, -45.5645,
        -77.4524, -41.4018, -60.1273, -58.8243, -45.4752, -57.2804, -36.0864,
        -60.5132, -48.3292, -51.0794, -62.7518, -49.1459, -58.4754, -49.0210,
        -53.7823, -43.1840, -53.7013, -40.9312, -62.1900, -45.2143, -48.8859,
        -43.3670, -28.1549, -68.4661, -55.2323, -28.2384, -36.2104, -55.9187,
        -34.3956])


In [6]:
def train():
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 betas=(0.9, 0.999),
                                 eps=1e-8)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.train()
    model.to(device)
    model_ref.to(device)

    for i in range(10_0000):
        data = next(iter(loader))
        for k, v in data.items():
            data[k] = v.to(device)

        loss = get_loss(model, data, device)

        with torch.no_grad():
            loss_ref = get_loss(model_ref, data, device)

        #logsigmoid正数归零的激活函数,有一定的平滑
        loss = -torch.nn.functional.logsigmoid(0.1 * (loss - loss_ref)).mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 2000 == 0:
            print(i)
            test(2, device)

    model.to('cpu')
    torch.save(model, 'dpo.model')


train()

0
SOS87.43=99.57+29.14+-41.28EOS
SOS87.43=-1.29*-29.24EOS
37.7196
---------------
SOS-4065.68=51.01*-79.32+-19.57EOS
SOS-4065.68=-40.44*84.08EOS
-3400.1951999999997
---------------
2000
SOS285.69=88.85*3.51+-26.17EOS
SOS285.69=-2.00*-73.02+-90.76EOS
55.27999999999999
---------------
SOS-63.77=-1.75+-18.11+-43.91EOS
SOS-63.77=-1.38**-1.08+-17.07EOS
-17.776204672506694
---------------
4000
SOS-9473.09=-98.52*96.63+46.90EOS
SOS-9473.09=-90.00*97.18+-9.18EOS
-8755.380000000001
---------------
SOS63.43=20.86**-2.20+63.43EOS
SOS63.43=-1.11**-2.18+69.87EOS
69.07348142841013
---------------
6000
SOS113.75=36.12*3.14+0.33EOS
SOS113.75=-1.11--97.14+27.17EOS
123.2
---------------
SOS40.45=54.29--12.06+-25.90EOS
SOS40.45=-40.44+-37.74+6.44EOS
-71.74000000000001
---------------
8000
SOS179.04=-3.32*-83.02+-96.59EOS
SOS179.04=99.99--19.37+49.32EOS
168.68
---------------
SOS-5279.83=-13.50**3.30+91.74EOS
SOS-5279.83=-61.11*89.62+-63.62EOS
-5540.2982
---------------
10000
SOS214.45=69.13--87.79+57.53E

SOS187.27=77.77--29.77+79.77EOS
187.31
---------------
84000
SOS998.58=27.83*35.02+23.97EOS
SOS998.58=-10.71*-97.11+-7.77EOS
1032.2781
---------------
SOS-124.30=-68.74-10.61+-44.95EOS
SOS-124.30=-97.77+-77.77+49.77EOS
-125.76999999999998
---------------
86000
SOS-159.39=3.38-75.41+-87.36EOS
SOS-159.39=-99.99-97.39+32.47EOS
-164.91
---------------
SOS-78.24=98.64**-0.34+-78.45EOS
SOS-78.24=-9.33+-91.09+22.07EOS
-78.35
---------------
88000
SOS21.22=-43.54**-1.15+21.23EOS
SOS21.22=-1.30**-1.11+20.30EOS
19.552651996026142
---------------
SOS-32.90=-86.47/-7.06+-45.15EOS
SOS-32.90=-1.30**-1.39+-32.40EOS
-33.09441443535542
---------------
90000
SOS-66.82=-11.46/30.71+-66.45EOS
SOS-66.82=-1.10**-1.88+-66.82EOS
-67.6559527906284
---------------
SOS72.45=-77.28/-55.42+71.06EOS
SOS72.45=-1.11*-11.11+49.08EOS
61.412099999999995
---------------
92000
SOS11.54=-86.36/25.17+14.97EOS
SOS11.54=-1.39**-4.19+11.37EOS
11.118366773113287
---------------
SOS-209.15=-62.71-78.75+-67.69EOS
SOS-209.15=-99.7