In [1]:
import glob
import torch
import nltk
import random
import os
import numpy as np
from collections import Counter
from itertools import dropwhile

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class Vocab:
    def __init__(self, counter, sos, eos, pad, unk, min_freq=10):

        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.unk = unk
        
        self.pad_idx = 0
        self.unk_idx = 1
        self.sos_idx = 2
        self.eos_idx = 3
        
        self._token2idx = {
            self.pad: self.pad_idx,
            self.unk: self.unk_idx,
            self.eos: self.eos_idx,
            self.sos: self.sos_idx
            
        }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.unk_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.unk)
    
    def __len__(self):
        return len(self._token2idx)

In [4]:
class Dataset(object):
    def __init__(self, path, val=False):
        
        assert os.path.exists(path), 'Path does not exist'
        
        self.val  = val
        files     = glob.glob(os.path.join(path, '*.txt'))
        data      = []
        sentences = []
        for file_ in files:
            with open(file_, 'r') as f:               
                read = f.read()
                read = read.lower()
                sents = nltk.sent_tokenize(read)
                for sent in sents: 
                    words = nltk.word_tokenize(sent)               
                    data.extend(words)               
                sentences.extend(sents)
            break
                    
        words_counter = Counter()

        for token in data:
                    words_counter[token] += 1
        
        eos = "<eos>"
        sos = "<sos>"
        pad = "<pad>"
        unk = "<unk>"
        self.words_vocab= Vocab(words_counter, sos, eos, pad, unk)       
        self.data_ids = [[self.words_vocab.token2idx(item) for item in nltk.word_tokenize(sent)] for sent in sentences]
        
    def get_batch(self, batch_size, train, noise, sos=True, eos=False):
        
        random_ids = np.random.randint(0, len(self.data_ids), batch_size)
        if not self.val:
            batch_data = [self.data_ids[idx] for idx in random_ids]
        else:
            batch_data = self.data_ids

        
        max_length = max([len(sent) for sent in batch_data])
        pad_sents  = []
        for sent in batch_data:
            if train:
                sent = self.apply_noise(sent)
            pad_sent = self.pad_single_seq(sent, self.words_vocab.pad_idx, max_length)
            pad_sent = torch.LongTensor(pad_sent).to(device)
            pad_sents.append(pad_sent)
            
        pad_sents = torch.stack(pad_sents, 0)
        
        return pad_sents
        
    def pad_single_seq(self, sequence, pad_idx, max_length):
        '''
            Pad sequences to max_length
        '''    
        return sequence + [pad_idx]*(max_length - len(sequence))

    def apply_noise(self, sent):
        '''
            Apply random swapping in inp according to their length
        '''        
        length = len(sent)
        if length > 2:
            for it in range(length//2):
                j = random.randint(0, length-2)
                sent[j], sent[j+1] = sent[j+1], sent[j]

        return sent

train_dataset = Dataset('kaz_rus/kaz/test')
# test_dataset = Dataset('kaz_rus/kaz/test')

In [5]:
btch = train_dataset.get_batch(2, train=True, noise=True)

In [19]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers, bidirectional):
        super(Encoder, self).__init__()
        self.vocab_size    = vocab_size
        self.emb_size      = emb_size
        self.hidden_size   = hidden_size
        self.num_layers    = num_layers
        self.bidirectional = bidirectional
        
        #self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0)
        self.gru       = nn.GRU(emb_size, hidden_size, num_layers, bidirectional=bidirectional, batch_first=True)

        
    def forward(self, sentence, enc_embedding, hidden=None):
        '''
        Input:
            sentence: (batch x seq_len)
        Output:
            out:      (batch x hidden)
        '''
               
        embeddings = enc_embedding(sentence)

        if hidden is not None:
            _, hidden = self.gru(embeddings, hidden)
        else:
            _, hidden = self.gru(embeddings)
        
        if self.bidirectional:
            hidden = torch.stack([torch.cat((hidden[2*i], hidden[2*i+1]), dim=1) for i in range(self.num_layers)])

        return hidden
            
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, n_layers):
        super(Decoder, self).__init__()
        self.vocab_size  = vocab_size
        self.emb_size    = emb_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0)
        self.gru       = GRULayers(emb_size, hidden_size, n_layers)
        self.linear    = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, inp, hidden):
        '''
        Input:
            inp: (batch x 1)
        Output:
            logit: (batch x vocab_size)
        '''    
        
        embedded = self.embedding(inp)
        embedded = embedded.squeeze(1)

        hidden = self.gru(embedded, hidden) 

        logit = self.linear(hidden[:, -1, :]) 


        return logit    

class GRULayers(nn.Module):
    def __init__(self, emb_size, hidden_size, n_layers):
        super(GRULayers, self).__init__()
        self.emb_size    = emb_size
        self.hidden_size = hidden_size
        self.n_layers    = n_layers
        
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(nn.GRUCell(emb_size, hidden_size))
            emb_size = hidden_size
    
    def forward(self, inp, prev_hidden):
        
        hiddens = []
        for i, layer in enumerate(self.layers):
            current_hidd = layer(inp, prev_hidden[i])
            inp          = current_hidd
            hiddens     += [current_hidd]
        
        hiddens = torch.stack(hiddens)
        return hiddens           
    
class Model(nn.Module):
    def __init__(self, enc_embedding, dec_embedding, vocab_size, encoder, decoder):
        super(Model, self).__init__()
        self.enc_embedding = enc_embedding
        self.dec_embedding = dec_embedding
        self.vocab_size    = vocab_size
        self.encoder       = encoder
        self.decoder       = decoder
        self.criterion     = nn.NLLLoss()
        
    def set_mode(self, mode):
        self.enc_embedding.train(mode)
        self.dec_embedding.train(mode)
        #self.generator.train(mode)
        self.encoder.train(mode)
        self.decoder.train(mode)
        self.criterion.train(mode)  
       
    def forward(self, input_, src_words, trg_words):
        hidden = self.encoder(batch_words, self.enc_embedding)

        logits = []
        for t in range(2):
            logit   = self.decoder(input_, hidden)
            input_  = F.softmax(logit, dim=-1).topk(1)[1]#.item()
            logits.append(logit)
        
        logits = torch.stack(logits, 1) 
        logits = logits.view(-1, self.vocab_size)
        
        return logits

In [None]:
src_enc_embeddings = nn.Embedding(5000, 30)
trg_enc_embeddings = nn.Embedding(5000, 30)

src_dec_embeddings = nn.Embedding(5000, 30)
trg_dec_embeddings = nn.Embedding(5000, 30)

In [20]:
encoder         = Encoder(10000, 30, 60, 2, bidirectional=True).to(device)
decoder_one     = Decoder(5000, 30, 120, 2).to(device)
decoder_two     = Decoder(5000, 30, 120, 2).to(device)

model_one2one   = Model(5000, src_enc_embeddings, src_dec_embeddings, encoder, decoder_one).to(device)
model_two2two   = Model(5000, trg_enc_embeddings, trg_dec_embeddings, encoder, decoder_two).to(device)
model_one2two   = Model(5000, src_enc_embeddings, trg_dec_embeddings, encoder, decoder_two).to(device)
model_two2one   = Model(5000, trg_enc_embeddings, src_dec_embeddings, encoder, decoder_one).to(device)

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer_one2one = optim.Adam(params=model_one2one.parameters(), lr=0.00002)
optimizer_two2two = optim.Adam(params=model_two2two.parameters(), lr=0.00002)
optimizer_one2two = optim.Adam(params=model_one2two.parameters(), lr=0.00002)
optimizer_two2one = optim.Adam(params=model_two2one.parameters(), lr=0.00002)

In [None]:
def train(epochs, batch_size):
    len_data = min(len(train_dataset_one), len(train_dataset_two))
       
    for batch_idx in range(len_data//batch_size):
        optimizer_one2one.zero_grad()
        optimizer_two2two.zero_grad()
        
        batch_one = train_dataset_one.get_batch(batch_size)
        batch_two = train_dataset_two.get_batch(batch_size)
        
        hidden_one = encoder(batch_one)
        hidden_two = encoder(batch_two)
        
        ### Simple autoencoding ###
        logits_one = decoder_one(hidden_one)      
        logits_one = logits_one.view(-1)
        mask = logits_one != train_dataset_one.word_vocab.pad_idx
        loss_one = criterion(logits_one[mask], batch_one[mask])
                
        logits_two = decoder_two(hidden_two)   
        logits_two = logits_two.view(-1)
        mask = logits_two != train_dataset_one.word_vocab.pad_idx
        loss_two = criterion(logits_two[mask], batch_two[mask])
        
        ### Back Translation ###
        logits_cycle_two = decoder_two(hidden_one)      
        logits_cycle_two = logits_cycle_two.view(-1)
        
        
        
        
        