In [None]:
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
from torch.utils.data.dataloader import _use_shared_memory
np.set_printoptions(threshold=np.nan)
import pdb

In [None]:
use_cuda = True

path=""
trainX = np.load(path + 'train.npy', encoding='bytes')
trainY = np.load(path + 'train_transcripts.npy', encoding='bytes')

valX = np.load(path + 'dev.npy', encoding='bytes')
valY = np.load(path + 'dev_transcripts.npy', encoding='bytes')

In [None]:
def get_charmap(corpus):
    chars = list(set(corpus))
    chars.sort()
    charmap = {c: i for i, c in enumerate(chars)}
    return chars, charmap


def map_corpus(corpus, charmap):
    return torch.IntTensor([int(charmap[c]) for c in corpus])


corpus = " ".join(trainY)
chars, charmap = get_charmap(corpus)
charcount = len(chars)
print("Unique character count: {}".format(len(chars)))
array = map_corpus(corpus, charmap)

In [None]:
print(chars)
print(charmap)

In [None]:
class SpeechDataset(Dataset):
    def __init__(self, x, y, evaluate=False):
        self.x = [torch.from_numpy(utterance).float() for utterance in x]
        self.y = [map_corpus(utterance, charmap) for utterance in y]
        
        self.evaluate = evaluate

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        if self.evaluate is True:
            return self.x[ idx ], torch.Tensor(1)
        else:
            return self.x[ idx ], self.y[ idx ]

valSet =  SpeechDataset(x=valX, y=valY)
trainSet = SpeechDataset(x=trainX, y=trainY)

In [None]:
batch_size = 32
attention_dim = 128
hidden_dim = 256
SOS = charcount
EOS = charcount + 1


class SequencePooling(nn.Module):
    def __init__(self):
        super(SequencePooling, self).__init__()

    def forward(self, h):
        h, lengths = pad_packed_sequence(h)
        max_length = h.size(0)
        new_length = max_length //2
        
        h = h.transpose(0, 1)
 
        # pooling
        if max_length % 2 == 1:
            h = h[:, 0:max_length - 1 , :]
        h = h.contiguous().view(batch_size, new_length, 2, hidden_dim*2)
        h = torch.mean(h, 2)        
        h = h.transpose(0, 1)
 
        return pack_padded_sequence(h, [l // 2 for l in lengths], batch_first=False)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()

        self.bilstm = nn.LSTM(input_size=40, hidden_size=hidden_dim, batch_first=False, bidirectional=True)
        
        self.pbilstm1 = nn.LSTM(input_size=hidden_dim*2, hidden_size=hidden_dim, batch_first=False, bidirectional=True)
        self.pbilstm2 = nn.LSTM(input_size=hidden_dim*2, hidden_size=hidden_dim, batch_first=False, bidirectional=True)
        self.pbilstm3 = nn.LSTM(input_size=hidden_dim*2, hidden_size=hidden_dim, batch_first=False, bidirectional=True)
        
        self.sequencepooling = SequencePooling()
        
        self.linear1 = nn.Linear(in_features=hidden_dim*2, out_features=attention_dim)
        self.linear2 = nn.Linear(in_features=hidden_dim*2, out_features=attention_dim)
        
    def forward(self, utterance_inputs, lengths):
        h = utterance_inputs

        h = pack_padded_sequence(h, lengths.data.cpu().numpy(), batch_first=False)         
        
        h, _ = self.bilstm(h)
        
        h = self.sequencepooling(h)
        
        h, _ =  self.pbilstm1(h)
        
        h = self.sequencepooling(h)
        h, _ =  self.pbilstm2(h)
        
        h = self.sequencepooling(h)
        h, _ =  self.pbilstm3(h)
        
        h, _ = pad_packed_sequence(h)
              
        attention_keys = self.linear1(h)
        attention_values = self.linear2(h)
        
                
        if attention_keys.data.sum() !=attention_keys.data.sum():
            pdb.set_trace()
        
        return attention_keys, attention_values


In [None]:
class TextDecoder(nn.Module):
    def __init__(self):
        super(TextDecoder, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings=charcount+2, embedding_dim=attention_dim)

        self.h_0_1 = nn.Parameter(torch.zeros(1, hidden_dim))
        self.c_0_1 = nn.Parameter(torch.zeros(1, hidden_dim))
        self.h_0_2 = nn.Parameter(torch.zeros(1, hidden_dim))
        self.c_0_2 = nn.Parameter(torch.zeros(1, hidden_dim))
        self.h_0_3 = nn.Parameter(torch.zeros(1, attention_dim))
        self.c_0_3 = nn.Parameter(torch.zeros(1, attention_dim))
        
        self.h1 = self.h_0_1.expand(batch_size,-1)
        self.h2 = self.h_0_2.expand(batch_size,-1)
        self.h3 = self.h_0_3.expand(batch_size,-1)
        
        self.c1 = self.c_0_1.expand(batch_size,-1)
        self.c2 = self.c_0_2.expand(batch_size,-1)
        self.c3 = self.c_0_3.expand(batch_size,-1)
        
        
        self.cell = nn.LSTMCell(input_size=attention_dim*2, hidden_size=hidden_dim)
        self.cell2 = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim)
        self.cell3 = nn.LSTMCell(input_size=hidden_dim, hidden_size=attention_dim)
         
    def forward(self, inputs, previous_context, char_index):
        h = inputs
        
        h = self.embedding(h)
        
        previous_context = previous_context.squeeze()

        h = torch.cat((h,previous_context), 1)
        
        
        if char_index == 0:
            self.h1 = self.h_0_1.expand(batch_size,-1)
            self.h2 = self.h_0_2.expand(batch_size,-1)
            self.h3 = self.h_0_3.expand(batch_size,-1)

            self.c1 = self.c_0_1.expand(batch_size,-1)
            self.c2 = self.c_0_2.expand(batch_size,-1)
            self.c3 = self.c_0_3.expand(batch_size,-1)
        
        self.h1, self.c1 = self.cell(h, (self.h1, self.c1))
        
        self.h2, self.c2 = self.cell2(self.h1, (self.h2, self.c2))
        
        self.h3, self.c3 = self.cell3(self.h2, (self.h3, self.c3))
        
        return self.h3

In [None]:
class LasModel(nn.Module):
    def __init__(self):
        super(LasModel, self).__init__()
        
        self.encoder = TextEncoder()
        self.decoder = TextDecoder()

        self.linearMlp = nn.Linear(in_features=hidden_dim, out_features=attention_dim)

        self.lrelu = nn.LeakyReLU()
        
        self.linear = nn.Linear(in_features=attention_dim, out_features=charcount+2)
        
    def forward(self, utterance_inputs, lengths, transcript_inputs, transcript_targets, transcript_lengths):
        attention_keys, attention_values = self.encoder(utterance_inputs, lengths)

        context = Variable(torch.FloatTensor(batch_size, 1, attention_dim).zero_()).cuda()

        max_utterance_length = utterance_inputs.size(0)
        max_utterance_length_after_pbilstm = max_utterance_length//8
        
        max_transcript_length = transcript_inputs.size(1)
        
        attention_keys = attention_keys.transpose(0,1).transpose(1,2)

        attention_values = attention_values.transpose(0,1)
        
        mask = torch.FloatTensor(batch_size, 1, max_utterance_length_after_pbilstm).zero_()
        for i in range(batch_size):
            utterance_length = int(lengths[i])//8
            mask[ i, 0,  0:utterance_length] = torch.ones(1, utterance_length).float()
        mask = Variable(mask).cuda()

        out = []
        for i in range(max_transcript_length):
            chars = transcript_inputs[:, i]  #32 x 1  teacher forcing all the time!

            query = self.decoder(chars, context, i)

            energy = torch.bmm(query.unsqueeze(1), attention_keys)

            attention = F.softmax(energy, dim=2)

            attention = attention * mask

            attention = attention/torch.sum(attention, 2).unsqueeze(2)

            context = torch.bmm(attention, attention_values)

            context = context.squeeze(1)

            mlp_input = torch.cat((context,query), 1)

            logit = self.linearMlp( mlp_input )
            logit = self.lrelu(logit)
            logit = self.linear(logit)

            out += [logit]

        out = torch.stack(out,1)
        return out

In [None]:
def my_collate(batch):
    batch_size = len(batch)

        
    lengths = [utterance[0].size(0) for utterance in batch]

    lengths = torch.IntTensor(lengths)
    
    lengths, length_idx = lengths.sort(0, descending=True)

    max_utterance_length = lengths[0]
    
    transcript_lengths = [utterance[1].size(0) for utterance in batch]
    
    max_transcript_length = max(transcript_lengths)+1
    
    if _use_shared_memory:
        utterance_inputs = torch.FloatStorage.storage()._new_shared(max_utterance_length*batch_size*40).new(batch_size, 40, max_utterance_length).zero_()
        transcript_inputs = torch.LongStorage.storage()._new_shared(batch_size*max_transcript_length).new(batch_size, max_transcript_length).zero_()
        transcript_targets = torch.LongStorage.storage()._new_shared(batch_size*max_transcript_length).new(batch_size, max_transcript_length).zero_()
        transcript_lengths = torch.IntStorage.storage()._new_shared(batch_size).new(batch_size,).zero_()
    else:
        utterance_inputs = torch.FloatTensor(batch_size, 40, max_utterance_length).zero_()
        transcript_inputs = torch.LongTensor(batch_size, max_transcript_length).zero_()
        transcript_targets = torch.LongTensor(batch_size, max_transcript_length).zero_()
        transcript_lengths = torch.IntTensor(batch_size,).zero_()

    for idx, sorted_idx in enumerate(length_idx):
        utterance_input = batch[sorted_idx][0] #length x 40
        utterance_inputs[idx, :, 0:utterance_input.size(0)] = torch.transpose(utterance_input, 0 , 1)
        
        mapped_transcript = batch[sorted_idx][1]
        transcript_lengths[idx] = mapped_transcript.size(0) + 1
        
        transcript_inputs[idx, 1:mapped_transcript.size(0)+1] = mapped_transcript
        transcript_inputs[idx, 0] = SOS
        
        transcript_targets[idx, 0:mapped_transcript.size(0)] = mapped_transcript
        transcript_targets[idx, mapped_transcript.size(0)] = EOS
    
    utterance_inputs = utterance_inputs.transpose(0,2).transpose(1,2)
    
    return utterance_inputs, lengths, transcript_inputs, transcript_targets, transcript_lengths

pin_memory = False 
if use_cuda:
    pin_memory = True 
print("using cuda: ", use_cuda)

train_loader = torch.utils.data.DataLoader(trainSet, num_workers=2, pin_memory=pin_memory, shuffle=True,
                                           batch_size=batch_size, collate_fn=my_collate)
val_loader = torch.utils.data.DataLoader(valSet,  num_workers=2, pin_memory=pin_memory,
                                           batch_size=batch_size, collate_fn=my_collate)




In [None]:
class CrossEntropyLoss3D(nn.CrossEntropyLoss):
    def __init__(self):
        super(CrossEntropyLoss3D, self).__init__(reduce=False)
    def forward(self, input, target):
        return super(CrossEntropyLoss3D, self).forward(input.view(-1, input.size()[2]), target.view(-1))

model = LasModel()
if use_cuda:
    model=model.cuda()

optimizer = torch.optim.Adam(model.parameters())

loss_function = CrossEntropyLoss3D()

In [None]:
epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

model.train()

#model.load_state_dict(torch.load('model_epoch_84.pt'))
for epoch in range(epochs):
    epoch_loss = 0.
    for batch_index, sample in enumerate(train_loader):
        optimizer.zero_grad()
        
        utterance_inputs = Variable(sample[0]).cuda()
        lengths = Variable(sample[1]).cuda()
        transcript_inputs = Variable(sample[2]).cuda()
        transcript_targets = Variable(sample[3]).cuda()
        transcript_lengths = Variable(sample[4]).cuda()

        output = model(utterance_inputs, lengths, transcript_inputs, transcript_targets, transcript_lengths)

        loss = loss_function(output, transcript_targets)  
                        
        if loss.data.sum() !=loss.data.sum():
            pdb.set_trace()

        maxlength = transcript_targets.size(1)
        transcript_mask = torch.FloatTensor(batch_size, maxlength).zero_()
        for i in range(batch_size):
            transcript_length = int(transcript_lengths[i])
            transcript_mask[ i, 0:transcript_length] = torch.ones(1, transcript_length).float()
        transcript_mask = Variable(transcript_mask).cuda()

        loss = loss.view(batch_size, maxlength)

        loss = loss * transcript_mask

        loss = torch.sum(loss, dim=1)

        
        loss = torch.mean(loss)


        loss.backward()
        
        torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)
        
        optimizer.step()
        epoch_loss += loss.data.sum()
       
        print("batch_index", batch_index, "loss batch: ", loss.data.sum(), ", epoch loss: ", epoch_loss/(batch_index+1))

    print("loss epoch: ", epoch_loss/len(train_loader))
    torch.save(model.state_dict(), "model_epoch_"+str(epoch)+".pt")

In [None]:
torch.save(model.state_dict(), "model_epoch_last.pt")