In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%run 1.tokenizer.ipynb

tokenizer

<__main__.Tokenizer at 0x7f877a38c520>

In [2]:
%run 2.dataset.ipynb


def f(data):
    data = [i['text'] for i in data]
    return tokenizer(data,
                     padding=True,
                     truncation=True,
                     max_length=32,
                     device=device,
                     add_bos_token=True,
                     add_eos_token=False,
                     padding_side='left')['input_ids']


loader = get_loader(f, negative_label=False, with_answer=False)

len(loader), next(iter(loader))

(62500,
 tensor([[ 0, 10,  7,  4,  4, 14,  9,  8,  8, 11, 18],
         [ 0,  7, 11,  7,  5, 14,  7,  9,  9,  4, 18],
         [ 0,  6, 12, 11,  5, 14, 13,  8, 13,  4, 18],
         [ 2,  0, 13, 13,  4,  9, 14, 12, 13,  5, 18],
         [ 0,  5,  9,  8,  6, 14, 13, 10, 12,  8, 18],
         [ 0,  9,  7, 11,  5, 14,  6,  5,  6,  5, 18],
         [ 0, 11,  5,  6,  5, 14,  8,  7,  9, 10, 18],
         [ 0, 10,  5, 11,  4, 14,  9,  7,  7,  6, 18],
         [ 0, 10, 12,  7,  7, 14,  5, 10,  5, 12, 18],
         [ 0, 11, 10, 11, 11, 14,  7,  6,  8,  7, 18],
         [ 0,  8,  5,  5,  7, 14,  8, 13,  6, 13, 18],
         [ 0,  8, 13, 12,  6, 14, 10, 12, 12, 11, 18],
         [ 0, 12,  4, 13,  7, 14,  7, 12, 13, 10, 18],
         [ 0,  8,  7,  9, 13, 14, 13, 13, 11, 11, 18],
         [ 0, 11,  7, 10,  6, 14,  7, 10, 10,  8, 18],
         [ 0,  9, 12,  6, 11, 14, 12, 11, 12, 13, 18]], device='cuda:0'))

In [3]:
%run 3.model.ipynb

model_actor = torch.load('model/actor', weights_only=False).to(device)
model_actor_ref = torch.load('model/actor', weights_only=False).to(device)

model_critic = torch.load('model/critic', weights_only=False).to(device)
model_critic_ref = torch.load('model/critic', weights_only=False).to(device)

optimizer = torch.optim.AdamW(list(model_actor.parameters()) +
                              list(model_critic.parameters()),
                              lr=1e-5)

In [4]:
def get_value(critic, question, answer, shift=True):
    input_ids = torch.cat((question, answer), 1)
    attention_mask = input_ids != tokenizer.pad_token_id
    input_ids = torch.masked_fill(input_ids, ~attention_mask, 0)

    #[b, lens, 768]
    last_hidden_state = critic.model(input_ids=input_ids,
                                     attention_mask=attention_mask)

    #[b, lens]
    value = critic.score(last_hidden_state)

    if shift:
        value = value[:, question.shape[1] - 1:-1].squeeze(-1)

    return value


get_value(model_critic,
          torch.randint(0, 10, [2, 5]).to(device),
          torch.randint(0, 10, [2, 15]).to(device)).shape

torch.Size([2, 15])

In [5]:
def get_logprob(actor, question, answer):
    input_ids = torch.cat((question, answer), 1)
    attention_mask = input_ids != tokenizer.pad_token_id
    input_ids = torch.masked_fill(input_ids, ~attention_mask, 0)

    _, logits = actor(input_ids=input_ids, attention_mask=attention_mask)

    logits = logits[:, question.shape[1] - 1:-1]
    logits /= 0.7

    logprob = logits.log_softmax(dim=-1)
    logprob = logprob.gather(2, answer.unsqueeze(-1)).squeeze(-1)

    return logprob


get_logprob(model_actor,
            torch.randint(0, 10, [2, 5]).to(device),
            torch.randint(0, 10, [2, 15]).to(device)).shape

torch.Size([2, 15])

In [6]:
def get_advantage(value, reward_kl):
    advantage = []
    last = 0
    for i in reversed(range(value.shape[1])):
        value_next = 0.0
        if i < value.shape[1] - 1:
            value_next = value[:, i + 1]

        delta = reward_kl[:, i] + value_next - value[:, i]

        last = delta + 0.95 * last

        advantage.append(last)

    return torch.stack(advantage[::-1], axis=1)


get_advantage(torch.randn(4, 25), torch.randn(4, 25)).shape

torch.Size([4, 25])

In [7]:
from trl.trainer.utils import first_true_indices


@torch.no_grad()
def get_data(question):
    #====answer====
    answer = generate(model_actor,
                      input_ids=question,
                      pad_token_id=tokenizer.pad_token_id,
                      eos_token_id=tokenizer.eos_token_id)

    answer = answer[:, question.shape[1]:]

    #求结束位置
    ends = first_true_indices(answer == tokenizer.pad_token_id).tolist()

    #====prob,value====
    prob_old = get_logprob(model_actor, question, answer)
    prob_ref = get_logprob(model_actor_ref, question, answer)
    value_old = get_value(model_critic, question, answer)
    #这里因为有可能取到最后一个字,所以不能偏移,如果偏移的话,最后一个字的值会被裁剪掉.
    value_ref = get_value(model_critic_ref, question, answer, shift=False)

    #end以后的值value归零
    for i, end in enumerate(ends):
        prob_old[i, end:] = 1.0
        prob_ref[i, end:] = 1.0
        value_old[i, end + 1:] = 0.0

    #====reward====
    reward = []
    for i, end in enumerate(ends):
        #没有eos符号的,置为-1
        if tokenizer.eos_token_id not in answer[i]:
            #reward.append(-1)
            #continue
            pass
        #取最后一个字的value作为reward
        reward.append(value_ref[i, end + question.shape[1] - 1])
    reward = torch.FloatTensor(reward).to(device)

    #====advantage====
    #计算kl散度
    reward_kl = -0.05 * (prob_old - prob_ref)

    #把reward加在最后一个字的kl散度上
    for i, end in enumerate(ends):
        if end == len(answer[i]):
            end = -1
        #assert end == -1

        reward_kl[i, end] += reward[i]

    advantage = get_advantage(value_old, reward_kl)
    returns = advantage + value_old

    #标准化,保持数值稳定
    select = torch.cat([adv[:end] for adv, end in zip(advantage, ends)])
    advantage = (advantage - select.mean()) / (select.var() + 1e-8)**0.5

    #end以后的值归零
    for i, end in enumerate(ends):
        advantage[i, end:] = 0

    return question, answer, ends, prob_old, value_old, advantage, returns


get_data(next(iter(loader)))

  from .autonotebook import tqdm as notebook_tqdm


(tensor([[ 0,  8, 10,  9, 13, 14, 10,  9,  4,  5, 18],
         [ 0,  6, 11,  6,  5, 14,  5, 12, 11, 12, 18],
         [ 0,  6, 10,  4,  5, 14, 12, 13,  4,  8, 18],
         [ 2,  0, 10,  4,  4,  4, 14,  9,  7, 12, 18],
         [ 0, 13,  4, 13,  6, 14, 11,  8, 13, 10, 18],
         [ 0, 13, 13,  7, 12, 14,  8, 11,  9, 12, 18],
         [ 0,  8,  8, 13,  8, 14,  6, 10, 13,  5, 18],
         [ 0,  6, 12, 10,  4, 14,  9,  7,  6,  7, 18],
         [ 2,  0, 12, 12,  6, 14,  8, 10,  8,  4, 18],
         [ 0,  9, 12, 10, 10, 14, 13,  8, 12, 11, 18],
         [ 2,  0,  6, 13,  5, 14,  6, 12, 13, 11, 18],
         [ 0,  6,  4,  9,  5, 14, 13,  5, 10, 10, 18],
         [ 0, 12, 10, 12,  9, 14, 12,  4, 11,  6, 18],
         [ 0,  8,  5,  4, 13, 14, 10, 10,  6,  8, 18],
         [ 0, 13,  6,  7,  4, 14,  8,  6,  6, 12, 18],
         [ 0, 12, 10, 12, 13, 14,  6, 12,  7, 13, 18]], device='cuda:0'),
 tensor([[ 5,  5,  5, 10,  4,  1],
         [ 8,  9, 13, 13,  1,  1],
         [ 5,  5,  9,  4,  9,  

In [8]:
def train(question, answer, ends, prob_old, value_old, advantage, returns):
    for _ in range(4):
        #重新计算value和prob
        prob_new = get_logprob(model_actor, question, answer)
        value_new = get_value(model_critic, question, answer)

        #end以后的值value归零
        for i, end in enumerate(ends):
            prob_new[i, end:] = 1.0
            value_new[i, end + 1:] = 0

        #计算critic部分的loss
        value_clip = torch.clamp(value_new, value_old - 0.2, value_old + 0.2)
        loss_vf1 = (value_new - returns)**2
        loss_vf2 = (value_clip - returns)**2
        loss_vf = torch.max(loss_vf1, loss_vf2)

        #计算actor部分的loss
        ratio = (prob_new - prob_old).exp()
        loss_pg1 = -advantage * ratio
        loss_pg2 = -advantage * torch.clamp(ratio, 0.8, 1.2)
        loss_pg = torch.max(loss_pg1, loss_pg2)

        #丢弃end之后的部分
        loss_vf = [xi[:end + 1] for xi, end in zip(loss_vf, ends)]
        loss_pg = [xi[:end + 1] for xi, end in zip(loss_pg, ends)]
        loss_vf = torch.cat(loss_vf).mean()
        loss_pg = torch.cat(loss_pg).mean()

        loss = loss_pg + 0.05 * loss_vf
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


train(*get_data(next(iter(loader))))

In [9]:
for i in range(5000):
    train(*get_data(next(iter(loader))))

    if i % 200 == 0:
        print(i)
        input_ids = next(iter(loader))[0:1]

        gen = generate(model_actor,
                       input_ids=input_ids,
                       pad_token_id=tokenizer.pad_token_id,
                       eos_token_id=tokenizer.eos_token_id)

        question = tokenizer.decode(input_ids[0])

        answer = question
        answer = answer[answer.index(tokenizer.bos_token) + 1:]
        answer = answer[:answer.index('=')]
        answer = int(eval(answer))
        gen = tokenizer.decode(gen[0, input_ids.shape[1]:])

        print({'question': question, 'answer': answer, 'gen': gen})

torch.save(model_actor.to('cpu'), 'model/ppo')

0
{'question': 'B7988+4118=', 'answer': 12106, 'gen': '12106E'}
200
{'question': 'B4659+174=', 'answer': 4833, 'gen': '8624E'}
400
{'question': 'B8235+7926=', 'answer': 16161, 'gen': '16161E'}
600
{'question': 'B7991+5433=', 'answer': 13424, 'gen': '13424E'}
800
{'question': 'B8954+9010=', 'answer': 17964, 'gen': '17964E'}
1000
{'question': 'B1389+1740=', 'answer': 3129, 'gen': '3129E'}
1200
{'question': 'B8365+9232=', 'answer': 17597, 'gen': '17597E'}
1400
{'question': 'B9809+6145=', 'answer': 15954, 'gen': '15954E'}
1600
{'question': 'B1243+946=', 'answer': 2189, 'gen': '2189E'}
1800
{'question': 'B9719+6990=', 'answer': 16709, 'gen': '16709E'}
2000
{'question': 'B4089+8954=', 'answer': 13043, 'gen': '13043E'}
2200
{'question': 'B4612+9409=', 'answer': 14021, 'gen': '14021E'}
2400
{'question': 'B8824+5445=', 'answer': 14269, 'gen': '14269E'}
2600
{'question': 'B9393+679=', 'answer': 10072, 'gen': '10172E'}
2800
{'question': 'B7936+2421=', 'answer': 10357, 'gen': '10357E'}
3000
{'ques