In [48]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import string
import re
from pickle import dump, load
from unicodedata import normalize

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [57]:
# # don't need to run this code as it was used to generate the eng-fra.txt file
# from lxml import etree

# def tmx_to_tabbed_txt(tmx_file, output_file):
#     """Extracts English and Hawaiian translations from a TMX file and saves them as tab-separated pairs."""
    
#     # Define the XML namespace for `xml:lang`
#     namespaces = {'xml': 'http://www.w3.org/XML/1998/namespace'}
    
#     # Parse the TMX file
#     xml_tree = etree.parse(tmx_file)
#     trans_units = xml_tree.findall(".//tu")

#     pairs = []
    
#     # Open the output file for writing
#     with open(output_file, "w", encoding="utf-8") as out_file:
#         # Iterate over each translation unit
#         for trans_unit in trans_units:
#             pair = []
#             source_text = trans_unit.find(".//tuv[@xml:lang='en']/seg", namespaces)
#             target_text = trans_unit.find(".//tuv[@xml:lang='fr']/seg", namespaces)

#             # Write the tab-separated pair if both texts are available
#             if source_text is not None and target_text is not None and source_text.text and target_text.text:
#                 out_file.write(f"{source_text.text}\t{target_text.text}\n")
#                 pair.append(target_text.text)
#                 pair.append(source_text.text)
#                 pairs.append(pair)

#     return pairs

# if __name__ == "__main__":
#     tmx_file = "./data/wiki/en-fr.tmx"
#     output_file = "./data/wiki/eng-fra.txt"
#     pairs = tmx_to_tabbed_txt(tmx_file, output_file)


In [54]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS", 2: "unk"}
        self.n_words = 3  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

    def trim_vocab(self, min_occurance):
        if "unk" not in self.word2count:
            self.word2count["unk"] = 0

        words_to_delete = [word for word, count in self.word2count.items() if count < min_occurance and word != "unk"]

        for word in words_to_delete:
            self.word2count["unk"] += self.word2count[word]
            del self.word2index[word]
            del self.word2count[word]

        self.index2word = {0: "SOS", 1: "EOS", 2: "unk"}
        self.n_words = 3  # Count SOS and EOS

        for word in self.word2count.keys():
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.n_words += 1

In [4]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return s.strip()

In [5]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/wiki/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [6]:
# # Since there are a lot of example sentences and we want to train something quickly,
# # we'll trim the data set to only relatively short and simple sentences.
# # Here the maximum length is 10 words (that includes ending punctuation) and 
# # we're filtering to sentences that translate to the form "I am" or "He is" etc. (accounting for apostrophes replaced earlier).
# MAX_LENGTH = 12

# eng_prefixes = [
#     "i am ", "i m ",
#     "he is", "he s ",
#     "she is", "she s ",
#     "you are", "you re ",
#     "we are", "we re ",
#     "they are", "they re ", "I don t", "Do you", "I want", "Are you", "I have", "I think",
#        "I can t", "I was", "He is", "I m not", "This is", "I just", "I didn t",
#        "I am", "I thought", "I know", "Tom is", "I had", "Did you", "Have you",
#        "Can you", "He was", "You don t", "I d like", "It was", "You should",
#        "Would you", "I like", "It is", "She is", "You can t", "He has",
#        "What do", "If you", "I need", "No one", "You are", "You have",
#        "I feel", "I really", "Why don t", "I hope", "I will", "We have",
#        "You re not", "You re very", "She was", "I love", "You must", "I can"]
# eng_prefixes = (map(lambda x: x.lower(), eng_prefixes))
# eng_prefixes = tuple(eng_prefixes)

# def filterPair(p):
#     return len(p[0].split(' ')) < MAX_LENGTH and \
#         len(p[1].split(' ')) < MAX_LENGTH and \
#         p[1].startswith(eng_prefixes)


# def filterPairs(pairs):
#     return [pair for pair in pairs if filterPair(pair)]

In [59]:
# Read text file and split into lines, split lines into pairs
# Normalize text, filter by length and content
# Make word lists from sentences in pairs
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))

    # trim down the data according to specification above
    # pairs = filterPairs(pairs)
    # print("Trimmed to %s sentence pairs" % len(pairs))

    cleaned_pairs = [pair for pair in pairs if len(pair[0]) > 0 and len(pair[1]) > 0]
    print(f'After cleaning, there are {len(cleaned_pairs)} pairs of sentence.')
    
    print("Counting words...")
    for pair in cleaned_pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    # remove all words with a frequency below a threshold
    print("Mark all OOV with 'unk' for all lines")
    input_lang.trim_vocab(2)
    output_lang.trim_vocab(2)

    # update pairs with 'unk'
    new_pairs = []
    for pair in cleaned_pairs:
        new_pair = []
        input_tokens = []
        output_tokens = []
        for word in pair[0].split(' '):
            if word in input_lang.word2index:
                input_tokens.append(word)
            else:
                input_tokens.append('unk')

        for word in pair[1].split(' '):
            if word in output_lang.word2index:
                output_tokens.append(word)
            else:
                output_tokens.append('unk')

        new_pair.append(" ".join(input_tokens))
        new_pair.append(" ".join(output_tokens))
        new_pairs.append(new_pair)

    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)

    return input_lang, output_lang, new_pairs

input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
# print(random.choice(pairs))

Reading lines...
Read 1365840 sentence pairs
Counting words...
Mark all OOV with 'unk' for all lines
Counted words:
fra 303770
eng 278545


In [60]:
len(pairs)

1365840

In [61]:
for pair in pairs:
    if len(pair[0]) == 0:
        print(pair)

['', 'dutton mountain digital raster quadrangle map new york state department of transportation']
['', 'theodore roosevelt centennial site archived from the original on july']
['', 'fed display requires a vacuum to operate so the display tube has to be sealed and mechanically robust however since the distance between the emitters and phosphors is quite small generally a few millimeters the screen can be mechanically reinforced by placing spacer strips or posts between the front and back face of the tube']
['', 'they started work as newsies because their father suffered an accident at work resulting in the termination of his employment seeing young les as an opportunity to sell more papers jack offers to help the boys meanwhile the publisher of the new york world joseph pulitzer increases the cost of the newspapers to the delivery boys so as to outsell his competitors the bottom line']
['', 'on march brown claimed that he saw arsenal midfielder and captain cesc fabregas spit at the feet

In [45]:
output_lang.word2count['philosopher']

914

In [46]:
input_lang.word2count['philosophe']

856

In [63]:
!git --version

git version 2.39.3 (Apple Git-146)
