In [1]:
import torch


class Tokenizer:

    def __init__(self):
        vocab = ['B', 'E', 'P'] + list('0123456789+-*/=.')
        self.vocab_encode = {j: i for i, j in enumerate(vocab)}
        self.vocab_decode = {i: j for i, j in enumerate(vocab)}

        self.bos_token = 'B'
        self.eos_token = 'E'
        self.pad_token = 'P'
        self.bos_token_id = self.vocab_encode['B']
        self.eos_token_id = self.vocab_encode['E']
        self.pad_token_id = self.vocab_encode['P']
        self.eq_token_id = self.vocab_encode['=']

    def encode(self,
               text,
               padding=True,
               truncation=True,
               max_length=128,
               add_bos_token=True,
               add_eos_token=True,
               padding_side='right',
               device='cpu'):
        input_ids = [[self.vocab_encode[j] for j in i] for i in text]

        if add_bos_token:
            input_ids = [[self.bos_token_id] + i for i in input_ids]
        if add_eos_token:
            input_ids = [i + [self.eos_token_id] for i in input_ids]

        if padding:
            lens = max([len(i) for i in input_ids])

            if padding == 'max_length':
                lens = max_length

            if padding_side == 'right':
                input_ids = [
                    i + [self.pad_token_id] * (lens - len(i))
                    for i in input_ids
                ]

            if padding_side == 'left':
                input_ids = [[self.pad_token_id] * (lens - len(i)) + i
                             for i in input_ids]

        if truncation:
            input_ids = [i[:max_length] for i in input_ids]

            for i in input_ids:
                if add_eos_token and i[-1] != self.eos_token_id and i[
                        -1] != self.pad_token_id:
                    i[-1] = self.eos_token_id

        input_ids = torch.LongTensor(input_ids).to(device)
        attention_mask = (input_ids != self.pad_token_id).long().to(device)

        return {'input_ids': input_ids, 'attention_mask': attention_mask}

    def decode(self, input_ids, ignore_pad=True):
        if (hasattr(input_ids, 'tolist')):
            input_ids = input_ids.tolist()

        if ignore_pad:
            if tokenizer.bos_token_id in input_ids:
                index = input_ids.index(tokenizer.bos_token_id)
                input_ids = input_ids[index:]

            if tokenizer.eos_token_id in input_ids:
                index = input_ids.index(tokenizer.eos_token_id) + 1
                input_ids = input_ids[:index]

        text = [self.vocab_decode[i] for i in input_ids]

        return ''.join(text)

    def __len__(self):
        return len(self.vocab_encode)

    def __call__(self, *args, **kwargs):
        return self.encode(*args, **kwargs)


tokenizer = Tokenizer()

tokenizer(['5+1234=123456', '9999+9999=9999999'],
          padding='max_length',
          truncation=True,
          max_length=32,
          add_bos_token=False,
          add_eos_token=False,
          padding_side='left',
          device='cuda')

{'input_ids': tensor([[ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           2,  8, 13,  4,  5,  6,  7, 17,  4,  5,  6,  7,  8,  9],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 12, 12, 12,
          12, 13, 12, 12, 12, 12, 17, 12, 12, 12, 12, 12, 12, 12]],
        device='cuda:0'),
 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}