In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import torch.optim as optim
from nltk.translate.bleu_score import sentence_bleu , SmoothingFunction

In [2]:
def sentence_to_words(sentence):
    words = sentence.split()
    last = words.pop()
    last, end = last[:-1], last[-1]
    words.append(last)
    words.append(end)
    return words

In [3]:
class TrEngDataset(Dataset):
    def __init__(self, path=None, limit=None):
        self.path = "/Users/musazenbilci/Desktop/mosesopposite/bisiler/nlp/basic_rnn/data/tur-eng/tur.txt" if not path else path
        self.data = []
        self._wordlimits = [84, 1123, 28555]
        count = 0
        with open(self.path, "r") as f:
            line = f.readline()
            while line:
                count+=1
                if limit and self._wordlimits[limit-1] < count:
                    break
                eng, tr = line.split('\t')[:2]
                self.data.append((tr, eng))
                line = f.readline()


    def __getitem__(self, index):
        return super().__getitem__(index)
    
    def __len__(self):
        return len(self.data)

In [4]:
dataset = TrEngDataset(limit=3)

In [5]:
len(dataset.data)

28555

In [6]:
class Vocabulary():
    def __init__(self, sentence_list):
        self.vocab = {}
        self.sos_token = "<SOS>"
        self.eos_token = "<EOS>"
        self.pad_token = "<PAD>"
        self.vocab_size = 0
        for token in [self.sos_token, self.eos_token, self.pad_token]:
            self.vocab[token] = self.vocab_size
            self.vocab_size+=1

        for sentence in sentence_list:
            words = sentence_to_words(sentence)
            for word in words:
                if word not in self.vocab:
                    self.vocab[word] = self.vocab_size
                    self.vocab_size+=1

        self.i2w = {index: word for word, index in self.vocab.items()}
    
    def get_word(self, index):
        return self.i2w[index]
    
    def get_index(self, word):
        return self.vocab.get(word, 'NAN')
    

In [7]:
turkish_sentences = np.array(dataset.data)[:,0]
tr_vocab = Vocabulary(turkish_sentences)
english_sentences = np.array(dataset.data)[:,1]
eng_vocab = Vocabulary(english_sentences)

In [8]:
tr_vocab.vocab_size

14615

In [9]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        combined_size = input_size + hidden_size
        # forget gate
        self.forget = nn.Linear(combined_size, hidden_size, bias=True)
        # update gate
        self.update = nn.Linear(combined_size, hidden_size, bias=True)
        # candidate gate
        self.candidate = nn.Linear(combined_size, hidden_size, bias=True)
        # output gate
        self.output = nn.Linear(combined_size, hidden_size, bias=True)
    
    def forward(self, x, prev_a, prev_c):
        input_vector = torch.concat((x, prev_a))
        forget = torch.sigmoid(self.forget(input_vector))
        update = torch.sigmoid(self.update(input_vector))
        candidate = torch.tanh(self.candidate(input_vector))
        output = torch.sigmoid(self.output(input_vector))

        c = (forget * prev_c) + (update * candidate)
        a =  torch.tanh(c) * output

        return a, c

class LSTMStacked(nn.Module):
    def __init__(self, num_layers, input_size, hidden_size, output_size):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.cells = nn.Sequential(LSTMCell(input_size, hidden_size), *(LSTMCell(hidden_size, hidden_size) for _ in range(self.num_layers)))
        self.final = nn.Linear(hidden_size, output_size, bias=True)
    
    def forward(self, x, prev_a_c_list):
        new_a_c_list = torch.zeros(self.num_layers, 2, self.hidden_size)
        for c in range(self.num_layers):
            prev_a, prev_c = prev_a_c_list[c]
            new_a, new_c = self.cells[c](x, prev_a, prev_c)
            new_a_c_list[c][0] = new_a
            new_a_c_list[c][0] = new_c
            x = new_a
        
        x = self.final(x)
        return x, new_a_c_list


In [10]:
make_word_onehot = lambda word, vocab: nn.functional.one_hot(torch.tensor(vocab.get_index(word)), num_classes=vocab.vocab_size).to(torch.float32)

<img src="images/attention_context.png">

In [11]:
class Seq2SeqAttention(nn.Module):
    def __init__(self, vocab, output_vocab, hidden_size, num_layers=3, sos_token='<SOS>'):
        super().__init__()
        self.vocab = vocab
        self.output_vocab = output_vocab
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.sos_token = sos_token
        self.eos_token_indexes = [self.output_vocab.get_index(token) for token in [self.vocab.eos_token, '.', '!', '?']]
        self.encoder = nn.LSTM(self.vocab.vocab_size, self.hidden_size, self.num_layers, batch_first=True)
        # self.encoder = LSTMStacked(self.num_layers, self.vocab.vocab_size, self.hidden_size, self.hidden_size)
        self.alignment = nn.Linear(2 * self.hidden_size, 1)
        self.decoder = nn.LSTMCell(self.output_vocab.vocab_size+self.hidden_size, self.hidden_size)
        # self.decoder = LSTMStacked(self.num_layers, self.output_vocab.vocab_size, self.hidden_size, self.output_vocab.vocab_size)
        self.final = nn.Linear(self.hidden_size, self.output_vocab.vocab_size)
    
    def forward(self, seq, target_seq):
        Tx = len(seq)
        seq = torch.unsqueeze(torch.stack(seq), 1)
        h0 = torch.zeros((self.num_layers, self.hidden_size))
        c0 = torch.zeros((self.num_layers, self.hidden_size))
        hn, cn = h0, c0
        h = torch.zeros((Tx, self.hidden_size)) # 2, 3, 64
        for t in range(Tx):
            _, (hn,cn) = self.encoder(seq[t], (hn,cn)) # hn.shape = 3, 64
            h[t] = torch.mean(hn, dim=0)
        
        input = make_word_onehot(self.sos_token, self.output_vocab)
        target_seq = [input] + target_seq
        # target_seq = torch.unsqueeze(torch.stack(target_seq), 0)
        output = []

        a_i = torch.zeros((Tx))
        c_i = torch.zeros((Tx, self.num_layers, 1, self.hidden_size))
        hn = torch.mean(hn, dim=0)
        cn = torch.mean(cn, dim=0)

        for target_word in target_seq:
            s_prev = hn
            e_i = []
            for j in range(Tx):
                e_i.append(self.alignment(torch.concat((s_prev, h[j]))))
            e_i = torch.stack(e_i)
            a_i = torch.softmax(a_i, 0)
            c_i = torch.sum(h * a_i[:, None], dim=0) # 3, 64
            
            input = torch.concat((target_word, torch.flatten(c_i)))
            hn, cn = self.decoder(input, (hn, cn))

            out = self.final(hn)
            output.append(out)
        
        return output

    def predict(self, seq): # this has not been synced yet
        Tx = len(seq)
        seq = torch.unsqueeze(torch.stack(seq), 1)
        h0 = torch.zeros((self.num_layers, self.hidden_size))
        c0 = torch.zeros((self.num_layers, self.hidden_size))
        hn, cn = h0, c0
        h = torch.zeros((Tx, self.hidden_size)) # 2, 3, 64
        for t in range(Tx):
            _, (hn,cn) = self.encoder(seq[t], (hn,cn)) # hn.shape = 3, 64
            h[t] = torch.mean(hn, dim=0)
        
        input = make_word_onehot(self.sos_token, self.output_vocab)
        output = []

        a_i = torch.zeros((Tx))
        c_i = torch.zeros((Tx, self.num_layers, 1, self.hidden_size))
        hn = torch.mean(hn, dim=0)
        cn = torch.mean(cn, dim=0)

        while torch.argmax(input) not in self.eos_token_indexes and len(output) < len(seq)+5:
            s_prev = hn
            e_i = []
            for j in range(Tx):
                e_i.append(self.alignment(torch.concat((s_prev, h[j]))))
            e_i = torch.stack(e_i)
            a_i = torch.softmax(a_i, 0)
            c_i = torch.sum(h * a_i[:, None], dim=0) # 3, 64
            
            input = torch.concat((input, torch.flatten(c_i)))
            hn, cn = self.decoder(input, (hn, cn))

            out = self.final(hn)
            output.append(out)
            input = torch.softmax(out, 0)
        
        return output

        

In [12]:
def pad_couple(first, sec, pad):
    while len(first) > len(sec):
        sec.append(pad)
    while len(sec) > len(first):
        first.append(pad)
    return first, sec


In [13]:
model = Seq2SeqAttention(tr_vocab, eng_vocab, hidden_size=64)
model.load_state_dict(torch.load('models/seq2seq_3word_ep319.pth', weights_only=True))

<All keys matched successfully>

In [14]:
torch.set_grad_enabled(True)

# model = Seq2SeqAttention(tr_vocab, eng_vocab, hidden_size=64)
initial_lr = 1e-3
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
EPOCHS = 319 #300

for ep in range(EPOCHS):
    epoch_loss = 0
    count = 0
    for s in range(len(turkish_sentences)):
        count+=1
        # print(f"Going for the {count}. with input {turkish_sentences[s]} and output {english_sentences[s]}")
        optimizer.zero_grad()
        
        sentence = turkish_sentences[s]
        word_list = sentence_to_words(sentence)
        embedding_list = [make_word_onehot(word, tr_vocab) for word in word_list]

        target_sentence = english_sentences[s]
        target_list = sentence_to_words(target_sentence)
        target_embedding_list = [make_word_onehot(word, eng_vocab) for word in target_list]

        logits = model(embedding_list, target_embedding_list)

        if len(logits) != len(target_embedding_list):
            logits, target_embedding_list = pad_couple(logits, target_embedding_list, make_word_onehot(eng_vocab.pad_token, eng_vocab))
        
        loss = 0
        for i in range(len(logits)):
            loss += criterion(logits[i], target_embedding_list[i])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print("EPOCH", 319+ep, epoch_loss / len(turkish_sentences))
    lr_scheduler.step(epoch_loss / len(turkish_sentences))
# print([eng_vocab.get_word(torch.argmax(logit).item()) for logit in logits])
# print(target_embedding_list[0].shape) 339


EPOCH 319 9.078273192166531
EPOCH 320 7.359861195488729
EPOCH 321 6.917773964297143
EPOCH 322 6.660366865601948
EPOCH 323 6.539584641394714
EPOCH 324 6.4651405270641105
EPOCH 325 6.396212602244578
EPOCH 326 6.379039014192017
EPOCH 327 6.381940246406943
EPOCH 328 6.334288177670687
EPOCH 329 6.293466918665418
EPOCH 330 6.2738746616035055
EPOCH 331 6.24383313572437
EPOCH 332 6.267677899889845
EPOCH 333 6.287024193500141
EPOCH 334 6.239238347781871
EPOCH 335 6.216641887499539
EPOCH 336 6.24901606435952
EPOCH 337 6.260651462194674
EPOCH 338 6.234904765424723
EPOCH 339 6.260188861915066
EPOCH 340 6.23726961193559
EPOCH 341 6.235784455682972
EPOCH 342 6.217132539951411
EPOCH 343 6.19683290789756
EPOCH 344 6.196336123254917
EPOCH 345 6.197807610095476
EPOCH 346 6.212202474196766
EPOCH 347 6.23698941538806
EPOCH 348 6.1750154648274265
EPOCH 349 6.2386289756689655
EPOCH 350 6.253463703160151
EPOCH 351 6.221455632768816
EPOCH 352 6.167433694701931
EPOCH 353 6.197945106884546
EPOCH 354 6.190444217

KeyboardInterrupt: 

In [16]:
torch.set_grad_enabled(False)
try:
    chencherry = SmoothingFunction()
    test_count = 0
    outputs = []
    targets = []
    total_bleu_score = 0
    for s in range(len(turkish_sentences)):
        test_count+=1
        print(f"Going for the {test_count}. with input {turkish_sentences[s]} and output {english_sentences[s]}")
        
        sentence = turkish_sentences[s]
        word_list = sentence_to_words(sentence)
        embedding_list = [make_word_onehot(word, tr_vocab) for word in word_list]

        target_sentence = english_sentences[s]
        target_list = sentence_to_words(target_sentence)
        target_embedding_list = [make_word_onehot(word, eng_vocab) for word in target_list]

        logits = model.predict(embedding_list)
        probs = [torch.softmax(logit, 0) for logit in logits]
        # print(probs)
        output = [eng_vocab.get_word(torch.argmax(prob).item()) for prob in probs]
        while len(output) > 0 and output[-1] == eng_vocab.pad_token:
            output.pop()
        outputs.append(output)
        targets.append([target_list])
        # score = bleu_score(target_list, output)
        print('Output', output)
        print('Target', target_list)
        # print('Score', score)
        # break
        score = sentence_bleu([target_list], output, smoothing_function=chencherry.method2)
        total_bleu_score += score
        print('Score', score)
    print('Total Avg Score', total_bleu_score / len(turkish_sentences))
    
finally:
    torch.set_grad_enabled(True)

Going for the 1. with input Selam. and output Hi.
Output ["I'm", 'smart', '.']
Target ['Hi', '.']
Score 0.408248290463863
Going for the 2. with input Merhaba. and output Hi.
Output ["I'm", 'smart', '.']
Target ['Hi', '.']
Score 0.408248290463863
Going for the 3. with input Kaç! and output Run!
Output ["I'm", '!']
Target ['Run', '!']
Score 0.5
Going for the 4. with input Koş! and output Run!
Output ["I'm", 'it', '!']
Target ['Run', '!']
Score 0.408248290463863
Going for the 5. with input Kaç! and output Run.
Output ["I'm", '!']
Target ['Run', '.']
Score 0
Going for the 6. with input Koş! and output Run.
Output ["I'm", 'it', '!']
Target ['Run', '.']
Score 0
Going for the 7. with input Kim? and output Who?
Output ['Is', 'it', '?']
Target ['Who', '?']
Score 0.408248290463863
Going for the 8. with input Vay canına! and output Wow!
Output ["It's", '!']
Target ['Wow', '!']
Score 0.5
Going for the 9. with input Ördek! and output Duck!
Output ["I'm", '!']
Target ['Duck', '!']
Score 0.5
Going fo

In [15]:
torch.save(model.state_dict(), 'models/seq2seq_3word_ep362.pth')