# Imports + Prepare Data

In [1]:
%matplotlib inline

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

!wget https://download.pytorch.org/tutorial/data.zip

!unzip data.zip

--2021-07-22 14:37:12--  https://download.pytorch.org/tutorial/data.zip
Resolving download.pytorch.org (download.pytorch.org)... 13.224.159.21, 13.224.159.119, 13.224.159.78, ...
Connecting to download.pytorch.org (download.pytorch.org)|13.224.159.21|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2882130 (2.7M) [application/zip]
Saving to: ‘data.zip’


2021-07-22 14:37:13 (33.4 MB/s) - ‘data.zip’ saved [2882130/2882130]

Archive:  data.zip
   creating: data/
  inflating: data/eng-fra.txt        
   creating: data/names/
  inflating: data/names/Arabic.txt   
  inflating: data/names/Chinese.txt  
  inflating: data/names/Czech.txt    
  inflating: data/names/Dutch.txt    
  inflating: data/names/English.txt  
  inflating: data/names/French.txt   
  inflating: data/names/German.txt   
  inflating: data/names/Greek.txt    
  inflating: data/names/Irish.txt    
  inflating: data/names/Italian.txt  
  inflating: data/names/Japanese.txt  
  inflating: data/names/Kore

In [2]:
SOS_token = 0
EOS_token = 1


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs


MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


input_lang, output_lang, pairs = prepareData('eng', 'fra', True)

Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words...
Counted words:
fra 4345
eng 2803


# Get 1 Input, Output Pair

In [3]:
sample = random.choice(pairs)
sample

['il s en met plein les poches .', 'he s raking it in .']

In [4]:
input_sentence, target_sentence = sample

In [5]:
def indexes_from_sentence(lang, sentence):
  return [lang.word2index[word] for word in sentence.split(" ")]

def tensor_from_sentence(lang, sentence):
  indexes = indexes_from_sentence(lang, sentence)
  indexes.append(EOS_token)
  return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensor_from_pair(pair):
  input_tensor = tensor_from_sentence(input_lang, pair[0])
  output_tensor = tensor_from_sentence(output_lang, pair[1])
  return input_tensor, output_tensor 

input_tensor, output_tensor = tensor_from_pair(sample)
input_tensor.shape, output_tensor.shape

(torch.Size([9, 1]), torch.Size([7, 1]))

In [6]:
input_sentence

'il s en met plein les poches .'

# Encoder

In [7]:
input_size = input_lang.n_words
hidden_size = 256
encoder_embedding = nn.Embedding(input_size, hidden_size).to(device)
encoder_lstm = nn.LSTM(hidden_size, hidden_size, num_layers=1).to(device)

encoder_outputs = torch.zeros(MAX_LENGTH, 256, device=device)
encoder_hidden = (
    torch.zeros(1, 1, 256).to(device),
    torch.zeros(1, 1, 256).to(device)
) # hidden state, cell state

for i in range(input_tensor.size()[0]):
  embedded_input = encoder_embedding(input_tensor[i].view(-1, 1))
  output, encoder_hidden = encoder_lstm(embedded_input, encoder_hidden)
  encoder_outputs[i] += output[0,0]


In [8]:
encoder_outputs.shape, encoder_hidden[0].shape, encoder_hidden[1].shape

(torch.Size([10, 256]), torch.Size([1, 1, 256]), torch.Size([1, 1, 256]))

# Decoder

In [9]:
decoded_words = []

output_size = output_lang.n_words
decoder_embedding = nn.Embedding(output_size, 256).to(device)
attn_weight_layer = nn.Linear(256*2, MAX_LENGTH).to(device)
input_to_decoder_lstm_layer = nn.Linear(256 * 2, 256).to(device)
decoder_lstm = nn.LSTM(256, 256).to(device)
output_word_layer = nn.Linear(256, output_lang.n_words).to(device)

# these will be overwritten below, used for input to the decoder
decoder_hidden = None
top_index = None

# running decoder twice, no teacher forcing
for i in range(2):
  if i == 0:
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = encoder_hidden
  else:
    decoder_input = torch.tensor([[top_index.item()]], device=device)

  embedded_output = decoder_embedding(decoder_input)
  attn_weights = attn_weight_layer(torch.cat((embedded_output[0], decoder_hidden[0][0]), 1))
  attn_weights = F.softmax(attn_weights, dim=1)
  attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
  
  input_to_decoder_lstm = input_to_decoder_lstm_layer(torch.cat((embedded_output[0], attn_applied[0]), 1))
  input_to_decoder_lstm = input_to_decoder_lstm.unsqueeze(0)

  output, decoder_hidden = decoder_lstm(input_to_decoder_lstm, decoder_hidden)
  output = F.relu(output)
  output = F.softmax(output_word_layer(output[0]), dim = 1)

  top_value, top_index = output.data.topk(1)
  decoded_word = output_lang.index2word[top_index.item()]
  decoded_words.append(decoded_word)

print(decoded_words)

['insecure', 'sisters']
