In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import *
from phoneme_list import *
from torch.utils import data
from ctcdecode import CTCBeamDecoder
import Levenshtein
from Levenshtein import distance
train_data = np.load('train.npy', allow_pickle=True)
train_labels = np.load('train_labels.npy', allow_pickle=True)
cuda = True

device = torch.device("cuda:0" if cuda else "cpu")

dev_data = np.load('dev.npy', allow_pickle=True)
dev_labels = np.load('dev_labels.npy', allow_pickle=True)



In [2]:
class MyDataset(data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        X = self.X[index]
        Y = self.Y[index] + 1
        return torch.from_numpy(X).float(), X.shape[0], torch.from_numpy(Y).long(), Y.shape[0]



In [3]:
num_workers = 8 if cuda else 0
batch_size = 64


In [4]:
def collate(sequences):
    X_list, X_shapes, Y_list, Y_shapes = zip(*sequences)
    return pad_sequence(X_list, batch_first=True), torch.Tensor(np.asarray(X_shapes)), pad_sequence(Y_list, batch_first=True), torch.tensor(np.asarray(Y_shapes))


In [5]:
train_dataset = MyDataset(train_data, train_labels)

train_loader_args = dict(shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                         collate_fn=collate) if cuda else dict(shuffle=True, batch_size=64, collate_fn=collate)


train_loader = data.DataLoader(train_dataset, **train_loader_args)


In [6]:
class Model(nn.Module):
    def __init__(self, out_vocab, input_feature_size):
        super(Model, self).__init__()

        self.lstm = nn.LSTM(input_feature_size, 256, bidirectional=True,
                            batch_first=True, num_layers=4, dropout=0.3)
        self.output1 = nn.Linear(512, 256)
        self.output2 = nn.Linear(256, out_vocab)

    def forward(self, X, lengths):

        packed_X = pack_padded_sequence(X, lengths, enforce_sorted=False, batch_first=True)
        packed_out = self.lstm(packed_X)[0]

        out, out_lens = pad_packed_sequence(packed_out, batch_first=True)
        
        out = self.output1(out)
        out = self.output2(out)
        out = out.log_softmax(2)
        return out, out_lens


In [7]:
model = Model(len(PHONEME_LIST), len(train_dataset[0][0][0]))
model.to(device)


Model(
  (lstm): LSTM(13, 256, num_layers=4, batch_first=True, dropout=0.3, bidirectional=True)
  (output1): Linear(in_features=512, out_features=256, bias=True)
  (output2): Linear(in_features=256, out_features=42, bias=True)
)

In [8]:
criterion = nn.CTCLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=0, min_lr=1e-4, verbose=True)


In [9]:
dev_dataset = MyDataset(dev_data, dev_labels)

dev_loader_args = dict(shuffle=False, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                       collate_fn=collate) if cuda else dict(shuffle=False, batch_size=64, collate_fn=collate)
dev_loader = data.DataLoader(dev_dataset, **dev_loader_args)


In [10]:
decoder = CTCBeamDecoder(PHONEME_MAP, beam_width=20, num_processes=8,
                         blank_id=0, log_probs_input=True)


In [11]:
def seq2phoneme(sequence, seq_len):
    result = ''
    if isinstance(seq_len, torch.Tensor):
        seq_len = seq_len.item()
    for i in range(seq_len):
        result += PHONEME_MAP[sequence[i]]
    return result


In [12]:
print(len(train_loader))
for epoch in range(30):
    model.train()
    count = 0
    avg_train_loss = 0
    for (X, X_lens, Y, Y_lens) in train_loader:
        optimizer.zero_grad()

        X = X.to(device)
        X_lens = X_lens.to(device)
        Y = Y.to(device)
        Y_lens = Y_lens.to(device)
        count += 1
        out, out_lens = model(X, X_lens)
        out = out.permute(1,0,2)
        loss = criterion(out, Y, out_lens, Y_lens)
        if count % 100 == 0:
            print("AVG_loss", count, 'loss=%f' % (avg_train_loss/100))
            avg_train_loss = 0
        else:
            avg_train_loss += float(loss)
        
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        del X
        del X_lens
        del Y
        del Y_lens
        del loss
        
        
    with torch.no_grad():
        model.eval()
        total_dist = []
        loss = None
        for (test_X, test_X_lens, test_Y, test_Y_lens) in dev_loader:

            test_X = test_X.to(device)
            test_X_lens = test_X_lens.to(device)
            test_Y = test_Y.to(device)
            test_Y_lens = test_Y_lens.to(device)

            out, out_lens = model(test_X, test_X_lens)
            out = out.permute(1,0,2)
            if loss is None:
                loss = float(criterion(out, test_Y, out_lens, test_Y_lens))
            else:
                loss += float(criterion(out, test_Y, out_lens, test_Y_lens))
            out = out.permute(1, 0, 2)
            
            
            Y_hat, _, _, Y_hat_lens = decoder.decode(out, out_lens)
            assert(len(Y_hat) == len(test_Y))
            val_seq_array = []
            for i in range(len(Y_hat)):
                val_seq_array.append(
                    (seq2phoneme(Y_hat[i][0], Y_hat_lens[i][0]), seq2phoneme(test_Y[i], test_Y_lens[i])))

            for ex in val_seq_array:
                if val_seq_array.index(ex) % 100 == 0:
                    if epoch % 3 == 0:
                total_dist.append(distance(ex[0], ex[1]))
            torch.cuda.empty_cache()
            del test_X
            del test_X_lens
            del test_Y
            del test_Y_lens
            
        print("Average Lev Distance", epoch, np.mean(total_dist))
        print("AVG VAL LOSS PER BATCH", loss/len(dev_loader))
        scheduler.step(loss/len(dev_loader))
    if (epoch > 9) and (epoch % 3 == 0 or epoch == 29):
        checkpoint = {'epoch': epoch, "model": model.state_dict(
        ), 'optimizer': optimizer.state_dict(), 'loss': criterion, 'scheduler': scheduler.state_dict()}
        torch.save(checkpoint, "checkpointLAST_%d.pth" % (epoch))


344
AVG_loss 100 loss=4.279128
AVG_loss 200 loss=3.383192
AVG_loss 300 loss=3.386090
h
.mistrkWiltrizDIhpashlhvDhmidhlklAshz.hndWIrglAdtUWelkhmHizgasphl.
65
h
.hWEvhvdispeROldwtfrm!.bRyhnsenstit.AndnUDhfifTpYnt.WhzHiz.
58
h
.onDhlAsAtidEinEpRhl.DhnU?oRktymzphbliSthnhkwnthvDhstRyk.kamplhkEShnz.WicWrdilEiNAlhgzAndrznUjrzIbRij.hndstEtidDAtDIenjhniRHimselfWhzintwn.AndAtHizofhsonWest.tenTstRIt.
167
h
.tUmIt.WhztUfyndIchDr.
21
h
.veRIhsdiSizrfRIkWhntlIoRnhmentid.hndgaRniSt.WiDitsgREsfhlIvz.hndDIzrshmtymzbYldinsUps.olDOitizmoR?UZhlIkhnfyndiniNgliSkukrI.tUDhmAkrhlsosAzHIRgivhn.
148
h
.DIzprshnzDen.displEdtoRdzIchDr.pRisyslIDhsEmp?URIljelhsIz.WicAnhmEtDhmenhvdimakRhsIz.DhsEm.IgrnhsthsnAcDhsmolhsthdvAntijiz.WicDeR.IkWhlzkhntesthd.AndDhsEmdizyrthprEdasthntEShslI.DOzhvWicDEWrinphzeShn.
200
h
.HIgEvTANksfoRaRfUdhndkhmfrt.hndpREdfrDhpURAndesthtUtingREtsitIz.HWeRDhstRhghlfrlyfWhzHaRdrDhnitWhzHIRWiThs.
107
h
.duRiNDIakShnHIWenthbwtWiDHizHedwnhndnevrliftidHizyz.
52
h
.ingudfETHwevr.HIiznatshfiShntlIimA

.Dhkelhs.bhtikhdhlIvhpofhl.smwbEglIWhDhzklUdlips.ispEstImOnTEstUt.hvhDhWIthvneliNkhlIhset.
.mykelis.DhtikhthvlIvhpashl.smyldvEglIWiDHizglUdlips.HizpEstImUnfEsdRUpt.hndrDhWEthvmelhnkalIhsent.
32
.AnSIsolDIhDrbrdzRapiNhbwpintWitrInHwpDhslI.
.AndSIsoDIhDrbrdzHapiNhbwtAndtWitriNHelplhslI.
11
.myfRAndhWiTmenIhDrz.WhzbEnkeRItwt.thdyolsWeR.
.myfRend.WiTmenIhDrz.WhzbIiNkeRIdwt.thdyelsWeR.
7
.WhzHizmyndWandriNhntshmDhcREn.hvsot.
.WhzHizmyndWandriNinthshmhDrtREn.hvTot.
6
.iWhshvviNteRWhltUidhssyhd.AdidiN.DhmicdispehvlhtIEHhztiRsnhdlIWhldwnHiztIks.
.itWhzshmTiNteRhbhltUWitnhsDhsylhnt.AghnI.Dhm?Ut.dispeRhv!HUztiRzsylhntlIROldwnHizcIks.
38
.DhsekhmpikinsR.AnEnjhl.DhTeRkHiR.drson.ynel.
.DhsekhnbiginzHIR.AnEnjhl.DhTrdHIR.diRsoN.ynO.
15
Average Lev Distance 3 19.471269296740996
AVG VAL LOSS PER BATCH 0.9380115302833351
AVG_loss 100 loss=0.859706
AVG_loss 200 loss=0.806642
AVG_loss 300 loss=0.764517
Average Lev Distance 4 16.328473413379072
AVG VAL LOSS PER BATCH 0.7843438805760564
AVG_loss 100 loss=0.

.DIzpashnzDem.DisplEtWhtchph.prsyzHIDhsEmp?uRoldelsIz.WicAnhmytDhmhnhvdhmakRhsIz.DhsEv.IvinezthsnhctsmUlzudvadijiz.WrsoRtIkWhlskhmpestit.AndDsEmdizaRthprEthshpEShslI.DOzhHWicDEWrhnphziShn.
.DIzprshnzDen.displEdtoRdzIchDr.pRisyslIDhsEmp?URIljelhsIz.WicAnhmEtDhmenhvdimakRhsIz.DhsEm.IgrnhsthsnAcDhsmolhsthdvAntijiz.WicDeR.IkWhlzkhntesthd.AndDhsEmdizyrthprEdasthntEShslI.DOzhvWicDEWrinphzeShn.
59
.HIgEvTANksfoRhfUdAndkhmfrt.hndpREdfrDhpURindistitUdingRIsitIz.WiDhstRhghlfrlytWhzHaRtrDhnitWhzHIRWiTest.
.HIgEvTANksfoRaRfUdhndkhmfrt.hndpREdfrDhpURAndesthtUtingREtsitIz.HWeRDhstRhghlfrlyfWhzHaRdrDhnitWhzHIRWiThs.
16
.driN?okShnHIWenthbwtWidHizHeddwnnevrleftidHizyz.
.duRiNDIakShnHIWenthbwtWiDHizHedwnhndnevrliftidHizyz.
11
.AndgudfEtHwevr.HIHUznatshfiShntmImAjhnhti.thflAtrHimselfWiTislythstHOphvDiskynd.
.ingudfETHwevr.HIiznatshfiShntlIimAjhnhtiv.thflAtrHimselfWiTDhslythstHOphvDiskynd.
10
.itWhzhndrshcaspiShsTrkhmztAnsizDAtSIztytidwtDismoRniNthlhkfrWrk.
.itWhzhndrshcaspiShsrkhmstAnsizDAtSIstaRthdwtDi

.?UdOmINtisEHIzgOiNthflodkrklhn.
.?UdOntmIntisEHIzgOiNtiflagkrklhnd.
7
.hndDEinkhnsisthnoRmhnfelhpanHizbhtmIbRestWIpinkOpIhslIp.
.AndDIinkhnsisthntWumhnfelhpanHizbhthnIbRestWIpiNkOp?hslI.
10
.DEHAlitkhnbenthkOlzhnIOkhnfIldz.And.bIiNtRItidWiTgREtfyhlhntsinkUltIgWyDhskRatiSkhvrt.WrDhnWhnsthkWaRnstuRiNnOsREnz.
.DEHeldkhnvenikhlzinDIOphnfIldz.And.bIiNtRItidWiTgREtvyhlhnsAndkRUltIbyDhskatiSghvrmhnt.moRDhnWhnstukaRmzdriNDOzREnz.
34
.WInkaRnEShnhnDhkaNvrshspiRhts.
.RIinkaRnEShnhndDhkanvrshvspiRits.
6
.yluktAthtpRIklOs.hndysedgREt.gad.
.yluktAtitpRitIklOs.hndysedgREt.gad.
3
.DisOphniNlukfr.olDhWrld.lykhmwThvHil.inDhWrdzhvDIytenrhntp?uRhthnpRIcrz.HUtrndhWEfrmitWiTHoR.
.DisOphniNluktfr.olDhWrld.lykhmwThvHel.inDhWrdzhvDI.ytinrhntp?uRhthnpRIcrz.HUtrndhWEfrmitWiTHoRr.
5
.ykudnatHelpsmyliNtismluksibEgonHizlithlWoRs.HizloNlegnwhndDenthciNDhgRwndmEdHimluklykhsikslitidsentaR.
.ykudnatHelpsmyliNtisIimluksObiganHizlithlHoRs.HizloNlegznwhndDenthciNDhgRwndmEdHimluklykhsiksfutidsentoR.
10
.bhtHIlukbAkkAtrRh

.mytfRendWiTmenIhDrz.WhzbIiNkeRIdwt.tidyelsWeR.
.myfRend.WiTmenIhDrz.WhzbIiNkeRIdwt.thdyelsWeR.
3
.WhzHizmyndWandriNinthshmhDrtREn.hvTot.
.WhzHizmyndWandriNinthshmhDrtREn.hvTot.
0
.itWhzshmTiNteRbhltUitnhs.hsydhnt.AdhnI.DhmUtdispeRhvWhtEHUztiRzsylhdlIWOldanHizcIks.
.itWhzshmTiNteRhbhltUWitnhsDhsylhnt.AghnI.Dhm?Ut.dispeRhv!HUztiRzsylhntlIROldwnHizcIks.
15
.DhsekhnbiginzHIR.AndEnShl.DhTeRdIR.drRshN.ynw.
.DhsekhnbiginzHIR.AnEnjhl.DhTrdHIR.diRsoN.ynO.
8
Average Lev Distance 15 9.40394511149228
AVG VAL LOSS PER BATCH 0.48007070937672175
AVG_loss 100 loss=0.336543
AVG_loss 200 loss=0.342009
AVG_loss 300 loss=0.341347
Average Lev Distance 16 9.360205831903945
AVG VAL LOSS PER BATCH 0.4802830348143706
Epoch    17: reducing learning rate of group 0 to 1.0000e-04.
AVG_loss 100 loss=0.294635
AVG_loss 200 loss=0.278574
AVG_loss 300 loss=0.271658
Average Lev Distance 17 8.19253859348199
AVG VAL LOSS PER BATCH 0.4272573513759149
AVG_loss 100 loss=0.260090
AVG_loss 200 loss=0.256486
AVG_loss 300 loss

.HIgEvTANksfWeRhfUdAndkhmfrt.hndpREdfrDhpURindhsthtUdingREtsitIz.WeDhstRhghlfrlytWhzHaRtrDhnitWhzHIRWiThst.
.HIgEvTANksfoRaRfUdhndkhmfrt.hndpREdfrDhpURAndesthtUtingREtsitIz.HWeRDhstRhghlfrlyfWhzHaRdrDhnitWhzHIRWiThs.
13
.duiN?okShnHIWenthbwtWitHizHedwnnevrleftidHizyz.
.duRiNDIakShnHIWenthbwtWiDHizHedwnhndnevrliftidHizyz.
9
.AndgudfEtHwevr.HIHAznatshfiShntlIimAjhnhti.thflAtrHimselfWiTDhslythstHOphvDiskynd.
.ingudfETHwevr.HIiznatshfiShntlIimAjhnhtiv.thflAtrHimselfWiTDhslythstHOphvDiskynd.
6
.itWhzhndrshcospiShsrkhmstAnsizDAtSIstytidwtDismoRniNthlukfrWrk.
.itWhzhndrshcaspiShsrkhmstAnsizDAtSIstaRthdwtDismoRniNthlukfrWrk.
4
.HIcRAbhldevrmenIlibeldItElT.AndtoktprfhNktoRlIthevRIbadI.
.HItRhbhldOvrmenIlithldItElz.AndtoktprfhNktRhlIthevRIbadI.
8
.DhHOTiNWhzhtRyfhlad.
.DhHOlTiNWhzhtRyfhlad.
1
.Azitiz.hndlesyAmistEkhn.shmhvDhRendiN.WhlbIon.aRsyd.AndDEnOit.
.Azitiz.hnlesyAmistEkhn.shmhvDhRendiN.WhlbIan.aRsyd.AndDEnOit.
2
.HIsmyldgilthlIezHIAdid.bhtymhstAdmedyWhzmoRDhnhlithlkhnsrndmyself.
.HIsmyldg

.WIinkrRnEShnhnDhkanvrshspiRhnts.
.RIinkaRnEShnhndDhkanvrshvspiRits.
6
.yluktDAtAtpRIklOs.hndysedgREt.gad.
.yluktAtitpRitIklOs.hndysedgREt.gad.
4
.DisOphniNluktfr.olDhWrld.lykhmwThvHil.inDhWrdzhvDIytenrhntp?uRhthnpRIcrz.HUtrndhWEfrmitWiTHaR.
.DisOphniNluktfr.olDhWrld.lykhmwThvHel.inDhWrdzhvDI.ytinrhntp?uRhthnpRIcrz.HUtrndhWEfrmitWiTHoRr.
5
.ykudnatHelpsmyliNtisIUmluksObEganHizlithlHoRs.HizloNlegnwhnDenthciNDhgRwnddmEdHimluklykhsikslutitsentaR.
.ykudnatHelpsmyliNtisIimluksObiganHizlithlHoRs.HizloNlegznwhndDenthciNDhgRwndmEdHimluklykhsiksfutidsentoR.
8
.bhtHIlukbAkAcaRlsthn.DhtgEDhvolzonDhb?Uthfhl.WhzRilhfekShn.
.bhtHIluktbAkAtcaRhlsthn.DhgEDhvalhthlhndDhb?Uthfhl.WiTRIlhfekShn.
14
.WiThSwlt.DhbYstASpElmElthmItDhpektREn.AndfaliNHinbiHyndDhslOmrDiNbrlz.HrSdmon.WiTthvezhkSwtshndshdRIRivwndiNslepsanDhAnhmhlzlyNks.
.WiThSwt.DhbYzdAStpelmelthmItDhpAktREn.AndfaliNinbiHyndDhslOmUviNbrOz.rjdDhmon.WiTdrisivSwtshndshndRIRIswndiNslApsonDhAnhmhlzflANks.
29
.Wynathlw?oRsithlvrtoftsthlhgZoRIEt.inhnhcRh

.DhsekhnbiginzHIR.AnEnShl.DhTrdHIR.duRshN.ynw.
.DhsekhnbiginzHIR.AnEnjhl.DhTrdHIR.diRsoN.ynO.
4
Average Lev Distance 27 7.779159519725558
AVG VAL LOSS PER BATCH 0.4260690151839643
AVG_loss 100 loss=0.202936
AVG_loss 200 loss=0.208316
AVG_loss 300 loss=0.208654
Average Lev Distance 28 7.7397084048027445
AVG VAL LOSS PER BATCH 0.4261687510722392
AVG_loss 100 loss=0.201805
AVG_loss 200 loss=0.203343
AVG_loss 300 loss=0.203904
Average Lev Distance 29 7.732847341337908
AVG VAL LOSS PER BATCH 0.42723275680799744
