In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%run 1.tokenizer.ipynb

tokenizer

<__main__.Tokenizer at 0x7f487d2ac550>

In [2]:
%run 2.dataset.ipynb


def f(data):
    data = [i['text'] for i in data]
    return tokenizer(data,
                     device=device,
                     add_eos_token=False,
                     padding_side='left')['input_ids']


loader = get_loader(f, negative_label=False, with_answer=False)

len(loader), next(iter(loader))

(62500,
 tensor([[ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           2,  2,  2,  2,  0,  7, 13, 13,  5, 14, 11, 10, 11,  9, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           2,  2,  2,  2,  0, 12, 13,  9,  7, 14, 11, 10,  4,  6, 18],
         [ 0,  9,  4,  7,  4, 14, 11,  9,  7, 10, 18,  5,  6,  9, 10, 10, 14,  9,
           8, 12,  4, 18,  5, 12,  4,  8, 10, 14,  9,  5,  5,  7, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0,  7,  8,  7, 11, 14,
           5, 10,  7,  7, 18,  9,  4, 11,  4, 14,  6, 12,  8, 10, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0,  9, 13,  4,  9, 14, 11,
          13,  4,  7, 18,  5,  7, 12,  4, 12, 14, 13,  6, 10,  4, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0, 11,  6, 13,  7,
          14,  9,  5,  5, 18, 11, 12,  4,  8, 14, 12,  7, 10,  5, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0, 12,  9,  9, 12, 

In [3]:
%run 3.model.ipynb

model_actor = torch.load('model/actor', weights_only=False).to(device)
model_actor_ref = torch.load('model/actor', weights_only=False).to(device)

model_critic = torch.load('model/critic', weights_only=False).to(device)
model_critic_ref = torch.load('model/critic', weights_only=False).to(device)

optimizer = torch.optim.AdamW(list(model_actor.parameters()) +
                              list(model_critic.parameters()),
                              lr=1e-5)

In [4]:
def get_value(critic, question, answer, shift=True):
    input_ids = torch.cat((question, answer), 1)
    attention_mask = input_ids != tokenizer.pad_token_id
    input_ids = torch.masked_fill(input_ids, ~attention_mask, 0)

    #[b, lens, 768]
    last_hidden_state = critic.model(input_ids=input_ids,
                                     attention_mask=attention_mask)

    #[b, lens]
    value = critic.score(last_hidden_state)

    if shift:
        value = value[:, question.shape[1] - 1:-1].squeeze(-1)

    return value


get_value(model_critic,
          torch.randint(0, 10, [2, 5]).to(device),
          torch.randint(0, 10, [2, 15]).to(device)).shape

torch.Size([2, 15])

In [5]:
def get_logprob(actor, question, answer):
    input_ids = torch.cat((question, answer), 1)
    attention_mask = input_ids != tokenizer.pad_token_id
    input_ids = torch.masked_fill(input_ids, ~attention_mask, 0)

    _, logits = actor(input_ids=input_ids, attention_mask=attention_mask)

    logits = logits[:, question.shape[1] - 1:-1] / 0.7

    logprob = logits.log_softmax(dim=-1)
    logprob = logprob.gather(2, answer.unsqueeze(-1)).squeeze(-1)

    return logprob


get_logprob(model_actor,
            torch.randint(0, 10, [2, 5]).to(device),
            torch.randint(0, 10, [2, 15]).to(device)).shape

torch.Size([2, 15])

In [6]:
def get_advantage(value, reward_kl):
    advantage = []
    last = 0
    for i in reversed(range(value.shape[1])):
        value_next = 0.0
        if i < value.shape[1] - 1:
            value_next = value[:, i + 1]

        delta = reward_kl[:, i] + value_next - value[:, i]

        last = delta + 0.95 * last

        advantage.append(last)

    return torch.stack(advantage[::-1], axis=1)


get_advantage(torch.randn(4, 25), torch.randn(4, 25)).shape

torch.Size([4, 25])

In [7]:
from trl.trainer.utils import first_true_indices


@torch.no_grad()
def get_data(question):
    #====answer====
    answer = generate(model_actor,
                      input_ids=question,
                      pad_token_id=tokenizer.pad_token_id,
                      eos_token_id=tokenizer.eos_token_id)

    answer = answer[:, question.shape[1]:]

    #求结束位置
    ends = first_true_indices(answer == tokenizer.pad_token_id).tolist()

    #====prob,value====
    prob_old = get_logprob(model_actor, question, answer)
    prob_ref = get_logprob(model_actor_ref, question, answer)
    value_old = get_value(model_critic, question, answer)
    #这里因为有可能取到最后一个字,所以不能偏移,如果偏移的话,最后一个字的值会被裁剪掉.
    value_ref = get_value(model_critic_ref, question, answer, shift=False)

    #end以后的值value归零
    for i, end in enumerate(ends):
        prob_old[i, end:] = 1.0
        prob_ref[i, end:] = 1.0
        value_old[i, end + 1:] = 0.0

    #====reward====
    reward = []
    for i, end in enumerate(ends):
        #没有eos符号的,置为-1
        if tokenizer.eos_token_id not in answer[i]:
            #reward.append(-1)
            #continue
            pass
        #取最后一个字的value作为reward
        reward.append(value_ref[i, end + question.shape[1] - 1])
    reward = torch.FloatTensor(reward).to(device)

    #====advantage====
    #计算kl散度
    reward_kl = -0.05 * (prob_old - prob_ref)

    #把reward加在最后一个字的kl散度上
    for i, end in enumerate(ends):
        if end == len(answer[i]):
            end = -1
        #assert end == -1

        reward_kl[i, end] += reward[i]

    advantage = get_advantage(value_old, reward_kl)
    returns = advantage + value_old

    #标准化,保持数值稳定
    select = torch.cat([adv[:end] for adv, end in zip(advantage, ends)])
    advantage = (advantage - select.mean()) / (select.var() + 1e-8)**0.5

    #end以后的值归零
    for i, end in enumerate(ends):
        advantage[i, end:] = 0

    return question, answer, ends, prob_old, value_old, advantage, returns


get_data(next(iter(loader)))

  from .autonotebook import tqdm as notebook_tqdm


(tensor([[ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0, 12, 13,
           8, 14,  8,  4, 18, 13,  7,  8, 14,  9,  4,  6,  7, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           2,  2,  2,  0,  5, 12,  6,  5, 14, 11,  4, 12,  8, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           2,  2,  2,  0,  7, 12, 11, 12, 14, 10,  9,  4,  8, 18],
         [ 2,  2,  0,  6,  7,  4,  4, 14, 10,  7,  8, 18,  6, 13,  7,  8, 14,  5,
           4, 12, 13, 18,  8,  4,  6,  7, 14,  8, 11,  5,  5, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0,  6,  9, 12,  7, 14, 10,
           5, 12,  8, 18, 12, 11, 10, 11, 14,  8,  4,  6,  5, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  0,  5,  5, 10,  4, 14,  5,
           4,  7,  9, 18,  6,  5, 13,  9, 14, 11,  5, 12, 10, 18],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           2,  2,  2,  0

In [8]:
def train(question, answer, ends, prob_old, value_old, advantage, returns):
    for _ in range(4):
        #重新计算value和prob
        prob_new = get_logprob(model_actor, question, answer)
        value_new = get_value(model_critic, question, answer)

        #end以后的值value归零
        for i, end in enumerate(ends):
            prob_new[i, end:] = 1.0
            value_new[i, end + 1:] = 0

        #计算critic部分的loss
        value_clip = torch.clamp(value_new, value_old - 0.2, value_old + 0.2)
        loss_vf1 = (value_new - returns)**2
        loss_vf2 = (value_clip - returns)**2
        loss_vf = torch.max(loss_vf1, loss_vf2)

        #计算actor部分的loss
        ratio = (prob_new - prob_old).exp()
        loss_pg1 = -advantage * ratio
        loss_pg2 = -advantage * torch.clamp(ratio, 0.8, 1.2)
        loss_pg = torch.max(loss_pg1, loss_pg2)

        #丢弃end之后的部分
        loss_vf = [xi[:end + 1] for xi, end in zip(loss_vf, ends)]
        loss_pg = [xi[:end + 1] for xi, end in zip(loss_pg, ends)]
        loss_vf = torch.cat(loss_vf).mean()
        loss_pg = torch.cat(loss_pg).mean()

        loss = loss_pg + 0.05 * loss_vf
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


train(*get_data(next(iter(loader))))

In [9]:
for i in range(5_0000):
    train(*get_data(next(iter(loader))))

    if i % 200 == 0:
        print(i)
        input_ids = next(iter(loader))[0:1]

        gen = generate(model_actor,
                       input_ids=input_ids,
                       pad_token_id=tokenizer.pad_token_id,
                       eos_token_id=tokenizer.eos_token_id)

        question = tokenizer.decode(input_ids[0])
        gen = tokenizer.decode(gen[0, input_ids.shape[1]:])

        print({'question': question, 'gen': gen})

torch.save(model_actor.to('cpu'), 'model/ppo')

0
{'question': 'B177+9401=', 'gen': '9578+4601=14179E'}
200
{'question': 'B2401+6696=9097+7166=', 'gen': '16263E'}
400
{'question': 'B8995+5001=', 'gen': '13996E'}
600
{'question': 'B5843+6778=', 'gen': '12621E'}
800
{'question': 'B3193+6516=', 'gen': '9709E'}
1000
{'question': 'B4075+6829=10904+6940=', 'gen': '17844E'}
1200
{'question': 'B7498+7584=15082+590=', 'gen': '15672E'}
1400
{'question': 'B9685+5812=15497+1720=17217+140=', 'gen': '17357E'}
1600
{'question': 'B3977+2479=6456+6597=', 'gen': '13053E'}
1800
{'question': 'B3643+6348=9991+4412=14403+5166=', 'gen': '19569E'}
2000
{'question': 'B8540+4172=12712+4179=16891+9758=', 'gen': '26649E'}
2200
{'question': 'B4614+5868=', 'gen': '10482E'}
2400
{'question': 'B7421+9344=16765+4075=20840+1330=', 'gen': '22170E'}
2600
{'question': 'B8788+5388=', 'gen': '14176E'}
2800
{'question': 'B227+7443=', 'gen': '7670E'}
3000
{'question': 'B8202+5517=13719+8693=', 'gen': '22412E'}
3200
{'question': 'B7374+222=7596+4547=12143+6562=', 'gen': '18

27400
{'question': 'B1299+8578=9877+7841=', 'gen': '17718E'}
27600
{'question': 'B6248+9153=15401+5944=21345+8830=', 'gen': '30175E'}
27800
{'question': 'B7687+2624=', 'gen': '10311E'}
28000
{'question': 'B843+276=1119+4278=', 'gen': '5397E'}
28200
{'question': 'B5330+1040=6370+8668=15038+2427=', 'gen': '17465E'}
28400
{'question': 'B9591+7655=17246+1044=', 'gen': '18290E'}
28600
{'question': 'B2744+2624=5368+8623=', 'gen': '13991E'}
28800
{'question': 'B6018+4965=10983+2851=13834+163=', 'gen': '13997E'}
29000
{'question': 'B9630+296=', 'gen': '9926E'}
29200
{'question': 'B5850+4378=10228+3933=', 'gen': '14161E'}
29400
{'question': 'B4574+4863=9437+8197=', 'gen': '17634E'}
29600
{'question': 'B4577+9369=13946+1717=15663+1550=', 'gen': '17213E'}
29800
{'question': 'B7411+5341=12752+2707=15459+9494=', 'gen': '24953E'}
30000
{'question': 'B359+3569=3928+9223=', 'gen': '13151E'}
30200
{'question': 'B1821+6868=', 'gen': '8689E'}
30400
{'question': 'B7902+210=8112+5522=', 'gen': '13634E'}
30