In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%run 1.tokenizer.ipynb

tokenizer

<__main__.Tokenizer at 0x7f1f0534c280>

In [2]:
%run 2.dataset.ipynb


def f(data):
    data = [i['text'] for i in data]
    return tokenizer(data,
                     padding=True,
                     truncation=True,
                     max_length=32,
                     device=device,
                     add_bos_token=True,
                     add_eos_token=False,
                     padding_side='left')['input_ids']


loader = get_loader(f, negative_label=False, with_answer=False)

len(loader), next(iter(loader))

(125000,
 tensor([[ 0,  5,  3,  8,  9, 13,  8,  5,  9,  3, 17],
         [ 0, 12,  5,  8,  4, 14,  8,  9,  7, 10, 17],
         [ 0,  6,  6,  3,  4, 14, 11, 10,  9, 12, 17],
         [ 0,  6, 11, 10,  8, 14,  4,  9, 12,  3, 17],
         [ 2,  0,  5,  3,  9,  8, 14,  9, 10, 10, 17],
         [ 0, 11, 11,  3,  3, 16,  4,  4,  4,  4, 17],
         [ 0,  7,  5,  6,  5, 14,  5,  3,  6, 12, 17],
         [ 0,  5, 12,  8, 12, 14,  9, 11,  3,  4, 17],
         [ 0, 11,  6,  5, 12, 14,  6,  8,  5,  4, 17],
         [ 2,  0, 10, 11,  4, 15,  9,  5, 10,  4, 17],
         [ 0,  8,  9,  5,  5, 15,  6,  5, 11,  9, 17],
         [ 0,  8,  3,  8, 10, 16,  5,  4,  8,  7, 17],
         [ 2,  0,  8,  7, 10, 11, 15, 10,  6,  9, 17],
         [ 0,  7,  6,  9, 10, 16,  9, 10,  9,  6, 17],
         [ 0,  5, 11, 10,  6, 16,  5,  8,  8,  3, 17],
         [ 2,  0,  9,  4,  8,  9, 15,  8,  8,  9, 17]], device='cuda:0'))

In [3]:
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
from trl.trainer.utils import disable_dropout_in_model

model_actor = AutoModelForCausalLM.from_pretrained('model/actor').to(device)
model_actor_ref = AutoModelForCausalLM.from_pretrained('model/actor').to(
    device)

model_critic = AutoModelForSequenceClassification.from_pretrained(
    'model/critic', num_labels=1).to(device)
model_critic_ref = AutoModelForSequenceClassification.from_pretrained(
    'model/critic', num_labels=1).to(device)

model_actor.generation_config.eos_token_id = None
model_actor.generation_config.pad_token_id = None

for i in [model_actor, model_actor_ref, model_critic, model_critic_ref]:
    disable_dropout_in_model(i)

optimizer = torch.optim.AdamW(list(model_actor.parameters()) +
                              list(model_critic.parameters()),
                              lr=5e-6)

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


In [4]:
def get_value(critic, question, answer, shift=True):
    input_ids = torch.cat((question, answer), 1)
    attention_mask = input_ids != tokenizer.pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()
    input_ids = torch.masked_fill(input_ids, ~attention_mask, 0)

    #[b, lens, 768]
    last_hidden_state = critic.gpt_neox(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids).last_hidden_state

    #[b, lens]
    value = critic.score(last_hidden_state)

    if shift:
        value = value[:, question.shape[1] - 1:-1].squeeze(-1)

    return value


get_value(model_critic,
          torch.randint(0, 10, [2, 5]).to(device),
          torch.randint(0, 10, [2, 15]).to(device)).shape

torch.Size([2, 15])

In [None]:
def get_logprob(actor, question, answer):
    input_ids = torch.cat((question, answer), 1)
    attention_mask = input_ids != tokenizer.pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()
    input_ids = torch.masked_fill(input_ids, ~attention_mask, 0)

    logits = actor(input_ids=input_ids,
                   attention_mask=attention_mask,
                   position_ids=position_ids).logits

    logits = logits[:, question.shape[1] - 1:-1]
    logits /= 0.7

    logprob = logits.log_softmax(dim=-1)
    logprob = logprob.gather(2, answer.unsqueeze(-1)).squeeze(-1)

    return logprob


get_logprob(model_actor,
            torch.randint(0, 10, [2, 5]).to(device),
            torch.randint(0, 10, [2, 15]).to(device)).shape

torch.Size([2, 15])

In [6]:
def get_advantage(value, reward_kl):
    advantage = []
    last = 0
    for i in reversed(range(value.shape[1])):
        value_next = 0.0
        if i < value.shape[1] - 1:
            value_next = value[:, i + 1]

        delta = reward_kl[:, i] + value_next - value[:, i]

        last = delta + 0.95 * last

        advantage.append(last)

    return torch.stack(advantage[::-1], axis=1)


get_advantage(torch.randn(4, 25), torch.randn(4, 25)).shape

torch.Size([4, 25])

In [7]:
from trl.trainer.utils import first_true_indices


@torch.no_grad()
def get_data(question):
    #====answer====
    answer = model_actor.generate(
        input_ids=question,
        attention_mask=(question != tokenizer.pad_token_id).long(),
        min_length=-1,
        max_length=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        top_k=0.0,
        top_p=1.0,
        do_sample=True)

    answer = answer[:, question.shape[1]:]

    #求结束位置
    ends = first_true_indices(answer == tokenizer.pad_token_id).tolist()

    #====prob,value====
    prob_old = get_logprob(model_actor, question, answer)
    prob_ref = get_logprob(model_actor_ref, question, answer)
    value_old = get_value(model_critic, question, answer)
    #这里因为有可能取到最后一个字,所以不能偏移,如果偏移的话,最后一个字的值会被裁剪掉.
    value_ref = get_value(model_critic_ref, question, answer, shift=False)

    #end以后的值value归零
    for i, end in enumerate(ends):
        prob_old[i, end:] = 1.0
        prob_ref[i, end:] = 1.0
        value_old[i, end + 1:] = 0.0

    #====reward====
    reward = []
    for i, end in enumerate(ends):
        #没有eos符号的,置为-1
        if tokenizer.eos_token_id not in answer[i]:
            #reward.append(-1)
            #continue
            pass
        #取最后一个字的value作为reward
        reward.append(value_ref[i, end + question.shape[1] - 1])
    reward = torch.FloatTensor(reward).to(device)

    #====advantage====
    #计算kl散度
    reward_kl = -0.05 * (prob_old - prob_ref)

    #把reward加在最后一个字的kl散度上
    for i, end in enumerate(ends):
        if end == len(answer[i]):
            end = -1
        #assert end == -1

        reward_kl[i, end] += reward[i]

    advantage = get_advantage(value_old, reward_kl)
    returns = advantage + value_old

    #标准化,保持数值稳定
    select = torch.cat([adv[:end] for adv, end in zip(advantage, ends)])
    advantage = (advantage - select.mean()) / (select.var() + 1e-8)**0.5

    #end以后的值归零
    for i, end in enumerate(ends):
        advantage[i, end:] = 0

    return question, answer, ends, prob_old, value_old, advantage, returns


get_data(next(iter(loader)))

(tensor([[ 0,  9,  9,  5,  7, 13,  4, 10,  7,  8, 17],
         [ 0,  6, 10, 10,  4, 13,  6, 12, 12,  6, 17],
         [ 0,  5, 10, 11,  4, 14,  7, 12, 10, 11, 17],
         [ 2,  0,  8,  7,  3, 13, 12, 10,  7,  5, 17],
         [ 0, 10,  7,  3,  8, 13,  6,  8,  5,  9, 17],
         [ 0, 11,  6,  6,  6, 16, 11, 12,  5, 12, 17],
         [ 0,  9,  9,  5,  8, 14, 11, 12, 11, 10, 17],
         [ 0,  9,  6,  9,  4, 13,  8,  4,  7, 10, 17],
         [ 0,  6,  5, 12,  3, 15, 12,  6,  8,  8, 17],
         [ 0,  9,  8,  7,  4, 14, 12,  4,  9,  7, 17],
         [ 0, 11,  6,  7, 11, 15,  8,  9,  8,  5, 17],
         [ 0,  8, 11,  4,  3, 13,  8,  8,  9,  3, 17],
         [ 2,  0,  8,  4,  4,  5, 14,  4,  5, 10, 17],
         [ 0,  7,  6, 10, 12, 15,  9,  5, 10,  9, 17],
         [ 0, 10,  6,  3, 11, 14,  9,  3,  3,  4, 17],
         [ 0, 10,  9,  9,  5, 13,  8,  4,  5,  4, 17]], device='cuda:0'),
 tensor([[11,  6,  9, 12,  1,  2,  2,  2,  2],
         [10, 10,  9,  7,  1,  2,  2,  2,  2],
       

In [8]:
def train(question, answer, ends, prob_old, value_old, advantage, returns):
    for _ in range(4):
        #重新计算value和prob
        prob_new = get_logprob(model_actor, question, answer)
        value_new = get_value(model_critic, question, answer)

        #end以后的值value归零
        for i, end in enumerate(ends):
            prob_new[i, end:] = 1.0
            value_new[i, end + 1:] = 0

        #计算critic部分的loss
        value_clip = torch.clamp(value_new, value_old - 0.2, value_old + 0.2)
        loss_vf1 = (value_new - returns)**2
        loss_vf2 = (value_clip - returns)**2
        loss_vf = torch.max(loss_vf1, loss_vf2)

        #计算actor部分的loss
        ratio = (prob_new - prob_old).exp()
        loss_pg1 = -advantage * ratio
        loss_pg2 = -advantage * torch.clamp(ratio, 0.8, 1.2)
        loss_pg = torch.max(loss_pg1, loss_pg2)

        #丢弃end之后的部分
        loss_vf = [xi[:end + 1] for xi, end in zip(loss_vf, ends)]
        loss_pg = [xi[:end + 1] for xi, end in zip(loss_pg, ends)]
        loss_vf = torch.cat(loss_vf).mean()
        loss_pg = torch.cat(loss_pg).mean()

        loss = loss_pg + 0.05 * loss_vf
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


train(*get_data(next(iter(loader))))

In [9]:
for i in range(1_0000):
    train(*get_data(next(iter(loader))))

    if i % 200 == 0:
        print(i)
        input_ids = next(iter(loader))[0:1]

        gen = model_actor.generate(input_ids=input_ids,
                                   min_length=-1,
                                   max_length=50,
                                   pad_token_id=tokenizer.pad_token_id,
                                   eos_token_id=tokenizer.eos_token_id,
                                   top_k=0.0,
                                   top_p=1.0,
                                   do_sample=True)

        question = tokenizer.decode(input_ids[0])

        answer = question
        answer = answer[answer.index(tokenizer.bos_token) + 1:]
        answer = answer[:answer.index('=')]
        answer = int(eval(answer))
        gen = tokenizer.decode(gen[0, input_ids.shape[1]:])

        print({'question': question, 'answer': answer, 'gen': gen})

model_actor.save_pretrained('model/ppo')

0
{'question': 'B6907+5732=', 'answer': 12639, 'gen': '12639E'}
200
{'question': 'B8821+480=', 'answer': 9301, 'gen': '9301E'}
400
{'question': 'B2729-7007=', 'answer': -4278, 'gen': '-4278E'}
600
{'question': 'B5344/9373=', 'answer': 0, 'gen': '0E'}
800
{'question': 'B8053-6391=', 'answer': 1662, 'gen': '1662E'}
1000
{'question': 'B3372/4496=', 'answer': 0, 'gen': '0E'}
1200
{'question': 'B4026*8955=', 'answer': 36052830, 'gen': '37897490E'}
1400
{'question': 'B3810/8578=', 'answer': 0, 'gen': '0E'}
1600
{'question': 'B8148-460=', 'answer': 7688, 'gen': '7788E'}
1800
{'question': 'B3106/1999=', 'answer': 1, 'gen': '1E'}
2000
{'question': 'B9615-6745=', 'answer': 2870, 'gen': '2870E'}
2200
{'question': 'B8147/290=', 'answer': 28, 'gen': '26E'}
2400
{'question': 'B9440-6679=', 'answer': 2761, 'gen': '2761E'}
2600
{'question': 'B3076/145=', 'answer': 21, 'gen': '13E'}
2800
{'question': 'B2811-8273=', 'answer': -5462, 'gen': '-5462E'}
3000
{'question': 'B8277/6264=', 'answer': 1, 'gen': '