### RUN ALL THE CELLS SEQUENTIALLy
#### Train the model for atleast 20000-30000 iterations (3 epochs)

In [95]:
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

### Class to build the vocabulary and track the index

In [57]:
SOS_token = 0


class Scrambled:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS"}
        self.n_words = 1  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [58]:
def readLangs():
    print("Reading lines...")

    with open('../data/train_original.txt') as f:
        lines_original = f.read().splitlines()
    
    with open('../data/train_scrambled.txt') as f:
        lines_scrambled = f.read().splitlines()

    pairs = list(zip(lines_scrambled,lines_original))
    scrambled = Scrambled('train_vocab')  # x and y vocab remains the same

    return scrambled, pairs

#### Prepare the data by creating pairs

In [59]:
def prepareData():
    scrambled, pairs = readLangs()
    print("Read %s sentence pairs" % len(pairs))
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        scrambled.addSentence(pair[0])
        
    print("Counted words:")
    print(scrambled.name, scrambled.n_words)
    return scrambled, pairs


scrambled, pairs = prepareData()
print(random.choice(pairs))

Reading lines...
Read 10000 sentence pairs
Trimmed to 10000 sentence pairs
Counting words...
Counted words:
train_vocab 21421
('effectively. also function can States, Member more 25 than with future, the of EU that the ensure to have we now EU; the enlarged have We', 'We have enlarged the EU; we now have to ensure that the EU of the future, with more than 25 Member States, can also function effectively.')


In [60]:
class BagOfEmbeddings(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BagOfEmbeddings, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.linear = nn.Linear(hidden_size, hidden_size)
    def forward(self, input):
        embedded = self.embedding(input)
        linear_out = self.linear(embedded)          
        return linear_out

In [61]:
with open('../data/train_original.txt') as f:
    lines_original = f.read().splitlines()
MAX_LENGTH = max([len(line.split()) for line in lines_original])
MAX_LENGTH = MAX_LENGTH + 1

#### Decoder uses Bahdanaus Attention mechanism to decode each word

In [62]:
class AttentionLSTMDecoder(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttentionLSTMDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)
        
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0][0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
 

        output = torch.cat((embedded[0], attn_applied[0]), 1)
 
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        
        output, hidden = self.lstm(output, hidden)
        

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size),torch.zeros(1, 1, self.hidden_size)

In [63]:
def indexesFromSentence(scrambled, sentence):
    return [scrambled.word2index[word] for word in sentence.split(' ')]


def tensorFromSentence(scrambled, sentence):
    indexes = indexesFromSentence(scrambled, sentence)
    return torch.tensor(indexes, dtype=torch.long).view(-1, 1)


def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(scrambled, pair[0])
    target_tensor = tensorFromSentence(scrambled, pair[1])
    return (input_tensor, target_tensor)

#### Teacher forcing concept used to decode words. Results might improve if we follow scheduled sampling technique

In [64]:

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):


    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size)


    loss = 0

    final_encoder_outputs = encoder(input_tensor.view(1,input_tensor.shape[0]).squeeze(0))


    for ei in range(input_length):
        encoder_outputs[ei] = final_encoder_outputs[ei]    
    
    
    decoder_input = torch.tensor([[SOS_token]])

    decoder_hidden = decoder.initHidden()
  


    for di in range(target_length):
        decoder_output, decoder_hidden, decoder_attention = decoder(
            decoder_input, decoder_hidden, encoder_outputs)
        loss += criterion(decoder_output, target_tensor[di])
        decoder_input = target_tensor[di]  # Teacher forcing


    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [65]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [66]:
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                      for i in range(n_iters)]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        print("EPOCH:",iter)
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder,
                     decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            torch.save(encoder, "encoder_self_attn"+str(iter)+".pth")
            torch.save(decoder, "decoder_self_attn"+str(iter)+".pth")
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, print_loss_avg))

In [None]:
hidden_size = 256
encoder1 = BagOfEmbeddings(scrambled.n_words, hidden_size)
attn_decoder1 = AttentionLSTMDecoder(hidden_size, scrambled.n_words, dropout_p=0.1)

trainIters(encoder1, attn_decoder1, 30000, print_every=5000)

In [103]:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence_test(scrambled, sentence)
    
        
        idx = input_tensor.view(input_tensor.shape[0])
  
 
        
        input_length = input_tensor.size()[0]


        encoder_outputs = torch.zeros(max_length, encoder.hidden_size)
        final_encoder_outputs = encoder(input_tensor.view(1,input_tensor.shape[0]).squeeze(0))
        for ei in range(input_length):
            encoder_outputs[ei] = final_encoder_outputs[ei]



        decoder_input = torch.tensor([[SOS_token]])  # SOS

        decoder_hidden = decoder.initHidden()

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        ## FOCUS only on source sentence
        for di in range(len(input_tensor)-1):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            temp = decoder_output.squeeze(0)
            temp = temp[idx]
        
            
   
            

            topv, topi = temp.data.topk(1)
            
            top_index = topi.item()
            topi = idx[topi]
        
            idx = torch.cat([idx[0:top_index], idx[top_index+1:]])
            
            decoded_words.append(scrambled.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words

In [69]:
def tensorFromSentence_test(scrambled, sentence):
    indexes = indexesFromSentence_test(scrambled, sentence)
    #indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long).view(-1, 1)

def indexesFromSentence_test(scrambled, sentence):
    idx = []
    for word in sentence.split(' '):
        if word in scrambled.word2index.keys():
            idx.append(scrambled.word2index[word])

    return idx

In [70]:
def readScrambled_test():
    print("Reading lines...")

    with open('../data/test_scrambled.txt') as f:
        lines_test = f.read().splitlines()
    
 
        

   

    return lines_test

In [99]:
lines_test = readScrambled_test()

Reading lines...


In [72]:
def evaluateRandomly_test(encoder, decoder, n=1):
    for i in range(n):
        pair = random.choice(lines_test)
        print('>', pair)

        output_words, attentions = evaluate(encoder, decoder, pair)
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

### Run the below cell to sample from test set

In [94]:
encoder1.eval()
attn_decoder1.eval()
evaluateRandomly_test(encoder1, attn_decoder1)

> that. support to ready I am itself In
< itself that. I am to In support



### Run below code to get result for all the test set

In [104]:
def evaluate_all_test(encoder,decoder,lines_test):
    unscrambled = []
    for line in lines_test:
    
        output_words = evaluate(encoder, decoder, line)
      
        output_sentence = ' '.join(output_words)
        unscrambled.append(output_sentence)
    
    return unscrambled

In [105]:
out = evaluate_all_test(encoder1,attn_decoder1,lines_test)

### Here are the results

In [106]:
print(out)

