In [25]:
import random


class Tokenizer:

    def __init__(self):
        self.vocab = {
            'mark': list('PSEU'), # P:padding, S:start, E:end, U:unknow
            'number': list('0123456789'),
            'letter': list('pqwertyuio'),
            'chinese_lower': list('〇一二三四五六七八九'),
            'chinese_upper': list('零壹贰叁肆伍陆柒捌玖'),
            'other': list('数字大写小母:=_'),
        }

        self.decoder = [l for vocab in self.vocab.values() for l in vocab]
        self.encoder = {j: index for index, j in enumerate(self.decoder)}

        self.label = {
            'number': 0,
            'letter': 1,
            'chinese_lower': 2,
            'chinese_upper': 3
        }
        self.prefix = ['数字', '字母', '小写', '大写']

    def decode(self, x):
        return ''.join([self.decoder[i] for i in x])

    def get_data(self, prefix: bool) -> tuple[int, list[int]]:
        # 生成:问题和答案
        question = random.randint(1000, 9999)  # int 4110
        answer = int(str(question) * 4) * 4 # int 16441644164416440
        
        question = list(str(question)) # ['4', '1', '1', '0']
        answer = list(str(answer))  # ['1', '6', '4', '4', '1', '6', '4', '4', '1', '6', '4', '4', '1', '6', '4', '4', '0']

        # 生成：标签
        label = random.choice(list(self.label.keys())) # number, letter, chinese_lower, chinese_upper
        label_id = self.label[label] # 0, 1, 2, 3

        # 改写：答案(四种答案可以随机选择)
        answer = [self.vocab[label][int(i)] for i in answer] # ['一', '六', '四', '四', '一', '六', '四', '四', '一', '六', '四', '四', '一', '六', '四', '四', '〇']

        # 组合:问题和答案
        if prefix:
            prefix = list(self.prefix[label_id])
        else:
            prefix = list('__')
        
        tokens = prefix + [':'] + question + ['='] + answer # ['小', '写', ':', '4', '1', '1', '0', '=', '一', '六', '四', '四', '一', '六', '四', '四', '一', '六', '四', '四', '一', '六', '四', '四', '〇']

        # 编码
        token_ids = [self.encoder[t] for t in tokens]
        token_ids = [self.encoder['S']] + token_ids + [self.encoder['E']]
        # ['S', '小', '写', ':', '4', '1', '1', '0', '=', '一', '六', '四', '四', '一', '六', '四', '四', '一', '六', '四', '四', '一', '六', '四', '四', '〇', 'E']

        return label_id, token_ids

    def get_batch_data(self, prefix:bool, batch_size: int)->tuple[list, list, list]:
        data = [self.get_data(prefix=prefix) for _ in range(batch_size)]

        label = [i[0] for i in data]
        token = [i[1] for i in data]

        return label, *self.batch_pad(token=token)

    def batch_pad(self, text=None, token=None):
        if text:
            #编码
            token = [[self.encoder[j] for j in i] for i in text]

        max_len_in_batch = max([len(i) for i in token])

        input_ids = []
        attention_mask = []
        for i in token:
            attention_mask.append([1] * len(i) + [0] * (max_len_in_batch - len(i)))
            input_ids.append(i + [self.encoder['P']] * (max_len_in_batch - len(i)))

        return input_ids, attention_mask


tokenizer = Tokenizer()

# decode input_ids
[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True, batch_size=4)[1]][:10]

['S字母:1168=ryuwryuwryuwryuwEP',
 'S小写:1668=六六七二六六七二六六七二六六七二EP',
 'S数字:2899=11597159715971596E',
 'S数字:4190=16761676167616760E']

In [27]:
[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=False, batch_size=8)[1]][:10]

['S__:7082=wieepieepieepiewiE',
 'S__:8723=三四八九五四八九五四八九五四八九二E',
 'S__:4814=qowtuowtuowtuowtyE',
 'S__:1892=7568756875687568EP',
 'S__:4539=18157815781578156E',
 'S__:3460=13841384138413840E',
 'S__:2402=玖陆零捌玖陆零捌玖陆零捌玖陆零捌EP',
 'S__:2138=ittwittwittwittwEP']

In [26]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [12]:
class ModelGEN(torch.nn.Module):

    def __init__(self):
        super().__init__()
        from transformers import GPT2Config, GPT2Model

        self.config = GPT2Config(bos_token_id=tokenizer.encoder['S'],
                                 eos_token_id=tokenizer.encoder['E'],
                                 n_embd=64,
                                 n_head=4,
                                 n_layer=4,
                                 n_positions=128,
                                 vocab_size=len(tokenizer.decoder))

        self.feature = GPT2Model(self.config)

        self.fc_out = torch.nn.Linear(64, self.config.vocab_size, bias=False)

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        out = self.feature(input_ids=input_ids,
                           attention_mask=attention_mask).last_hidden_state

        return self.fc_out(out)

In [13]:
class ModelCLS(torch.nn.Module):

    def __init__(self):
        super().__init__()
        from transformers import BertConfig, BertModel

        self.config = BertConfig(hidden_size=64,
                                 intermediate_size=64,
                                 max_position_embeddings=128,
                                 num_attention_heads=4,
                                 num_hidden_layers=4,
                                 vocab_size=len(tokenizer.decoder))

        self.feature = BertModel(self.config)

        self.fc_out = torch.nn.Sequential(torch.nn.Dropout(p=0.1),
                                          torch.nn.Linear(64, 4))

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        out = self.feature(input_ids=input_ids,
                           attention_mask=attention_mask).pooler_output

        return self.fc_out(out)

In [5]:
class ModelPPO(torch.nn.Module):

    def __init__(self, model_gen):
        super().__init__()
        self.model_gen = model_gen
        self.v_head = torch.nn.Sequential(torch.nn.Dropout(0.1),
                                          torch.nn.Linear(64, 1))

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.model_gen.feature(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True).last_hidden_state

        logits = self.model_gen.fc_out(last_hidden_state)
        value = self.v_head(last_hidden_state).squeeze(-1)

        return logits, value

In [6]:
generater = None


def generate(model_gen, input_ids):
    global generater
    if not generater:
        #包装类,用于生成
        from transformers import GPT2LMHeadModel
        generater = GPT2LMHeadModel(model_gen.config)
        generater.transformer = model_gen.feature
        generater.lm_head = model_gen.fc_out
        generater.to(device)

    return generater.generate(input_ids=input_ids,
                              min_length=-1,
                              top_k=0.0,
                              top_p=1.0,
                              do_sample=True,
                              pad_token_id=tokenizer.encoder['P'],
                              max_new_tokens=25,
                              eos_token_id=tokenizer.encoder['E'])