In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from tqdm import tqdm_notebook as tqdm

torch.manual_seed(1)
random.seed(1)

In [2]:
def sorting_letters_dataset(size):
    dataset = []
    for _ in range(size):
        x = []
        for _ in range(random.randint(3, 10)):
            letter = chr(random.randint(97, 122))
            repeat = [letter] * random.randint(1, 3)
            x.extend(repeat)
        y = sorted(set(x))
        dataset.append((x, y))
    return zip(*dataset)

In [3]:
class Vocab:
    def __init__(self, vocab):
        self.itos = vocab
        self.stoi = {d:i for i, d in enumerate(self.itos)}
        
    def __len__(self):
        return len(self.itos) 

src_vocab = Vocab(['<pad>'] + [chr(i+97) for i in range(26)])
tgt_vocab = Vocab(['<pad>'] + [chr(i+97) for i in range(26)] + ['<start>', '<stop>'] )

train_inp, train_out = sorting_letters_dataset(200) #20_000
valid_inp, valid_out = sorting_letters_dataset(50) #5_000

In [4]:
# map the text data into numeric values

def map_elems(elems, mapper):
    return [mapper[elem] for elem in elems]

def map_many_elems(many_elems, mapper):
    return [map_elems(elems, mapper) for elems in many_elems]

train_x = map_many_elems(train_inp, src_vocab.stoi)
train_y = map_many_elems(train_out, tgt_vocab.stoi)

valid_x = map_many_elems(valid_inp, src_vocab.stoi)
valid_y = map_many_elems(valid_out, tgt_vocab.stoi)

In [5]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, lstm_size, z_type, dropout=0.5):
        super(Encoder, self).__init__()
        self.z_index = z_type
        
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, lstm_size, batch_first=True)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, inputs):
        device = next(self.parameters()).device
        
        seq = torch.tensor([inputs]).to(device) # (1, seqlen)
        emb = self.emb(seq) # (1, seqlen, emb_dim)
        emb = self.drop(emb) 
        
        outs, (h_n, c_n) = self.lstm(emb)
        
        if self.z_index == 1:
            return h_n[0], c_n[0] # (seqlen, lstm_dim)
        else:
            return outs # (1, seqlen, lstm_dim)

encoder = Encoder(vocab_size=len(src_vocab), emb_dim=16, lstm_size=32, z_type=1) # 64 128
encoder

Encoder(
  (emb): Embedding(27, 16)
  (lstm): LSTM(16, 32, batch_first=True)
  (drop): Dropout(p=0.5, inplace=False)
)

In [6]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, lstm_size, dropout=0.5):
        super(Decoder, self).__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTMCell(emb_dim, lstm_size)
        self.clf = nn.Linear(lstm_size, vocab_size)
        
        self.drop = nn.Dropout(dropout)
        self.objective = nn.CrossEntropyLoss(reduction="none")
        
    def forward(self, state, targets, curr_token, last_token):
        device = next(self.parameters()).device
        
        loss = 0
        shifted = targets + [last_token]
        for i in range(len(shifted)):
            inp = torch.tensor([curr_token]).to(device)
            
            emb = self.emb(inp)
            emb = self.drop(emb)
            
            state = self.lstm(emb, state)
            q_i, _ = state 
            q_i = self.drop(q_i)

            scores = self.clf(q_i)
            target = torch.tensor([shifted[i]]).to(device)
            loss += self.objective(scores, target)
            
            curr_token = shifted[i]
            
        return loss / len(shifted)

    def predict(self, state, curr_token, last_token, maxlen):
        device = next(self.parameters()).device
        preds = []
        for i in range(maxlen):
            inp = torch.tensor([curr_token]).to(device)
            emb = self.emb(inp)
            
            state = self.lstm(emb, state)
            h_i, _ = state
            
            scores = self.clf(h_i)
            pred = torch.argmax(torch.softmax(scores, dim=1))
            curr_token = pred
            
            if last_token == pred:
                break
            preds.append(pred)
        return preds
    
decoder = Decoder(vocab_size=len(tgt_vocab), emb_dim=16, lstm_size=32) # 64 128
decoder

Decoder(
  (emb): Embedding(29, 16)
  (lstm): LSTMCell(16, 32)
  (clf): Linear(in_features=32, out_features=29, bias=True)
  (drop): Dropout(p=0.5, inplace=False)
  (objective): CrossEntropyLoss()
)

In [7]:
START_IX = tgt_vocab.stoi['<start>']
STOP_IX  = tgt_vocab.stoi['<stop>']

def shuffle(x, y):
    pack = list(zip(x, y))
    random.shuffle(pack)
    return zip(*pack)


def train(encoder, decoder, train_x, train_y, batch_size=50, epochs=10, print_every=1):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    encoder.to(device)
    decoder.to(device)

    enc_optim = optim.SGD(encoder.parameters(), lr=0.001, momentum=0.99)
    dec_optim = optim.SGD(decoder.parameters(), lr=0.001, momentum=0.99)

    encoder.train()
    decoder.train()

    for epoch in range(1, epochs+1):
        encoder.zero_grad(); enc_optim.zero_grad()
        decoder.zero_grad(); dec_optim.zero_grad()

        train_x, train_y = shuffle(train_x, train_y)

        epoch_loss = 0
        batch_loss = 0    
        
        for i in range(len(train_x)):
            x = train_x[i]
            y = train_y[i]
            
            batch_loss += decoder(encoder(x), y, START_IX, STOP_IX)

            if (i+1) % batch_size == 0:
                batch_loss.backward()
                enc_optim.step()
                dec_optim.step()

                encoder.zero_grad(); enc_optim.zero_grad()
                decoder.zero_grad(); dec_optim.zero_grad()

                epoch_loss += batch_loss.item()
                batch_loss = 0

        if epoch % print_every == 0:
            print(f"**** Epoch {epoch} - Loss: {epoch_loss / len(train_x):.6f} ****")
    return encoder, decoder

In [8]:
encoder = Encoder(vocab_size=len(src_vocab), emb_dim=16, lstm_size=32, z_type=1) # 64 128
decoder = Decoder(vocab_size=len(tgt_vocab), emb_dim=16, lstm_size=32) # 64 128

print(encoder)
print(decoder)

encoder, decoder = train(encoder, decoder, train_x, train_y, batch_size=5, epochs=2, print_every=1) #50

Encoder(
  (emb): Embedding(27, 16)
  (lstm): LSTM(16, 32, batch_first=True)
  (drop): Dropout(p=0.5, inplace=False)
)
Decoder(
  (emb): Embedding(29, 16)
  (lstm): LSTMCell(16, 32)
  (clf): Linear(in_features=32, out_features=29, bias=True)
  (drop): Dropout(p=0.5, inplace=False)
  (objective): CrossEntropyLoss()
)
**** Epoch 1 - Loss: 3.361181 ****
**** Epoch 2 - Loss: 3.208918 ****
