In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from sklearn.metrics import accuracy_score
from tqdm import tqdm_notebook as tqdm

torch.manual_seed(1)
random.seed(1)

In [None]:
config = {
    'num_train': 200,
    'num_valid': 50,
    'batch': 5,
    'epoch': 50,
    'patient': 20,
    'lr': 0.001,
    'momentum': 0.99,
    'encoder_emb_size': 64,
    'decoder_emb_size': 64,
    'lstm_size': 128,
    'max_pred': 10,
    'logfile': "lstm-batch-loop.log",
    'checkpoint': "seq2seq-loop.pt"
}

open(config['logfile'], 'w').close()
def saveLogMsg(msg):
    print(msg, "\n")
    with open(config['logfile'], "a") as myfile:
        myfile.write(msg + "\n")
saveLogMsg("Starting...")

In [None]:
def sorting_letters_dataset(size):
    dataset = []
    for _ in range(size):
        x = []
        for _ in range(random.randint(3, 10)):
            letter = chr(random.randint(97, 122))
            repeat = [letter] * random.randint(1, 3)
            x.extend(repeat)
        y = sorted(set(x))
        dataset.append((x, y))
    return zip(*dataset)

train_inp, train_out = sorting_letters_dataset(config['num_train'])
valid_inp, valid_out = sorting_letters_dataset(config['num_valid'])

saveLogMsg("Dataset for train and valid...")

In [None]:
class Vocab:
    def __init__(self, vocab):
        self.itos = vocab
        self.stoi = {d:i for i, d in enumerate(self.itos)}
        
    def __len__(self):
        return len(self.itos) 

src_vocab = Vocab(['<pad>'] + [chr(i+97) for i in range(26)])
tgt_vocab = Vocab(['<pad>'] + [chr(i+97) for i in range(26)] + ['<start>', '<stop>'] )

START_IX = tgt_vocab.stoi['<start>']
STOP_IX  = tgt_vocab.stoi['<stop>']

saveLogMsg("Vocab for source and target...")

In [None]:
def map_elems(elems, mapper):
    return [mapper[elem] for elem in elems]

def map_many_elems(many_elems, mapper):
    return [map_elems(elems, mapper) for elems in many_elems]

train_x = map_many_elems(train_inp, src_vocab.stoi)
train_y = map_many_elems(train_out, tgt_vocab.stoi)

valid_x = map_many_elems(valid_inp, src_vocab.stoi)
valid_y = map_many_elems(valid_out, tgt_vocab.stoi)

saveLogMsg("Mapping dataset through Vocab...")

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, lstm_size, z_type, dropout=0.5):
        super(Encoder, self).__init__()
        self.z_index = z_type
        
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, lstm_size, batch_first=True)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, inputs):
        device = next(self.parameters()).device
        
        x_tensor = [torch.tensor(sample).to(device) for sample in inputs]
        x_pad = pad_sequence(x_tensor, batch_first=True, padding_value=0) # (batch, seqlen) 
        x_emb = self.emb(x_pad) # (batch, seqlen, emb_dim) 
        x_emb = self.drop(x_emb)
        
        x_len = [len(sample) for sample in inputs]
        x_pack = pack_padded_sequence(x_emb, x_len, batch_first=True, enforce_sorted=False)
        outs_pack, (h_n, c_n) = self.lstm(x_pack)
        outs, _ = pad_packed_sequence(outs_pack, batch_first=True)
        
        if self.z_index == 1:
            return h_n, c_n # (seqlen, lstm_dim)
        else:
            return outs # (1, seqlen, lstm_dim)

encoder = Encoder(vocab_size=len(src_vocab), 
                  emb_dim=config['encoder_emb_size'], 
                  lstm_size=config['lstm_size'], 
                  z_type=1)
saveLogMsg("encoder:\n{}".format(encoder))

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, lstm_size, dropout=0.5):
        super(Decoder, self).__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTMCell(emb_dim, lstm_size)
        self.clf = nn.Linear(lstm_size, vocab_size)
        
        self.drop = nn.Dropout(dropout)
        self.objective = nn.CrossEntropyLoss(reduction="none")
        
    def forward(self, batch_state, batch_targets, curr_token_raw, last_token_raw):
        device = next(self.parameters()).device
        
        batch_state_h, batch_state_c = batch_state
        batch_state_ht = batch_state_h.transpose(0, 1)
        batch_state_ct = batch_state_c.transpose(0, 1)
        
        batch_loss = 0.0
        for targets, state_h, state_c in zip(batch_targets, batch_state_ht, batch_state_ct):
            curr_token, last_token = curr_token_raw, last_token_raw
            state = (state_h, state_c)
            shifted = targets + [last_token]
            
            each_loss = 0.0
            for i in range(len(shifted)):
                inp = torch.tensor([curr_token]).to(device)

                emb = self.emb(inp)
                emb = self.drop(emb)

                state = self.lstm(emb, state)
                q_i, _ = state 
                q_i = self.drop(q_i)

                scores = self.clf(q_i)
                target = torch.tensor([shifted[i]]).to(device)
                each_loss += self.objective(scores, target)

                curr_token = shifted[i]
            
            batch_loss += (each_loss / len(shifted) * 1.0)
            
        return batch_loss # / len(targets)

    def predict(self, batch_state, batch_targets, curr_token_raw, last_token_raw):
        device = next(self.parameters()).device
        
        batch_state_h, batch_state_c = batch_state
        batch_state_ht = batch_state_h.transpose(0, 1)
        batch_state_ct = batch_state_c.transpose(0, 1)
        
        batch_preds = []
        batch_loss = 0.0
        for state_h, state_c, targets in zip(batch_state_ht, batch_state_ct, batch_targets):
            curr_token, last_token = curr_token_raw, last_token_raw
            state = (state_h, state_c)
            
            each_preds = []
            each_loss = 0.0
            for i in range(maxlen):
                inp = torch.tensor([curr_token]).to(device)
                
                emb = self.emb(inp)

                state = self.lstm(emb, state)
                h_i, _ = state

                scores = self.clf(h_i)
                target = torch.tensor([?]).to(device)
                each_loss += self.objective(scores, target)
                
                pred = torch.argmax(torch.softmax(scores, dim=1))
                curr_token = pred

                if last_token == pred:
                    break
                each_preds.append(pred)
            batch_preds.append(each_preds)
        return batch_preds
    
decoder = Decoder(vocab_size=len(tgt_vocab), 
                  emb_dim=config['decoder_emb_size'], 
                  lstm_size=config['lstm_size'])
saveLogMsg("decoder:\n{}".format(decoder))

In [None]:
def shuffle(x, y):
    pack = list(zip(x, y))
    random.shuffle(pack)
    return zip(*pack)

def track_best_model(model_path, model, epoch, best_acc, dev_acc, dev_loss, patient_track):
    if best_acc > dev_acc:
        return best_acc, '', patient_track+1
    state = {
        'epoch': epoch,
        'acc': dev_acc,
        'loss': dev_loss,
        'model': model.state_dict()
    }
    torch.save(state, model_path)
    return dev_acc, ' * ', 0

def evaluate(encoder, decoder, sample_x, sample_y, batch_size):
    encoder.eval()
    decoder.eval()
    
    batch_x, batch_y = [], []
    predictions, actuals = [], []
    
    with torch.no_grad():
        for i in range(len(sample_x)):
            batch_x.append(sample_x[i])
            batch_y.append(sample_y[i])
            
            if len(batch_x) == batch_size or i == len(sample_x) - 1:
                actuals.extend(batch_y)
                batch_preds = decoder.predict(encoder(batch_x), START_IX, STOP_IX, maxlen=config['max_pred'])
                batch_preds = [[tgt_vocab.itos[ix] for ix in each_preds] for each_preds in batch_preds]
                batch_preds = [''.join(each_preds) for each_preds in batch_preds]
                predictions.extend(batch_preds)
                batch_x, batch_y = [], []
    
    return y_preds

def train(encoder, decoder, train_x, train_y, batch_size=50, epochs=10, print_every=1):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    encoder.to(device)
    decoder.to(device)

    enc_optim = optim.SGD(encoder.parameters(), lr=config['lr'], momentum=config['momentum'])
    dec_optim = optim.SGD(decoder.parameters(), lr=config['lr'], momentum=config['momentum'])

    encoder.train()
    decoder.train()
    
    for epoch in range(1, epochs+1):
        encoder.zero_grad(); enc_optim.zero_grad()
        decoder.zero_grad(); dec_optim.zero_grad()

        train_x, train_y = shuffle(train_x, train_y)
        batch_x, batch_y = [], []

        epoch_loss = 0
        
        for i in range(len(train_x)):
            batch_x.append(train_x[i])
            batch_y.append(train_y[i])

            if len(batch_x) == batch_size or i == len(train_x) - 1:
                batch_loss = decoder(encoder(batch_x), batch_y, START_IX, STOP_IX)
            
                batch_loss.backward()
                enc_optim.step()
                dec_optim.step()

                encoder.zero_grad(); enc_optim.zero_grad()
                decoder.zero_grad(); dec_optim.zero_grad()

                epoch_loss += batch_loss.item()
                batch_x, batch_y = [], []

        if epoch % print_every == 0:
            saveLogMsg(f"**** Epoch {epoch} - Loss: {epoch_loss / len(train_x):.6f} ****")
    return encoder, decoder

In [None]:
saveLogMsg("Training with encoder and decoder...")
encoder, decoder = train(encoder, decoder, 
                         train_x, train_y, 
                         batch_size=config['batch'], epochs=config['epoch'], 
                         print_every=1)

In [None]:
torch.save({'encoder': encoder.state_dict(), 'decoder': decoder.state_dict()}, config['checkpoint'])
saveLogMsg("Saved model as {}...".format(config['checkpoint']))

In [None]:
predictions = predict(encoder, decoder, valid_x, tgt_vocab.itos, batch_size=16)
groundtruth = [''.join(t) for t in valid_out]

accs = accuracy_score(groundtruth, predictions)
saveLogMsg("accuracy_score = {}...".format(accs))