<a href="https://colab.research.google.com/github/cerasole/ml4hep/blob/main/RNNs/torch_Encoder_Decoder_for_Neural_Machine_Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Machine translation using RNNs

The theory of machine translation using RNNs is described, for instance, in Chapter 16 of https://github.com/ageron/handson-ml3.

Basics of the algorithm
- Input: sentence in a given language,
- Output: translation of the input sentence in a different language.

The architecture consists in an encoder-decoder system.
- The input sentence is fed to the encoder, which transforms the input into a low-dimensional latent representation
- The decoder transforms the latent representation into an output sentence.

In [56]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata

import string
import re
import random

import numpy as np

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
device

device(type='cuda')

We define a class for handling languages and sentences.

This class will need to
- register the words in the sentences given as input to the class

In the preprocessing, we will uniform the sentences to a standard.

In [4]:
SOS_token = 0   # Start-of-Sentence
EOS_token = 1   # End-of-Sentence

class Lang:

    def __init__(self, name):
        self.name = name # name of the language
        self.word2index = {} # dictionary containing words: indices when they were first inserted in the dictionary
        self.word2count = {} # dictionary containing words: number of times they were inserted in the dictionary
        self.index2word = {0: "SOS", 1: "EOS"} # dictionary contiaining indices of first insertion in the dictionary: words
        self.n_words = 2  # Total counts of words in the dictionary, including SOS and EOS

    def addWord(self, word):
        # When we add a word, we have to check if it is already present in the dictionary, e.g. in the word2index.
        # If it is not present, we need to
        #  - add this word to the self.word2index dictionary, giving it as index the current self.n_words,
        #  - add this word to the self.word2count dictionary, giving it 1 count,
        #  - add, in the self.index2word dictionary, using as index given by self.n_words, the word itself,
        #  - increase by 1 the number of total words, aka self.n_words.
        # If it is already present, we need to
        #  - increase by 1 the corresponding self.word2counts entry.
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

    def addSentence(self, sentence):
        # Add words splitting the sentence by space.
        # It would be better if the sentence is already standardized.
        # We will take care of this in the preprocessing of the sentences.
        for word in sentence.split(' '):
            self.addWord(word)


In [7]:
lang = Lang("prova")
print (lang.word2count, lang.word2index, lang.index2word, lang.n_words)
lang.addSentence("could you please stop the noise?")
print (lang.word2count, lang.word2index, lang.index2word, lang.n_words)

{} {} {0: 'SOS', 1: 'EOS'} 2
{'could': 1, 'you': 1, 'please': 1, 'stop': 1, 'the': 1, 'noise?': 1} {'could': 2, 'you': 3, 'please': 4, 'stop': 5, 'the': 6, 'noise?': 7} {0: 'SOS', 1: 'EOS', 2: 'could', 3: 'you', 4: 'please', 5: 'stop', 6: 'the', 7: 'noise?'} 8


In the next cells, we will investigate methods to standardize the input sentences:
- transform to lower case
- all special characters need to be treated differently
- Take into account abbreviations

In [70]:
# https://stackoverflow.com/a/518232/2809427
# Turn a Unicode string to plain ASCIIi, i.e.
# - remove the accents without changing the letter
# - turn letters in different languages into the corresponding ASCII
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

In [16]:
unicodeToAscii("ciàò!")

'ciao!'

In [69]:
def normalizeString(s):
    # string.lower() to go to lowercase
    # string.strip() to eliminate the first (and last) character, if  blank space
    # unicodeToAscii(string) to transform to plain ASCII
    # re.sub(pattern, repl, string) returns the string obtained by replacing the
    #  leftmost non-overlapping occurrences of the pattern in string by the replacement repl
    #  - within the pattern, the outermost () indicate that we may want to "capture" the pattern and re-use it in the replacement.
    #     Indeed, in the replacement we use r" \1" to say that we want to replace the pattern with " {same_pattern}". We need to specify
    #     r otherwise he considers "\" as a normal backslash character, when instead we want to use it as a special character into \1
    #  - the [] are used because out pattern is composed by several characters, not just one. Indeed, we want to replace ".", "!" and "?".
    #  - So, the second command will transform "." to " .", "!" to " !", "?" to " ?", "??" to " ? ?"
    #  - In the second command, we do something else.
    #    - Again we use [] to indicate a group of characters for the pattern.
    #    - We use ^ to indicate that we want to consider as pattern everything that is *not* indicated in the [].
    #      a-zA-Z indicate that we don't want to consider lowercase nor uppercase letters (even though we already did lower everything)
    #      .!? indicate that we don't want to consider to three characters ., ! and ?
    #      So, we want to consider everything that is not a letter, nor a ., nor a !, nor a ?,
    #      and we want to delete it and replace it with a single space!
    #      But, in this way, a string like "***" would be transformed to "   ", which is not good as later we will use "   ".split(" ").
    #      To solve this, there is the last + in the pattern string, which says that the pattern can be composed by one or more
    #      consecutive characters satisfying the same condition.
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [21]:
"Ciao!".lower()

'ciao!'

In [19]:
" ciao ! ".strip()

'ciao !'

In [30]:
re.sub(r"([?])", r" \1", "Ciao??")

'Ciao ? ?'

In [36]:
re.sub(r"[^a-zA-Z.!?]+", r" ", "a ?*^*:-_=+"), re.sub(r"[^a-zA-Z.!?]", r" ", "a ?*^*:-_=+")

('a ? ', 'a ?        ')

### Download the data

In [44]:
!wget https://download.pytorch.org/tutorial/data.zip
!unzip data.zip

--2024-08-26 09:14:19--  https://download.pytorch.org/tutorial/data.zip
Resolving download.pytorch.org (download.pytorch.org)... 18.238.238.114, 18.238.238.82, 18.238.238.23, ...
Connecting to download.pytorch.org (download.pytorch.org)|18.238.238.114|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2882130 (2.7M) [application/zip]
Saving to: ‘data.zip’


2024-08-26 09:14:19 (58.7 MB/s) - ‘data.zip’ saved [2882130/2882130]

Archive:  data.zip
   creating: data/
  inflating: data/eng-fra.txt        
   creating: data/names/
  inflating: data/names/Arabic.txt   
  inflating: data/names/Chinese.txt  
  inflating: data/names/Czech.txt    
  inflating: data/names/Dutch.txt    
  inflating: data/names/English.txt  
  inflating: data/names/French.txt   
  inflating: data/names/German.txt   
  inflating: data/names/Greek.txt    
  inflating: data/names/Irish.txt    
  inflating: data/names/Italian.txt  
  inflating: data/names/Japanese.txt  
  inflating: data/names/Kor

In [95]:
filename = "data/eng-fra.txt"
lines = open(filename, encoding = "utf-8").read().strip().split("\n")
pairs = [
    [
        normalizeString(s) for s in l.split("\t")
    ] for l in lines
]

In [96]:
len(pairs)

135842

In [97]:
pairs[0]

['go .', 'va !']

In [89]:
list(reversed(pairs[0]))

['go .', 'va !']

In [77]:
def reverse_pairs(pairs):
    return [list(reversed(pair)) for pair in pairs]

### Set the languages for the translator

In [115]:
############# Command for an English to French translator #############
lang1 = "eng"
lang2 = "fra"
english_index = 0
if pairs[0][0] != "go .":
  pairs = reverse_pairs(pairs)
input_lang, output_lang = Lang(lang1), Lang(lang2)

In [136]:
############# Command for an English to French translator #############
lang1 = "fra"
lang2 = "eng"
english_index = 1
if pairs[0][0] == "go .":
  pairs = reverse_pairs(pairs)
input_lang, output_lang = Lang(lang1), Lang(lang2)

Here we want to introduce an **enormous** cut in the sentences.
Among all the 135k sentences, we decide to retain only those that satisfy these requirements.
- Both lengths of the sentences in the two languages have to be smaller than MAX_LENGTH = 10
- The english sentences have to begin with personal pronoun + present simple of to be

In [122]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is ", "he s ",   # note that we need to include the blank space, otherwise also the sentence "he stopped" would be included
    "she is ", "she s ",
    "you are ", "you re ",
    "we are ", "we re ",
    "they are ", "they re "
)

def filterPair(p, english_index):
    condition = True
    # Cut on the length of the sentences in the two languages
    for i in range(2):
      length = len(p[i].split(" "))
      condition *= (length < MAX_LENGTH)
    # Cut that the english sentence has to start with one of the sentences in the eng_prefixes
    condition *= (p[english_index].startswith(eng_prefixes))
    return condition

def filterPairs(pairs, english_index):
    return [pair for pair in pairs if filterPair(pair, english_index)]

cut_pairs = filterPairs(pairs, english_index)

In [132]:
print (f"From {len(pairs)} sentences, we ended up with {len(cut_pairs)} sentences after the cut.")

From 135842 sentences, we ended up with 10522 sentences after the cut.


In [138]:
pairs[0][0]

'va !'

In [139]:
### Fill the Lang instances

if input_lang.n_words == 2:  # if not already filled
    for pair in cut_pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

In [140]:
input_lang.name, input_lang.n_words

('fra', 2802)

In [141]:
output_lang.name, output_lang.n_words

('eng', 4341)