In [1]:
%matplotlib inline

In [2]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
    
    def addWords(self, words):
        for word in words:
            addWord(word)

The files are all in Unicode, to simplify we will turn Unicode
characters to ASCII, make everything lowercase, and trim most
punctuation.




In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [16]:
def readLangs(articles_file, titles_file):
    print("Reading lines...")

    # Read the file and split into lines
    articles_lines = open('data/'+articles_file, encoding='utf-8').read().strip().split('\n')
    titles_lines = open('data/'+titles_file, encoding='utf-8').read().strip().split('\n')

    # Split every line into pairs and normalize
    articles_lines = [[normalizeString(s) for s in articles_lines]]
    titles_lines = [[normalizeString(s) for s in titles_lines]]

    print(articles_lines[0][0])
    

    #return input_lang, output_lang, pairs

In [17]:
readLangs("mini.art.txt","mini.tit.txt")

Reading lines...
['australia s current account deficit shrunk by a record . billion dollars lrb . billion us rrb in the june quarter due to soaring commodity prices figures released monday showed .', 'at least two people were killed in a suspected bomb attack on a passenger bus in the strife torn southern philippines on monday the military said .', 'australian shares closed down . percent monday following a weak lead from the united states and lower commodity prices dealers said .', 'south korea s nuclear envoy kim sook urged north korea monday to restart work to disable its nuclear plants and stop its typical brinkmanship in negotiations .', 'south korea on monday announced sweeping tax reforms including income and corporate tax cuts to boost growth by stimulating sluggish private consumption and business investment .', 'taiwan share prices closed down . percent monday on wall street weakness and lacklustre interim earnings from electronics manufacturing giant hon hai dealers said .',

In [7]:
pairs[1][0]

'run !'

Since there are a *lot* of example sentences and we want to train
something quickly, we'll trim the data set to only relatively short and
simple sentences. Here the maximum length is 10 words (that includes
ending punctuation) and we're filtering to sentences that translate to
the form "I am" or "He is" etc. (accounting for apostrophes replaced
earlier).




In [8]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

The full process for preparing the data is:

-  Read text file and split into lines, split lines into pairs
-  Normalize text, filter by length and content
-  Make word lists from sentences in pairs




In [None]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))