## ==============Translation with a Sequence to Sequence Network and Attention===============

In [1]:
from io import open
import string
import string
import re
import random
import unicodedata

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: 'SOS', 1: 'EOS'}
        self.n_words = 2
        
    def addSentence(self, sentence):
        for wd in sentence.split(' '):
            self.addWord(wd)
            
    def addWord(self, wd):
        if wd in self.word2index:
            self.word2count[wd] += 1
        else:
            self.word2index[wd] = self.n_words
            self.word2count[wd] = 1
            self.index2word[self.n_words] = wd
            self.n_words += 1
    

In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [4]:
def readLangs(lang1, lang2, reverse=False):
    lines = open('classify names with char rnn/data/{}-{}.txt'.format(lang1, lang2), encoding='utf-8').read().strip().split('\n')
    pairs = [[normalizeString(s) for s in line.strip().split('\t')] for line in lines]
    
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lange(lang1)
        output_lang = Lange(lang2)
    return input_lang, output_lang, pairs
    

In [5]:
#只取一部分数据集做demo

In [6]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(pair):
    return len(pair[0].strip().split(' ')) < MAX_LENGTH and len(pair[1].strip().split(' ')) < MAX_LENGTH and pair[1].strip().startswith(eng_prefixes)

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]
    


In [7]:
def prepareData(lang1, lang2, reverse=False):
    print('=========start prepar date========')
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print('original pairs: ', len(pairs))
    pairs = filterPairs(pairs)
    print('after filter, paris: ', len(pairs))
    
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
        
    print('Lang1/(input): ' + input_lang.name + ', its vocab size: ' + str(input_lang.n_words))
    print('Lang2/(output): ' + output_lang.name + ', its vocab size: ' + str(output_lang.n_words))
    
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

original pairs:  135842
after filter, paris:  10853
Lang1/(input): fra, its vocab size: 4489
Lang2/(output): eng, its vocab size: 2925
['je ne suis pas surprise .', 'i m not surprised .']


In [8]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embeddings = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        #self.hidden = self.initHidden()
        
    def forward(self, input_, hidden):
        embed = self.embeddings(input_).view(1, 1, -1)
        output = embed
        output, hidden = self.gru(output, hidden)
        
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
             

In [9]:
class DecoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(DecoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.embeddings = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size, input_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input_, hidden):
        embed = self.embeddings(input_).view(1, 1, -1)
        ouput = F.relu(embed)
        output, hidden = self.gru(output, hidden)
        output = self.linear(output[0])
        output = self.softmax(output)
        return output, hidden
        
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [11]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, output_size, hidden_size, max_len=MAX_LENGTH, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.max_len = max_len
        self.dropout_p = dropout_p
        
        self.embeddings = nn.embeddins(output_size, hidden_size)
        self.attn = nn.Linear(hidden_size * 2, max_len)
        self.attn_combine = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(dropout_p)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
    def forward(self, input_, hidden, encoder_outs):
        embed = self.embeddings(input_).view(1, 1, -1)
        embed = self.dropout(embed)
        
        attn_wei = F.softmax(self.attn(torch.cat((embed[0], hidden[0]), 1)), dim=1)
        context_vec = torch.bmm(attn_wei.unsqueeze(0), encoder_outs.unsqueeze(0))
        
        attn_combine = self.attn_combine(torch.cat((embed[0], context_vec[0]), 1)).unsqueeze(0)
        gru_in = F.relu(attn_combine)
        
        output, hidden = self.gru(gru_in, hidden)
        out = F.log_softmax(self.out(output[0]), dim=1)
        return out, hidden, attn_wei
        
        return out, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
        
        
    
        

In [12]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[w] for w in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    
    return input_tensor, target_tensor