In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import time
import random
import math
from data_utils import tokenize, tokenize_data, load_data, plot, TranscriptionsDataset
from model import Encoder, Decoder, Seq2Seq

In [2]:
def generate_transcriptions(model, w_vocab, t_vocab, word):
        model.eval()
        word_tokenized = tokenize(word)
        word_tokenized = [w.upper() for w in word_tokenized]
        word_indexed = [w_vocab.token2idx(w) for w in word_tokenized]
        word_indexed = torch.LongTensor(word_indexed).to(device)
        
        outputs = model.predict(word_indexed)
        transcriptions = [t_vocab.idx2token(t) for t in outputs]
        transcriptions = ''.join(transcriptions)
        return transcriptions

In [3]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)   

In [4]:
def train(model, train_dataloader, optimizer, criterion, clip, epoch, train_loss_list):
    model.train()
    
    epoch_loss = 0
    for batch_idx, (words, sos_transcriptions, eos_transcriptions) in enumerate(train_dataloader):
        words = words.to(device)
        sos_transcriptions = sos_transcriptions.to(device)
        eos_transcriptions = eos_transcriptions.to(device)
        
        optimizer.zero_grad()
        
        output = model(words, sos_transcriptions)
        output = output.view(-1, output.shape[-1])
        eos_transcriptions = eos_transcriptions.view(-1)
       
        loss = criterion(output.to(device), eos_transcriptions)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        train_loss_list.append(loss.item())
        
#         if batch_idx % 50 == 0:
#             plot(epoch, batch_idx, train_loss_list)
    return epoch_loss / len(train_dataloader)

In [5]:
def evaluate(model, test_dataloader, criterion):
    model.eval()
    
    epoch_loss = 0
    with torch.no_grad():
        for (words, sos_transcriptions, eos_transcriptions) in test_dataloader:
            words = words.to(device)
            sos_transcriptions = sos_transcriptions.to(device)
            eos_transcriptions = eos_transcriptions.to(device)

            output = model(words, sos_transcriptions, 0) # turn off teacher forcing
            output = output.view(-1, output.shape[-1])
            eos_transcriptions = eos_transcriptions.view(-1)
            
            output = output.to(device)
            loss   = criterion(output, eos_transcriptions)
            
            epoch_loss += loss.item()                                   
    return epoch_loss / len(test_dataloader)

In [6]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [7]:
train_dataloader, test_dataloader = load_data()
w_vocab = pickle.load(open('word_vocab.pickle', 'rb'))
t_vocab = pickle.load(open('transcription_vocab.pickle', 'rb'))

In [8]:
INPUT_DIM   = len(w_vocab)
OUTPUT_DIM  = len(t_vocab)
ENC_EMB_DIM = 100
DEC_EMB_DIM = 100
HID_DIM  = 64
N_LAYERS = 1
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS).to(device)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS).to(device)
model = Seq2Seq(enc, dec).to(device)
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(56, 100)
    (LSTM): LSTM(100, 64, batch_first=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(56, 100)
    (LSTM): LSTM(100, 64, batch_first=True)
    (out): Linear(in_features=64, out_features=56, bias=True)
  )
)

In [9]:
PAD_IDX   = t_vocab.token2idx('<pad>')
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
optimizer = optim.Adam(model.parameters())

In [10]:
# N_EPOCHS = 50
# CLIP = 1
# data = tokenize_data()
# best_test_loss  = float('inf')
# train_loss_list = []
# for epoch in range(N_EPOCHS):
    
#     start_time = time.time()
#     train_loss = train(model, train_dataloader, optimizer, criterion, CLIP, epoch, train_loss_list)
#     test_loss  = evaluate(model, test_dataloader, criterion)
#     end_time   = time.time()
    
#     epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
#     if test_loss < best_test_loss:
#         best_test_loss = test_loss
#         torch.save(model.state_dict(), 'checkpoint.pth')
        
#     print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
#     print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
#     print(f'\t Val. Loss: {test_loss:.3f} |  Val. PPL: {math.exp(test_loss):7.3f}')

In [17]:
model.load_state_dict(torch.load('checkpoint.pth'))
for (words, y_in, y_out) in test_dataloader:
    for i in words:
        output = generate_transcriptions(model, w_vocab, t_vocab, ''.join(w_vocab.idx2sent(i.tolist())))
        print('WORD:         ', ''.join(w_vocab.idx2sent(i.tolist())))
        print('GENERATED:    ', output)    
        print('TRANSCRIPTION:', ''.join(t_vocab.idx2sent(i.tolist())))
        print('\n')
    break

WORD:          CONTRARIANS
GENERATED:     KAANTRAHHHAYNERZ
TRANSCRIPTION: CONTRARIANS


WORD:          CYLINDRICAL
GENERATED:     KIHLIHNDRIHKAHL
TRANSCRIPTION: CYLINDRICAL


WORD:          PRESSWOODS<pad>
GENERATED:     PREHSWEYDOWSIHN
TRANSCRIPTION: PRESSWOODS<pad>


WORD:          SOFTSPOKEN<pad>
GENERATED:     SAAPSKAHBAHLSIHNG
TRANSCRIPTION: SOFTSPOKEN<pad>


WORD:          WITTENMYER<pad>
GENERATED:     WIHTAHNMAYGREYT
TRANSCRIPTION: WITTENMYER<pad>


WORD:          INFATUATES<pad>
GENERATED:     IHNFAHGTEYSHAHTIHNG
TRANSCRIPTION: INFATUATES<pad>


WORD:          SHOREWARD<pad><pad>
GENERATED:     SHAORDAHMAARGAHND
TRANSCRIPTION: SHOREWARD<pad><pad>


WORD:          PARCPLACE<pad><pad>
GENERATED:     PAARPKAHLAHSEYBAHND
TRANSCRIPTION: PARCPLACE<pad><pad>


WORD:          SHAREWARE<pad><pad>
GENERATED:     SHERAHDRAHFAYAHLZAHZ
TRANSCRIPTION: SHAREWARE<pad><pad>


WORD:          DERIDDER<pad><pad><pad>
GENERATED:     DERIHDERBAEKTAHVAHND
TRANSCRIPTION: DERIDDER<pad><pad><pad>


WOR

In [12]:
for i in range(len(data)):
    print(i)
    break

NameError: name 'data' is not defined

In [None]:
for i in test_dataloader:
    print(w_vocab.idx2sent(i[0][0].tolist()))
    break

In [None]:
i[0][0].tolist()