In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
SOS_token = 0
EOS_token = 1


class Lang:
    """
    name:语言名称
    index2word:标索引单词
    word2index：单词索引下标
    word2count：单词索引单次数量
    """
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.index2word = {0: 'SOS_token', 1: 'EOS_token'}
        self.word2count = {}
        self.n_word = 2

    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index.keys():
            self.index2word[self.n_word] = word
            self.word2index[word] = self.n_word
            self.word2count[word] = 1
            self.n_word += 1
        else:
            self.word2count[word] += 1

In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicode2ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                   if unicodedata.category(c) != 'Mn')


# Lowercase, trim, and remove non-letter characters


def normalize_string(s):
    s = unicode2ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [4]:
def readLang(lan1, lan2):
    """
    读取文本文件，返回两个语言类Lang实例和对应的成对的语言列表
    """
    print("reading lines of language %s to %s" % (lan1, lan2))

    with open('./data/eng-fra.txt', encoding='UTF-8') as f:
        lines = f.readlines()
    pairs = []
    for line in lines:
        sentences = line.strip().split('\t')
        if len(sentences) < 2:
            print(line)
            print('-------error!------')
            exit(-1)
        pairs.append([
            normalize_string(sentences[0].strip()),
            normalize_string(sentences[1].strip())
        ])
    lang_src = Lang(lan1)
    lang_dst = Lang(lan2)
#     for pair in pairs:
#         lang_src.add_sentence(pair[0])
#         lang_dst.add_sentence(pair[1])

    return lang_src, lang_dst, pairs


lang_src, lang_dst, lang_pairs = readLang('eng', 'fra')
print(lang_pairs[0:5])

reading lines of language eng to fra
[['go .', 'va !'], ['run !', 'cours !'], ['run !', 'courez !'], ['wow !', 'ca alors !'], ['fire !', 'au feu !']]


In [5]:
print(len(lang_pairs[0][0]))

4


In [6]:
MAX_WORD = 10


def filter_pairs(pairs):
    return [
        pair for pair in pairs if len(pair[0].split(' ')) < MAX_WORD
        and len(pair[1].split(' ')) < MAX_WORD
    ]

lang_pairs = filter_pairs(lang_pairs)
print(len(lang_pairs))

95170


In [7]:

for pair in lang_pairs:
    lang_src.add_sentence(pair[0])
    lang_dst.add_sentence(pair[1])

print(lang_src.name, lang_src.n_word)
print(lang_dst.name, lang_dst.n_word)


eng 10025
fra 16813


In [8]:
class EncoderRnn(nn.Module):
    def __init__(self, dict_size, output_size, hidden_size):
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(dict_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input_data, hidden):
        """
        gru要求的输入形状为 seqlen*batch*dims，
        此处输入为单个词的word2index索引，使用embedding转换为指定尾数的输入向量
        input_data: [index] 
        hidden: (1*1*hidden_size), 1*1为本处指定，句子长度为1，batch为1
        """
        embedded = self.embedding(input_data).view(1, 1, -1)
        output_data, hidden = self.gru(embedded, hidden)
        return output_data, hidden

    def init_hidden():
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [9]:
class DecoderRnn(nn.Module):
    
    def __init__(self, hidden_size, dict_size):
        super(DecoderRnn, self).__init__()
        self.embedding = nn.Embedding(dict_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, dict_size)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, input_data, hidden):
        embedded = self.embedding(input_data).view(1, 1, -1)
        embedded = F.relu(embedded)
        gru_out, gru_hidden = self.gru(embedded, hidden)
        gru_out = self.out(gru_out)
        output_data = self.softmax(gru_out[0])
        return output_data, hidden
    
        

In [10]:
class DecoderAttnRnn(nn.Module):
    def __init__(self, hidden_size, dict_size, drop = 0.1, maxlen = MAX_WORD):
        super(DecoderAttnRnn, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding_size = hidden_size
        
        self.embedding = nn.Embedding(dict_size, embedding_size)
        self.dropout = nn.Dropout(0.1)
        self.attn = nn.Linear(hidden_size + embedding_size, maxlen)
        self.attn_cat_hidden = nn.Linear(embedding_size*2, embedding_size)
        self.gru = nn.GRU(embedding_size, hidden_size)
        self.out = nn.Linear(hidden_size, dict_size)
        
    def forward(self, input_data, hidden, encoder_outputs):
        embedded = self.embedding(input_data).view(1, 1, -1)
        droped_embedded = self.dropout(embedded)
        
        attn_in = torch.cat((droped_embedded[0], hidden[0]), dim=1)
        attn_weight = F.softmax(self.attn(attn_in), dim=1)
        
        attn_applied = torch.bmm(attn_weight.unsqueeze(0)
                                encoder_outputs.unsqueeze(0))
        attn_cat_hidden = self.attn_cat_hidden(
            torch.cat((attn_applied[0], hidden[0]), dim=1)
        ).unsqueeze(0)
        gru_in = F.relu(attn_cat_hidden)
        out, hidden = self.gru(gru_in, hidden)
        out = self.out(F.log_softmax(out[0], dim=1))
        return out, hidden, attn_weight

SyntaxError: unexpected EOF while parsing (<ipython-input-10-1e096d7f8657>, line 1)