In [1]:
import random
import re
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import unicodedata
from torchtext.data.metrics import bleu_score

import os

In [2]:
import nltk
from nltk.translate.bleu_score import corpus_bleu

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Data Preparation

In [4]:
sos_token = 0
eos_token = 1

# helper class to assign word to index and vice versa
# this class essentially builds a vocab for each language
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2idx = {}
        self.word2count = {}
        self.idx2word = {0: 'SOS',
                         1: 'EOS'}
        self.num_words = 2 # initialize with the sos and eos tokens
        
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.num_words
            self.word2count[word] = 1
            self.idx2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

In [5]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

In [6]:
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [7]:
def readLangPair(source_lang, target_lang):
    print("Reading Sentence Pairs")
    
    # a backslash tells the interpreter to extend the logical line to 
    # the next physical line
    lines = open('data/%s-%s.txt' % (source_lang, target_lang), encoding='utf-8').\
        read().strip().split('\n')
    
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    input_lang = Lang(source_lang)
    output_lang = Lang(target_lang)
    
    return input_lang, output_lang, pairs

In [8]:
max_length = 10

eng_prefixes = (
    "i am", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

def filterPair(pair):
    return len(pair[0].split(' ')) < max_length and \
        len(pair[1].split(' ')) < max_length \
        and pair[0].startswith(eng_prefixes)

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]
    

In [9]:
# readLangPair('eng', 'fra')

In [10]:
def prepareData(source_lang, target_lang, reverse=False):
    input_lang, output_lang, pairs = readLangPair(source_lang, target_lang)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Triimmed to %s sentence pairs" % len(pairs))
    print("Counting Words")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted Words: ")
    print(input_lang.name, input_lang.num_words)
    print(output_lang.name, output_lang.num_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('eng', 'fra')
print(random.choice(pairs))

Reading Sentence Pairs
Read 135842 sentence pairs
Triimmed to 10601 sentence pairs
Counting Words
Counted Words: 
eng 2803
fra 4346
['you re considerate .', 'vous etes prevenant .']


### Encoder RNN

In [11]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        
    def forward(self, input, hidden):
        # ensure the first two dimensions are 1x1, -1 means infer this dimension
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        # the gru will output the current prediction and the next hidden state
        output, hidden = self.gru(output, hidden)
        return output, hidden
    
    def initHidden(self):
        # the initial hidden state
        return torch.zeros(1, 1, self.hidden_size, device=device)
        

### Decoder RNN with Attention

In [12]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, drop_rate=0.1, max_length=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.drop_rate = drop_rate
        self.max_length = max_length
        
        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.drop_rate)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
        
    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)
        
        
        # what is happening here ?
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
         
        output = F.log_softmax(self.out(output[0]), dim=1) # substituted for
                                                  # softmax
        return output, hidden, attn_weights
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
        
        

### Training Process

In [13]:
def indexesFromSentence(lang, sentence):
    return [lang.word2idx[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(eos_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

In [14]:
teacher_forcing_ratio = 0.5

def train(input, label, encoder, decoder, encoder_optimizer, decoder_optimizer, \
          criterion, max_length=10):
    
    encoder_hidden = encoder.initHidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    # encoder.train()
    # decoder.train()
    
    input_length = input.size(0)
    target_length = label.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0
    
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
                        input[ei], encoder_hidden)  
        encoder_outputs[ei] = encoder_output[0, 0] # why?
        
    decoder_input = torch.tensor([[sos_token]], device=device)
    decoder_hidden = encoder_hidden
    
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(\
                             decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, label[di])
            decoder_input = label[di] # use teacher forcing to
            # feed the correct answer instead of decoder output
            
    else: # use decoder predictions and not correct output
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(\
                            decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach() # remove from computation graph
            
            loss += criterion(decoder_output, label[di])
            if decoder_input.item() == eos_token:
                break
                
    loss.backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item() / target_length

In [15]:
def trainIters(encoder, decoder, n_iters, lr=0.01):
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=lr)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=lr)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                     for i in range(n_iters)]
    criterion = nn.NLLLoss()
    
    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter-1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]
        
        # prev_loss = 0
        loss = train(input_tensor, target_tensor, encoder,
                    decoder, encoder_optimizer, decoder_optimizer,
                    criterion)   
        # print(loss)
        # if(iter == 1):
        #   prev_loss = loss
        # if(loss < prev_loss):
        #   torch.save(encoder.state_dict(), 'encoder.pth')   
        #   torch.save(decoder.state_dict(), 'decoder.pth')   
        #else:
        #pass
    torch.save(encoder.state_dict(), './model_weights/encoder.pth')   
    torch.save(decoder.state_dict(), './model_weights/decoder.pth')  
          

### Evaluation 

In [16]:
def evaluate(encoder, decoder, sentence, max_length=10):
    # encoder.eval()
    # decoder.eval()
    # encoder = encoder.to(device)
    # decoder = decoder.to(device)
    with torch.no_grad():
        input = tensorFromSentence(input_lang, sentence)
        input_length = input.size()[0]
        encoder_hidden = encoder.initHidden()
        
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
        
        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input[ei],
                                             encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]
            
        decoder_input = torch.tensor([[sos_token]], device=device) 
        decoder_hidden = encoder_hidden
        
        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)
        
        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == eos_token:
                decoded_words.append('<eos>')
                break
            else:
                decoded_words.append(output_lang.idx2word[topi.item()])
            
            decoder_input = topi.squeeze().detach() # necessary?
            
        return decoded_words, decoder_attentions[:di+1]

In [17]:
def evaluateRandomly(encoder, decoder, n=3):
    # encoder.eval()
    # decoder.eval()
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])

        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

In [18]:
hidden_size = 512
encoder1 = EncoderRNN(input_lang.num_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.num_words, drop_rate=0.1).to(device)


if ((device.__str__() == 'cuda')):
  # trainIters(encoder1, attn_decoder1, n_iters=75000)
  encoder1.load_state_dict(torch.load('./model_weights/encoder.pth'))
  attn_decoder1.load_state_dict(torch.load('./model_weights/decoder.pth'))
  evaluateRandomly(encoder1, attn_decoder1, 5)
else:
  print(device)
  encoder1.load_state_dict(torch.load('./model_weights/encoder.pth', map_location='cpu'))
  attn_decoder1.load_state_dict(torch.load('./model_weights/decoder.pth', map_location='cpu'))

  device='cpu'
  encoder1.to(device)
  attn_decoder1.to(device)


  evaluateRandomly(encoder1, attn_decoder1, 5)


> i m impatient .
= je suis impatiente .
< je suis impatient . <eos>

> you re very observant .
= tu es tres observatrice .
< vous etes tres observateur . <eos>

> they re mad at you .
= elles sont furieuses apres vous .
< elles sont en colere apres vous . <eos>

> he is independent of his parents .
= il est independant de ses parents .
< il est independant de ses parents . <eos>

> i m a light sleeper .
= je dors peu .
< je dors peu . <eos>



### BLEU Score

In [61]:
def evaluateBleu(encoder, decoder, n=3):
    from nltk.translate.bleu_score import SmoothingFunction
    smoothie = SmoothingFunction().method4
    cand = []
    ref = []  
    encoder.eval()
    decoder.eval()
    for i in range(5000):
        pair = pairs[i]
        ref.append(pair[1].split())
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        cand.append(output_sentence.split()[:-1])
    return corpus_bleu(ref, cand, smoothing_function=smoothie)

In [62]:
if((device.__str__() == 'cuda')):
  print(evaluateBleu(encoder1, attn_decoder1, 1))
else:
  quantized_encoder = torch.quantization.quantize_dynamic(
    encoder1, {nn.GRU, nn.Linear, nn.Embedding}, dtype=torch.qint8
    )
  quantized_decoder = torch.quantization.quantize_dynamic(
    attn_decoder1, {nn.GRU, nn.Linear, nn.Embedding}, dtype=torch.qint8
    )
  print(evaluateBleu(quantized_encoder, quantized_decoder, 1))
  torch.save(quantized_encoder.state_dict(), './quantized_weights/q_encoder.pth')
  torch.save(quantized_decoder.state_dict(), './quantized_weights/q_decoder.pth')

0.21715226428238207


### Model Compression Block

In [21]:
if((device.__str__() == 'cpu')):
  print(os.path.getsize('./model_weights/encoder.pth')/1e6)
  print(os.path.getsize('./quantized_weights/q_encoder.pth')/1e6)
  print(os.path.getsize('./model_weights/decoder.pth')/1e6)
  print(os.path.getsize('./quantized_weights/q_decoder.pth')/1e6)

##### **Torch To ONNX**

In [None]:
# not yet supported by torch
dummy_input = torch.zeros(10, 1)

quantized_encoder.eval()

torch.onnx.export(
    quantized_encoder,
    dummy_input,
    "torch.onxx",
)