In [1]:
import random


class Tokenizer:

    def __init__(self, third_number=False):
        self.vocab = {
            'mark': 'PAD,SOS,EOS,.,='.split(','),
            'number': '0,1,2,3,4,5,6,7,8,9'.split(','),
            'symbol': '+,-,*,/,**'.split(','),
            'letter': 'a,b,c,d,e,x,y,z'.split(','),
        }

        self.decoder = [j for i in self.vocab.values() for j in i]
        self.encoder = {j: i for i, j in enumerate(self.decoder)}
        self.third_number = third_number

    def get_data(self):
        s = random.choice(self.vocab['symbol'])

        a = random.uniform(-100, 100)
        b = random.uniform(-100, 100)

        while s == '/' and abs(b) < 0.01:
            b = random.uniform(-100, 100)

        while s == '**' and abs(b) > 5:
            b = random.uniform(-5, 5)

        while s == '**' and abs(a) < 0.01:
            a = random.uniform(-100, 100)

        x = '%.2f%s%.2f' % (a, s, b)

        #增加第三个数
        if self.third_number:
            x = '%s%s%.2f' % (x, '+', random.uniform(-100, 100))

        y = '%.2f' % eval(x)

        #交换问答方向
        x, y = y, x

        x = [self.encoder['SOS']] + [self.encoder[i]
                                     for i in x] + [self.encoder['=']]
        y = [self.encoder[i] for i in y] + [self.encoder['EOS']]

        return x, y

    def decode(self, x):
        return ''.join([self.decoder[i] for i in x])


tokenizer = Tokenizer(False)

[tokenizer.decode(i) for i in tokenizer.get_data()]

['SOS58.14=', '14.37--43.77EOS']

In [2]:
import torch


def get_loader(collate_fn):

    class Dataset(torch.utils.data.Dataset):

        def __len__(self):
            return 10000

        def __getitem__(self, i):
            x, y = tokenizer.get_data()
            return torch.LongTensor(x), torch.LongTensor(y)

    return torch.utils.data.DataLoader(Dataset(),
                                       batch_size=64,
                                       drop_last=True,
                                       collate_fn=collate_fn)

In [3]:
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        from transformers import LlamaConfig, LlamaModel

        self.config = LlamaConfig(hidden_size=64,
                                  intermediate_size=64,
                                  max_position_embeddings=256,
                                  num_attention_heads=4,
                                  num_hidden_layers=4,
                                  num_key_value_heads=4,
                                  vocab_size=len(tokenizer.decoder))

        self.model = LlamaModel(self.config)
        self.lm_head = torch.nn.Linear(64, self.config.vocab_size, bias=False)

    def forward(self, input_ids, attention_mask):
        out = self.model(input_ids=input_ids,
                         attention_mask=attention_mask).last_hidden_state

        return self.lm_head(out)


model = Model()

with torch.no_grad():
    out = model(torch.ones(2, 10).long(), torch.ones(2, 10).long())

out.shape

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([2, 10, 28])

In [4]:
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast


#用于生成的包装类,非重点
class GenerateModel(PreTrainedModel):

    def __init__(self):
        super().__init__(model.config)
        self._ = torch.nn.Linear(1, 1)

    def forward(self, input_ids, attention_mask, **kwargs):
        out = model.model(input_ids=input_ids,
                          attention_mask=attention_mask,
                          **kwargs)
        logits = model.lm_head(out.last_hidden_state)

        return CausalLMOutputWithPast(logits=logits,
                                      past_key_values=out.past_key_values)

    def prepare_inputs_for_generation(self,
                                      input_ids,
                                      attention_mask,
                                      use_cache,
                                      past_key_values=None,
                                      **kwargs):
        lens, device = attention_mask.shape[1], attention_mask.device

        # 位置编码,就是简单的自增数列
        position_ids = torch.arange(lens).to(device)

        # 如果不是第一次生成,则只要最后一个字的位置编码
        if past_key_values:
            position_ids = position_ids[-1]
            input_ids = input_ids[:, -1]

        position_ids = position_ids.reshape(1, -1)
        input_ids = input_ids.reshape(1, -1)

        return {
            'input_ids': input_ids,
            'position_ids': position_ids,
            'past_key_values': past_key_values,
            'use_cache': use_cache,
            'attention_mask': attention_mask,
        }


generate_model = GenerateModel()

generate_model.generate(torch.ones(1, 10).long(), max_length=40)

tensor([[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 13, 14,  7,  1, 13, 14,  7,  1,
         13, 14,  7,  1, 13, 14,  7,  1, 13, 14,  7,  1, 13, 14,  7,  1, 13, 14,
          7,  1, 13, 14]])