In [106]:
import torch
from torch import nn
import matplotlib.pyplot as plt

In [107]:
class Vocab:
    def __init__(self, tokens) -> None:
        if isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        import collections
        counter = collections.Counter(tokens)
        tokens = [
            k for k, _ in sorted(
                counter.items(), key=lambda item: item[1], reverse=True)
        ]
        tokens.insert(0, '<unk>')
        self.tokens_indicates = {
            token: idx
            for idx, token in enumerate(tokens)
        }
        self.indicates_tokens = {
            v: k
            for k, v in self.tokens_indicates.items()
        }

    @property
    def unk(self):
        return 0

    def __len__(self):
        return len(self.tokens_indicates)

    def __getitem__(self, keys):
        if isinstance(keys, str):
            return self.tokens_indicates[keys]
        if isinstance(keys, list):
            return [self.__getitem__(key) for key in keys]
        if isinstance(keys, (torch.Tensor)):
            keys = keys.reshape(-1)
            return ''.join(self.indicates_tokens[int(keys[i])]
                           for i in range(keys.numel()))
        return self.indicates_tokens[keys]


def tokenize(lines: list, steps=32):
    import jieba
    import zhconv
    import re
    truncate = lambda sen, l: sen[:l] if len(sen) > l else sen + ['<pad>'] * (
        l - len(sen))
    en, zh = [], []
    for line in lines:
        sentence_en, sentence_zh, _ = line.split('\t')
        sentence_en = [
            i for i in re.sub('[^A-Za-z ]+', lambda m: f' {m.group()} ',
                              sentence_en).lower().split(' ')
            if i != '' and i != ' '
        ]
        sentence_zh = list(
            jieba.cut(zhconv.convert(sentence_zh, 'zh-cn'), cut_all=False))
        sentence_en = sentence_en + ['<eos>']
        sentence_zh = ['<bos>'] + sentence_zh + ['<eos>']
        sentence_en = truncate(sentence_en, steps)
        sentence_zh = truncate(sentence_zh, steps)
        en.append(sentence_en)
        zh.append(sentence_zh)
    return en, zh


def read_data():
    with open('../rnn/en_zh.trans.txt', 'r') as f:
        return f.readlines()


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, data_raw, steps=24) -> None:
        super().__init__()
        en, zh = data_raw
        self.vocab_en, self.vocab_zh = Vocab(en), Vocab(zh)
        self.corpus_en, self.corpus_zh = torch.tensor(
            self.vocab_en[en],
            dtype=torch.float32), torch.tensor(self.vocab_zh[zh],
                                               dtype=torch.float32)
        self.valid_len_en, self.valid_len_zh = (
            self.corpus_en != self.vocab_en['<pad>']).sum(
                dim=1), (self.corpus_zh != self.vocab_zh['<pad>']).sum(dim=1)
        self._len = len(self.corpus_en)
        del en, zh
        assert self._len == len(self.corpus_zh) == len(
            self.valid_len_en) == len(self.valid_len_zh)

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        return self.corpus_en[idx], self.valid_len_en[idx], self.corpus_zh[
            idx], self.valid_len_zh[idx]


In [108]:
dataset = _Dataset(tokenize(read_data()), steps=19)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)
train_iter = iter(dataloader)