In [1]:
import numpy as np
import string
import time
import matplotlib.pyplot as plt
import numpy as np
import Levenshtein as L
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [2]:
train_data = np.load('./dataset/train.npy', encoding='latin1')
train_transcript_raw = np.load('./dataset/train_transcripts.npy')
dev_data = np.load('./dataset/dev.npy', encoding='latin1')
dev_transcript_raw = np.load('./dataset/dev_transcripts.npy')
test_data = np.load('./dataset/test.npy', encoding='latin1')

char_set = ['',"'",'+','-','.','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','_',' ']

In [3]:
def preprocess_char(transcripts):
    empty_idx, space_idx = 0, len(char_set)-1
    char_representation = []
    for line in transcripts:
        char_repre = [empty_idx]
        for w in line:
            char_repre.extend([char_set.index(c) for c in list(w.decode('utf-8'))])
            char_repre.append(space_idx)
        if len(line) > 0:
            char_repre.pop()
        char_repre.append(empty_idx)
        char_representation.append(np.array(char_repre))
    return np.array(char_representation)

train_transcript = preprocess_char(train_transcript_raw)
dev_transcript = preprocess_char(dev_transcript_raw)

In [4]:
class LASDataLoader(DataLoader):
    
    def __init__(self, data, transcript, batch_size, shuffle=True):
        for i in range(len(data)):
            data[i] = torch.Tensor(data[i])
            if transcript is None:
                continue
            transcript[i] = torch.LongTensor(transcript[i])
        self.data = data
        self.transcript = transcript
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            rand_idx = np.random.permutation(len(self.data))
        else:
            rand_idx = np.arange(len(self.data))
        num_iter = len(self.data) // self.batch_size
        if num_iter % self.batch_size != 0:
            num_iter += 1
        for i in range(num_iter):
            batch_utter_list = self.data[rand_idx[i*self.batch_size:(i+1)*self.batch_size]]
            utter_lens = [len(utter) for utter in batch_utter_list]
            utter_order = sorted(range(len(utter_lens)), key=utter_lens.__getitem__, reverse=True)
            batch_utter_list = [batch_utter_list[i] for i in utter_order]
            
            # for test data
            if self.transcript is None:
                yield batch_utter_list, utter_order
            else:
                batch_trans_list = self.transcript[rand_idx[i*self.batch_size:(i+1)*self.batch_size]]
                trans_lens = [len(trans) for trans in batch_trans_list]
                trans_order = sorted(range(len(trans_lens)), key=trans_lens.__getitem__, reverse=True)
                batch_trans_data = [batch_trans_list[i][:-1] for i in trans_order]
                batch_trans_label = [batch_trans_list[i][1:] for i in trans_order]
                yield batch_utter_list, utter_order, batch_trans_data, batch_trans_label, trans_order

In [5]:
from torch.nn import Parameter
class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights, dropout=0):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        # We need to replace flatten_parameters with a nothing function
        # It must be a function rather than a lambda as otherwise pickling explodes
        # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
        # (╯°□°）╯︵ ┻━┻
        return

    def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)

In [6]:
from torch.autograd import Variable

# Referenced from https://github.com/salesforce/awd-lstm-lm

class LockedDropout(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, dropout=0.5):
        if not self.training or not dropout:
            return x
        m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
        mask = Variable(m, requires_grad=False) / (1 - dropout)
        mask = mask.expand_as(x)
        return mask * x

In [7]:
class MaskConv(nn.Module):
    
    def __init__(self, input_channel, output_channel, kernel_size, stride, padding):
        super(MaskConv, self).__init__()
        self.conv = nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(output_channel)
        self.tanh = nn.Hardtanh(0, 20, inplace=True)
        
        nn.init.xavier_normal_(self.conv.weight)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)
    
    def forward(self, x , input_length):
        # calculate the output length on seq dim
        output_length = ((input_length + 2 * self.conv.padding[1] - self.conv.kernel_size[1]) / self.conv.stride[1] + 1)
        
        x = self.tanh(self.bn(self.conv(x)))
        mask = torch.ByteTensor(x.size()).fill_(0)
        if x.is_cuda:
            mask = mask.to(DEVICE)
        for i, length in enumerate(output_length):
            length = length.item()
            if (mask[i].size(2) - length) > 0:
                mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
        x = x.masked_fill(mask, 0)
        return x, output_length

In [8]:
# for each sample
class Beam:
    def __init__(self, beam, log_prob, hidden_states, attention):
        self.beam = beam
        self.log_prob = log_prob
        self.hidden_states = hidden_states
        self.attention = attention

class BeamSearch:
    def __init__(self, beam_size):
        self.beam_size = beam_size
        self.top_beams = []
        self.complete_beams = []
    
    def last_top_beams(self):
        x = []
        hidden = [None, None, None]
        att = []
        for beam in self.top_beams:
            x.append(beam.beam[-1])
            for layer in range(3):
                if hidden[layer] is None:
                    hidden[layer] = beam.hidden_states[layer]
                else:
                    hidden[layer] = (torch.cat((hidden[layer][0], beam.hidden_states[layer][0])), 
                                     torch.cat((hidden[layer][1], beam.hidden_states[layer][1])))
            att.append(beam.attention)
        return torch.LongTensor(x).to(DEVICE), hidden, torch.cat(att)
                
    def update_beams(self, probs, attention, hidden_states):
        vocab_size = probs.size(1)

        candidate_beams = []
        for i, beam in enumerate(self.top_beams):
            for ind in range(vocab_size):
                new_beam = beam.beam.copy()
                new_beam.append(ind)
                log_prob = beam.log_prob + probs[i][ind]
                hidden = [(hidden_state[0][i].view(1,-1), 
                           hidden_state[1][i].view(1,-1)) for hidden_state in hidden_states]
                att = attention[i]
                candidate_beams.append(Beam(new_beam, log_prob, hidden, att))
        candidate_scores = [beam.log_prob for beam in candidate_beams]
        top_k_prob, top_k_ind = torch.topk(torch.Tensor(candidate_scores), min(len(candidate_scores), 2*self.beam_size))
        
        top_k_beams = []
        k = 0
        for ind in top_k_ind:
            if candidate_beams[ind].beam[-1] == 0:
                self.complete_beams.append(candidate_beams[ind])
            else:
                top_k_beams.append(candidate_beams[ind])
                k += 1
            if k == self.beam_size:
                break
        
        self.top_beams = top_k_beams
#         print (len(self.top_beams), len(self.complete_beams))
    
    def best_beam(self):
        if len(self.complete_beams) == 0:
            self.complete_beams = self.top_beams
        top_scores = [beam.log_prob / len(beam.beam) for beam in self.complete_beams]
        top_k_prob, top_k_ind = torch.topk(torch.Tensor(top_scores), min(len(top_scores), self.beam_size))
#         print ("prob: ", top_k_prob[0])
        return self.complete_beams[top_k_ind[0]].beam

In [9]:
class pBLSTM(nn.Module):
    
    def __init__(self, input_size, hidden_size, weight_drop):
        super(pBLSTM, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=(2,1))
        self.blstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True, batch_first=True)
        if weight_drop:
            self.blstm = WeightDrop(self.blstm, ['weight_hh_l0'], dropout=0.3)
    
    def forward(self, x, input_length, hidden_states): #B * L * F
        batch_size, seq_len, feat_len = x.shape
        out = x[:,:seq_len//2*2,:]
        out = out.view(batch_size, seq_len//2, 2, feat_len)
        out = self.pool(out).squeeze()
        input_length //= 2
        out = nn.utils.rnn.pack_padded_sequence(out, input_length, batch_first=True)
        out, hidden = self.blstm(out, hidden_states)
        out, output_length = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        return out, hidden, output_length

class Listener(nn.Module):
    
    def __init__(self, input_size, hidden_size, nlayers, key_size, value_size, weight_drop=True):
        super(Listener, self).__init__()
        self.hidden_states = nn.ParameterList([nn.Parameter(torch.zeros(2, 1, hidden_size)) for i in range(nlayers)])
        self.cell_states = nn.ParameterList([nn.Parameter(torch.zeros(2, 1, hidden_size)) for i in range(nlayers)])
            
        self.conv1 = MaskConv(1, 32, kernel_size=(3, 3), stride=(2,1), padding=(1,1))
        self.conv2 = MaskConv(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1,1))
        self.conv3 = MaskConv(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1,1))
        rnn_input_size = 32 * 20
        
        self.plstms = []
        plstm = pBLSTM(rnn_input_size, hidden_size, weight_drop)
        self.plstms.append(plstm)
        self.add_module('plstm_0', plstm)
        for i in range(nlayers-1):
            plstm = pBLSTM(2*hidden_size, hidden_size, weight_drop)
            self.plstms.append(plstm)
            self.add_module('plstm_'+str(i+1), plstm)
    
        self.lockdrop = LockedDropout()
        self.fc_key_proj = nn.Linear(2*hidden_size, key_size)
        self.fc_value_proj = nn.Linear(2*hidden_size, value_size)
    
        self.lockdrop = LockedDropout()
        self.fc_key_proj = nn.Linear(2*hidden_size, key_size)
        self.fc_value_proj = nn.Linear(2*hidden_size, value_size)
        
    def forward(self, x, input_length):
        x = x.transpose(1,2) #n f s
        x = x.unsqueeze(1)
        x, output_length = self.conv1(x, input_length)
        x, output_length = self.conv2(x, output_length)
        x, output_length = self.conv3(x, output_length)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) #n f s
        x = x.transpose(1,2) # n s f
        
        batch_size = x.size(0)
        hidden_states = [(hidden_state.expand(2, batch_size, -1).contiguous(), 
                          cell_state.expand(2, batch_size, -1).contiguous()) 
                         for hidden_state, cell_state in zip(self.hidden_states, self.cell_states)]
        for i, plstm in enumerate(self.plstms):
            x = self.lockdrop(x, 0.3)
            x, hidden, output_length = plstm(x, output_length, hidden_states[i])
        x = self.lockdrop(x, 0.3)
        key = self.fc_key_proj(x)
        value = self.fc_value_proj(x)
        return key, value, output_length

class Speller(nn.Module):
    
    def __init__(self, vocab_size, embed_size, hidden_size, key_size, value_size, nlayers, weight_drop=False):
        super(Speller, self).__init__()
        self.nlayers = nlayers
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.projection = nn.Linear(hidden_size, key_size)
        self.softmax = nn.Softmax(dim=2)
        self.lockdrop = LockedDropout()
        
        self.hidden_states = nn.ParameterList([nn.Parameter(torch.zeros(1, hidden_size)) for i in range(nlayers)])
        self.cell_states = nn.ParameterList([nn.Parameter(torch.zeros(1, hidden_size)) for i in range(nlayers)])
        
        self.rnn_cell_0 = nn.LSTMCell(input_size=embed_size+value_size, hidden_size=hidden_size)
        if weight_drop:
            self.rnn_cell_0 = WeightDrop(self.rnn_cell_0, ['weight_hh'], dropout=0.3)
            
        self.rnn_cells = []
        for i in range(nlayers-1):
            rnn_cell = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
            if weight_drop:
                rnn_cell = WeightDrop(rnn_cell, ['weight_hh'], dropout=0.3)
            self.rnn_cells.append(rnn_cell)
            self.add_module('rnn_cell_'+str(i+1), rnn_cell)
        
        self.fc = nn.Linear(hidden_size+value_size, vocab_size)
        
    def step(self, h, last_hidden_states, att_key, att_value, utter_mask):
        h = self.lockdrop(h.unsqueeze(1), 0.2).squeeze(1)
        h, c = self.rnn_cell_0(h, last_hidden_states[0])
        
        hidden_states = [(h,c)]
        for i, rnn_cell in enumerate(self.rnn_cells):
            h = self.lockdrop(h.unsqueeze(1), 0.2).squeeze(1)
            h, c = rnn_cell(h, last_hidden_states[i+1])
            hidden_states.append((h,c))
        
        h = self.lockdrop(h.unsqueeze(1), 0.2).squeeze(1)
        query = self.projection(h).unsqueeze(1)
        energy = torch.bmm(query, att_key)
        energy = self.softmax(energy)

        # zero out pad for softmax calculation
        att_mask = utter_mask.unsqueeze(1)
        energy = energy * att_mask
        sum_ = energy.sum(dim=2, keepdim=True)
        energy = energy.div(sum_)
        
        att = torch.bmm(energy, att_value).squeeze(1)
        concat_h = torch.cat((att, h), dim=1)
        logits = self.fc(concat_h)
        return logits, energy, att, hidden_states
        
    def forward(self, x, att_key, att_value, utter_input_length, teacher_force=1.0, plot=False):
        batch_size, max_seq_len = att_key.size(0), att_key.size(1)
        utter_mask = torch.FloatTensor(batch_size, max_seq_len).fill_(1).to(DEVICE)
        for i, length in enumerate(utter_input_length):
            length = length.item()
            if length < max_seq_len:
                utter_mask[i].narrow(0,length,max_seq_len-length).fill_(0)
        utter_mask = utter_mask.unsqueeze(2)
        att_key = att_key * utter_mask
        att_value = att_value * utter_mask
        
        att_key = att_key.transpose(1,2)
        output = []
        max_seq_len = x.shape[1]
        
        x = self.embed(x)
        assert(x.size(0) == att_value.size(0))
        att = torch.zeros(att_value.size(0), att_value.size(2)).to(DEVICE)
        hidden_states = [(hidden_state.expand(batch_size, -1), cell_state.expand(batch_size, -1)) 
                         for hidden_state, cell_state in zip(self.hidden_states, self.cell_states)]
        
        utter_mask = utter_mask.squeeze(2)
        att_list = []
        for i in range(max_seq_len):
            indicator = np.random.random()
            if indicator > teacher_force:
                # use predicted chars
                if i == 0:
                    pred = torch.zeros(batch_size).long().to(DEVICE)
                else:
                    scores = F.gumbel_softmax(logits)
                    _, pred = torch.max(scores, dim=1)
                    pred = pred.long()
                pred = self.embed(pred)
                inp = torch.cat((pred, att), dim=-1)
            else:
                inp = torch.cat((x[:,i,:], att), dim=-1)
            logits, energy, att, hidden_states = self.step(inp, hidden_states, att_key, att_value, utter_mask)
            output.append(logits.unsqueeze(1))
            if plot:
                att_list.append(energy.squeeze(1).cpu().detach().numpy())
        if plot:
            self.plot_attention(att_list, utter_input_length)
        return torch.cat(output, dim=1)
    
    def generate(self, att_key, att_value, utter_input_length, beam_search=True):
        batch_size, max_seq_len = att_key.size(0), att_key.size(1)
        x = torch.zeros(batch_size).long().to(DEVICE)
        
        utter_mask = torch.FloatTensor(batch_size, max_seq_len).fill_(1).to(DEVICE)
        for i, length in enumerate(utter_input_length):
            length = length.item()
            if length < max_seq_len:
                utter_mask[i].narrow(0,length,max_seq_len-length).fill_(0)
        utter_mask = utter_mask.unsqueeze(2)
        att_key = att_key * utter_mask
        att_value = att_value * utter_mask
        att_key = att_key.transpose(1,2)
        
        att = torch.zeros(att_value.size(0), att_value.size(2)).to(DEVICE)
        hidden_states = [(hidden_state.expand(batch_size, -1), cell_state.expand(batch_size, -1)) 
                          for hidden_state, cell_state in zip(self.hidden_states, self.cell_states)]
        
        utter_mask = utter_mask.squeeze(2)

        # for each sample
#         print (att.shape, att_key.shape, att_value.shape)
        output = []
        beam_size = 32
        for batch in range(batch_size):
            beam_searcher = BeamSearch(beam_size)
            init_input = [0]
            init_hidden_states = [(hidden_state, cell_state) for hidden_state, cell_state
                                  in zip(self.hidden_states, self.cell_states)]
            init_att = att[batch]
            init_beam = Beam(init_input, 0., init_hidden_states, init_att)
            beam_searcher.top_beams.append(init_beam)
            
            step = 0
            while (len(beam_searcher.complete_beams) < beam_size) and step < 300:
                x, hidden_states, sample_att = beam_searcher.last_top_beams()
                x = self.embed(x)
                sample_att = sample_att.view(x.size(0),-1)
                inp = torch.cat((x, sample_att), dim=-1)
                
                sample_att_key = att_key[batch].expand(len(x),-1,-1)
                sample_att_value = att_value[batch].expand(len(x), -1,-1)
                sample_utter_mask = utter_mask[batch].expand(len(x), -1)
                
                logits, energy, sample_att, hidden_states = self.step(inp, hidden_states, sample_att_key,
                                                                      sample_att_value, sample_utter_mask)
                scores = F.log_softmax(logits, dim=1)
                beam_searcher.update_beams(scores, sample_att, hidden_states)
                step += 1
            output.append(beam_searcher.best_beam())
            print ('finished with sample %d' % (batch+1))
        return output
    
    def plot_attention(self, attention_list, utter_length):
        print (attention_list[0].shape)
        for i in range(BATCH_SIZE):
            attention = np.array([att[i] for att in attention_list])
            attention = attention[:utter_length[i]]
            attention /= np.linalg.norm(attention, axis=1, keepdims=True)
            plt.imshow(attention, cmap='hot')
            plt.show()

In [10]:
class Trainer:
    def __init__(self, listener, speller, train_loader, dev_loader, test_loader, max_epochs=1, run_id='exp'):
        self.train_loader = train_loader
        self.dev_loader = dev_loader
        self.test_loader = test_loader
        self.train_losses = []
        self.val_losses = []
        self.epochs = 0
        self.max_epochs = max_epochs
        self.run_id = run_id
        
        self.listener = listener.to(DEVICE)
        self.speller = speller.to(DEVICE)
        
        params = list(listener.parameters()) + list(speller.parameters())
        self.optimizer = torch.optim.Adam(params, lr=3e-4, weight_decay=1e-6)
        self.criterion = nn.CrossEntropyLoss()
        
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def train_batch(self, batch_utter, utter_order, batch_trans_data, batch_trans_label, trans_order, 
                    pre_train, teacher_force=1.0, plot=False, generation=False):
        if not pre_train:
            utter_input_length = torch.IntTensor([len(utter) for utter in batch_utter])
            batch_utter = nn.utils.rnn.pad_sequence(batch_utter, batch_first=True)
            batch_utter = batch_utter.to(DEVICE)
            att_key, att_value, utter_output_length = self.listener(batch_utter, utter_input_length)
        
            reorder = [utter_order.index(i) for i in trans_order]
            att_key, att_value = att_key[reorder], att_value[reorder] 
            utter_output_length = utter_output_length[reorder]
        else:
            random_seq_len = 100
            att_key = torch.zeros([len(batch_trans_data), random_seq_len, 128]).to(DEVICE)
            att_value = torch.zeros([len(batch_trans_data), random_seq_len, 128]).to(DEVICE)
            utter_output_length = torch.IntTensor([random_seq_len]*BATCH_SIZE)

        trans_input_length = [len(trans) for trans in batch_trans_data]
        batch_trans_data = nn.utils.rnn.pad_sequence(batch_trans_data, batch_first=True)
        batch_trans_label = nn.utils.rnn.pack_sequence(batch_trans_label)
        batch_trans_data = batch_trans_data.to(DEVICE)
        batch_trans_label = batch_trans_label.to(DEVICE)
        logits = self.speller(batch_trans_data, att_key, att_value, utter_output_length, 
                              teacher_force=teacher_force, plot=plot)
    
        logits = nn.utils.rnn.pack_padded_sequence(logits, trans_input_length, batch_first=True)
        loss = self.criterion(logits.data, batch_trans_label.data)
        return loss
    
    def train(self, pre_train=False, teacher_force=1.0):
        self.listener.train()
        self.speller.train() # set to training mode
        epoch_loss = 0
        num_batches = 0
        for batch_utter, utter_order, batch_trans_data, batch_trans_label, trans_order in self.train_loader:
            num_batches += 1
            batch_loss = self.train_batch(batch_utter, utter_order, 
                                          batch_trans_data, batch_trans_label, trans_order,
                                          pre_train, teacher_force=teacher_force)
            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()
            
            batch_loss = batch_loss.item()
            epoch_loss += batch_loss
            if num_batches % 50 == 0:
                print ('[TRAIN] Iter [%d/%d]    Perplexity: %.4f'
                      % (num_batches, len(train_data) // BATCH_SIZE, np.exp(batch_loss)))
            torch.cuda.empty_cache()
        epoch_loss = epoch_loss / num_batches
        self.epochs += 1
        print ('[TRAIN] Epoch [%d/%d]    Perplexity: %.4f'
               % (self.epochs, self.max_epochs, np.exp(epoch_loss)))
        self.train_losses.append(epoch_loss)
    
    def val(self, pre_trained=False, plot=False):
        self.listener.eval()
        self.speller.eval()
        epoch_loss = 0
        num_batches = 0
        with torch.no_grad():
            for batch_utter, utter_order, batch_trans_data, batch_trans_label, trans_order in self.dev_loader:
                num_batches += 1
                batch_loss = self.train_batch(batch_utter, utter_order, 
                                              batch_trans_data, batch_trans_label, trans_order,
                                              pre_trained, plot=plot)
                epoch_loss += batch_loss.item()
                if plot:
                    return None
                #torch.cuda.empty_cache()
            epoch_loss = epoch_loss / num_batches
            print ('[VAL] Epoch [%d/%d]    Perplexity: %.4f'
                   % (self.epochs, self.max_epochs, np.exp(epoch_loss)))
            self.val_losses.append(epoch_loss)
        return epoch_loss
    
    def test(self, val=False, random_search=False, beam_search=True):
        self.listener.eval()
        self.speller.eval()
        with torch.no_grad():
            if val:
                L_distance = 0
                num_sentences = 0
                for batch_utter, utter_order, batch_trans_data, _, trans_order in self.dev_loader:
                    utter_input_length = torch.IntTensor([len(utter) for utter in batch_utter])
                    batch_utter = nn.utils.rnn.pad_sequence(batch_utter, batch_first=True)
                    batch_utter = batch_utter.to(DEVICE)
                    att_key, att_value, utter_output_length = self.listener(batch_utter, utter_input_length)

                    reorder = [utter_order.index(i) for i in trans_order]
                    att_key, att_value = att_key[reorder], att_value[reorder] 
                    utter_output_length = utter_output_length[reorder]

                    output = self.speller.generate(att_key, att_value, utter_output_length, beam_search=beam_search)

                    for i in range(len(batch_trans_data)):
                        pred = ''.join([char_set[c] for c in output[i]])
                        true = ''.join([char_set[c] for c in batch_trans_data[i]])
                        L_distance += L.distance(pred, true)
                        if i == 5:
                            print ("PRED: ", pred)
                            print ("TRUE: ", true)
                    num_sentences += len(batch_trans_data)
                    print (L_distance / num_sentences)

                print ('[VAL] Epoch [%d/%d]    L distance: %.4f'
                      % (self.epochs, self.max_epochs, L_distance / num_sentences))
            else:
                index = 0
                f = open('./try_conv_1.12.csv', 'w')
                wr = csv.writer(f, dialect='excel')
                wr.writerow(['Id', 'Predicted'])
                for batch_utter, utter_order in self.test_loader:
                    utter_input_length = torch.IntTensor([len(utter) for utter in batch_utter])
                    batch_utter = nn.utils.rnn.pad_sequence(batch_utter, batch_first=True)
                    batch_utter = batch_utter.to(DEVICE)
                    att_key, att_value, utter_output_length = self.listener(batch_utter, utter_input_length)
                
                    reorder = [utter_order.index(i) for i in range(len(batch_utter))]
                    att_key, att_value = att_key[reorder], att_value[reorder] 
                    utter_output_length = utter_output_length[reorder]
                    
                    output = self.speller.generate(att_key, att_value, utter_output_length, beam_search=beam_search)
                    
                    for i in range(len(batch_utter)):
                        pred = ''.join([char_set[c] for c in output[i]])
                        wr.writerow([index, pred])
                        index += 1
                f.close()
                print ('Finished predictions')
    
    def save(self, pre_trained=False):
        listener_path = os.path.join('./conv_experiments', self.run_id, 'listener-{}.pt'.format(self.epochs))
        speller_path = os.path.join('./conv_experiments', self.run_id, 'speller-{}.pt'.format(self.epochs))
        if not pre_trained:
            torch.save(self.listener.state_dict(), listener_path)
        torch.save(self.speller.state_dict(), speller_path)

    def load(self, listener_path, speller_path):
        if listener_path:
            self.listener.load_state_dict(torch.load(listener_path))
            print ("Loaded trained listener")
        self.speller.load_state_dict(torch.load(speller_path))
        print ("Loaded trained speller")

In [14]:
# topic = 'train_3by3_no_norm_on_spell_with_pretrain_0.7tf'
# run_id = topic + str(int(time.time()))
# if not os.path.exists('./conv_experiments'):
#     os.mkdir('./conv_experiments')
# os.mkdir('./conv_experiments/%s' % run_id)
# print("Saving models, predictions, and generated words to ./conv_experiments/%s" % run_id)
run_id = 'test'

In [15]:
NUM_EPOCHS = 20
BATCH_SIZE = 64
PRE_TRAIN = False

In [16]:
listener = Listener(input_size=40, hidden_size=256, nlayers=3, key_size=128, value_size=128)
speller = Speller(vocab_size=len(char_set), embed_size=128, hidden_size=512, key_size=128, value_size=128, nlayers=3)

train_loader = LASDataLoader(train_data, train_transcript, BATCH_SIZE)
dev_loader = LASDataLoader(dev_data, dev_transcript, BATCH_SIZE, shuffle=True)
test_loader = LASDataLoader(test_data, None, BATCH_SIZE, shuffle=False)

trainer = Trainer(listener, speller, train_loader, dev_loader, test_loader, max_epochs=NUM_EPOCHS, run_id=run_id)

Applying weight drop of 0.3 to weight_hh_l0
Applying weight drop of 0.3 to weight_hh_l0
Applying weight drop of 0.3 to weight_hh_l0


In [17]:
# trainer.test(val=True)

In [18]:
trainer.load('./conv_experiments/train_3by3_no_norm_on_spell_with_pretrain_0.7tf1543267208/listener-29.pt', 
             './conv_experiments/train_3by3_no_norm_on_spell_with_pretrain_0.7tf1543267208/speller-29.pt')
# trainer.update_lr(1e-4)

# trainer.val(pre_trained=False, plot=True)
trainer.test()
# trainer.load(None, './experiments/pretrain_with_learned_hidden1543130180/speller-5.pt')

Loaded trained listener
Loaded trained speller


  result = self.forward(*input, **kwargs)


finished with sample 1
finished with sample 2
finished with sample 3
finished with sample 4
finished with sample 5
finished with sample 6
finished with sample 7
finished with sample 8
finished with sample 9
finished with sample 10
finished with sample 11
finished with sample 12
finished with sample 13
finished with sample 14
finished with sample 15
finished with sample 16
finished with sample 17
finished with sample 18
finished with sample 19
finished with sample 20
finished with sample 21
finished with sample 22
finished with sample 23
finished with sample 24
finished with sample 25
finished with sample 26
finished with sample 27
finished with sample 28
finished with sample 29
finished with sample 30
finished with sample 31
finished with sample 32
finished with sample 33
finished with sample 34
finished with sample 35
finished with sample 36
finished with sample 37
finished with sample 38
finished with sample 39
finished with sample 40
finished with sample 41
finished with sample 42
f

finished with sample 25
finished with sample 26
finished with sample 27
finished with sample 28
finished with sample 29
finished with sample 30
finished with sample 31
finished with sample 32
finished with sample 33
finished with sample 34
finished with sample 35
finished with sample 36
finished with sample 37
finished with sample 38
finished with sample 39
finished with sample 40
finished with sample 41
finished with sample 42
finished with sample 43
finished with sample 44
finished with sample 45
finished with sample 46
finished with sample 47
finished with sample 48
finished with sample 49
finished with sample 50
finished with sample 51
finished with sample 52
finished with sample 53
finished with sample 54
finished with sample 55
finished with sample 56
finished with sample 57
finished with sample 58
finished with sample 59
finished with sample 60
finished with sample 61
finished with sample 62
finished with sample 63
finished with sample 64
finished with sample 1
finished with sam

In [None]:
# # best_error = 1e30  # set to super large value at first
# # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, 'min', factor=0.3, patience=1, verbose=True)
# # scheduler = torch.optim.lr_scheduler.StepLR(trainer.optimizer, step_size=5, gamma=0.3)
# for epoch in range(NUM_EPOCHS):
# #     scheduler.step()
#     trainer.train(pre_train=PRE_TRAIN, teacher_force=0.7)
#     error = trainer.val(pre_trained=PRE_TRAIN)
#     scheduler.step(error)
#     if error < best_error:
#         best_error = error
#         print("Saving model, predictions and generated output for epoch " + 
#                 str(epoch)+" with Error: " + str(np.exp(best_error)))
#         trainer.save(pre_trained=PRE_TRAIN)

In [None]:
torch.cuda.empty_cache()

In [None]:
# plot training curves
plt.figure()
plt.plot(range(1, trainer.epochs + 1), trainer.train_losses, label='Training losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure()
plt.plot(range(1, trainer.epochs + 1), trainer.val_losses, label='Validation NLL')
plt.xlabel('Epochs')
plt.ylabel('NLL')
plt.legend()
plt.show()