In [8]:
import torch
from torch import nn
import matplotlib.pyplot as plt

In [9]:
class Vocab:
    def __init__(self, tokens) -> None:
        if isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        import collections
        counter = collections.Counter(tokens)
        tokens = [
            k for k, _ in sorted(
                counter.items(), key=lambda item: item[1], reverse=True)
        ]
        tokens.insert(0, '<unk>')
        self.tokens_indicates = {
            token: idx
            for idx, token in enumerate(tokens)
        }
        self.indicates_tokens = {
            v: k
            for k, v in self.tokens_indicates.items()
        }

    @property
    def unk(self):
        return 0

    def __len__(self):
        return len(self.tokens_indicates)

    def __getitem__(self, keys):
        if isinstance(keys, str):
            return self.tokens_indicates[keys]
        if isinstance(keys, list):
            return [self.__getitem__(key) for key in keys]
        if isinstance(keys, (torch.Tensor)):
            keys = keys.reshape(-1)
            return ''.join(self.indicates_tokens[int(keys[i])]
                           for i in range(keys.numel()))
        return self.indicates_tokens[keys]

truncate = lambda sen, l: sen[:l] if len(sen) > l else sen + ['<pad>'] * (
    l - len(sen))

def tokenize_zh(sentence, steps):
    import jieba
    import zhconv
    sentence_zh = list(
            jieba.cut(zhconv.convert(sentence, 'zh-cn'), cut_all=False))

    sentence_zh = ['<bos>'] + sentence_zh + ['<eos>']
    sentence_zh = truncate(sentence_zh, steps)
    return sentence_zh

def tokenize_en(sentence, steps):
    import re
    sentence_en = [
        i for i in re.sub('[^A-Za-z ]+', lambda m: f' {m.group()} ',
                        sentence).lower().split(' ')
        if i != '' and i != ' '
    ]
    sentence_en = sentence_en + ['<eos>']
    sentence_en = truncate(sentence_en, steps)
    return sentence_en

def tokenize(lines: list, steps=32):
    en, zh = [], []
    for line in lines:
        sentence_en, sentence_zh, _ = line.split('\t')
        en.append(tokenize_en(sentence_en, steps))
        zh.append(tokenize_zh(sentence_zh, steps))
    return en, zh


def read_data():
    with open('../rnn/en_zh.trans.txt', 'r') as f:
        return f.readlines()


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, data_raw) -> None:
        super().__init__()
        en, zh = data_raw
        self.vocab_en, self.vocab_zh = Vocab(en), Vocab(zh)
        self.corpus_en, self.corpus_zh = torch.tensor(
            self.vocab_en[en],
            dtype=torch.int64), torch.tensor(self.vocab_zh[zh],
                                               dtype=torch.int64)
        self.valid_len_en, self.valid_len_zh = (
            self.corpus_en != self.vocab_en['<pad>']).sum(
                dim=1), (self.corpus_zh != self.vocab_zh['<pad>']).sum(dim=1)
        self._len = len(self.corpus_en)
        del en, zh
        assert self._len == len(self.corpus_zh) == len(
            self.valid_len_en) == len(self.valid_len_zh)

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        return self.corpus_en[idx], self.valid_len_en[idx], self.corpus_zh[
            idx], self.valid_len_zh[idx]


In [10]:
class Encoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 embed_size,
                 n_hiddens,
                 n_layers,
                 dropout=0.) -> None:
        super().__init__()
        # embedding layer similar as `one-hot` transformation.
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, n_hiddens, n_layers, dropout=dropout)

    def forward(self, x: torch.Tensor):
        # x.shape = (batch_size, n_steps)
        # embed.shape = (n_steps, batch_size, embed_size)
        embed = self.embedding(x.T.type(torch.int64))
        # output.shape = (n_steps, batch_size, n_hiddens)
        # state.shape = (n_layers, batch_size, n_hiddens)
        output, state = self.rnn(embed)
        return output, state


class Decoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 embed_size,
                 n_hiddens,
                 n_layers,
                 dropout=0.) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + n_hiddens,
                          n_hiddens,
                          n_layers,
                          dropout=dropout)
        self.fc = nn.Linear(n_hiddens, vocab_size)

    def init_state(self, encode_outputs):
        encode_output, encode_state = encode_outputs
        return encode_state

    def forward(self, x, es: torch.Tensor):
        embed = self.embedding(x.T.type(torch.int64))
        # context using last layer of encode state
        # c.shape = (batch_size, n_hiddens)
        c = es[-1]
        c = c.repeat(embed.shape[0], 1, 1)
        embed_c = torch.cat((embed, c), -1)
        # decoder initial state as final `output_state` of encoder
        output, state = self.rnn(embed_c, es)
        output = self.fc(output)
        # output.shape = (batch_size, n_steps, vocab_size)
        return output, state


class EncoderDecoder(nn.Module):
    def __init__(self,
                 vocab_size_source,
                 vocab_size_target,
                 embed_size,
                 n_hiddens,
                 n_layers,
                 dropout=0.) -> None:
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size_source,
                               embed_size=embed_size,
                               n_hiddens=n_hiddens,
                               n_layers=n_layers,
                               dropout=dropout)
        self.decoder = Decoder(vocab_size=vocab_size_target,
                               embed_size=embed_size,
                               n_hiddens=n_hiddens,
                               n_layers=n_layers,
                               dropout=dropout)
    
    def forward(self, x, y):
        output = self.encoder(x)
        decode_state = self.decoder.init_state(output)
        output, _ = self.decoder(y, decode_state)
        return output

def loss_masking(loss, y, pad_indicate):
    mask = (y != pad_indicate).type(torch.float32)
    return (loss * mask).sum() / mask.sum()

def grad_clip(net: nn.Module, clip_val=1):
    params = [p for p in net.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > clip_val:
        for parm in params:
            parm.grad[:] *= clip_val / norm

In [11]:
# hyperparameter
EPOCHS = 50
LR = 1e-3
BATCH_SIZE = 128
STEPS = 19
EMBED_SIZE = 256
HIDDENS = 256
LAYERS = 3
DROPOUT = .3

In [12]:

dataset = _Dataset(tokenize(read_data(), steps=STEPS))
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True)
net = EncoderDecoder(vocab_size_source=len(dataset.vocab_en),
                     vocab_size_target=len(dataset.vocab_zh),
                     embed_size=EMBED_SIZE,
                     n_hiddens=HIDDENS,
                     n_layers=LAYERS,
                     dropout=DROPOUT)
loss_fn = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
losses, loss = [], None
net.zero_grad()
for epoch in range(1, EPOCHS + 1):
    train_iter = iter(dataloader)
    for x, _, y, _ in train_iter:
        optimizer.zero_grad()
        # 1024 32 28
        output = net(x, y)
        # permute axises to (batch_size, vocab_size, n_steps)
        # loss_fn required input.shape=(B, C, ...) where C = vocab_size
        # see also `torch.nn.CrossEntropyLoss` for detail.
        output = output.permute((1, 2, 0))
        loss = loss_fn(output, y)
        loss = loss_masking(loss, y, dataset.vocab_zh['<pad>'])
        with torch.no_grad():
            loss.backward()
            grad_clip(net)
            optimizer.step()
    print(f"\repoch: [{epoch}/{EPOCHS}", end='')
    losses.append(loss.detach().numpy())
    # print(f'loss: {loss: .6f}')
print()
plt.plot(losses)

KeyboardInterrupt: 

In [15]:
def predict(sentence):
    softmax = nn.LogSoftmax(dim=1)
    token = tokenize_en(sentence, steps=STEPS)
    x = torch.tensor(dataset.vocab_en[token], dtype=torch.int64)
    x = x.unsqueeze(0)
    output = net.encoder(x)
    state = net.decoder.init_state(output)
    y = dataset.vocab_zh['<bos>']
    y = torch.tensor([y], dtype=torch.int64)
    # add batch_axis
    y = y.unsqueeze(0)
    outputs = []
    for _ in range(STEPS):
        output, _ = net.decoder(y, state)
        
        outputs.append(output.argmax(dim=2))
        y = outputs[-1]
    return ''.join(dataset.vocab_zh[outputs])

In [16]:
predict("hello")

'<bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos>'