In [1]:
from hydra import initialize, compose
from omegaconf import OmegaConf

import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, WeightedRandomSampler, random_split

from models import Vocabulary
from models.autoencoder import EncoderRNN, DecoderRNN, Autoencoder

In [2]:
with initialize(version_base=None, config_path="conf"):
    cfg = compose(config_name="config", overrides=["+autoencoder=phone_word_uniform"])
cfg

{'trainer': {'batch_size': 32}, 'autoencoder': {'data': {'sampler': {'strategy': 'word', 'weights': 'uniform'}, 'representation': {'form': 'phone-sequence', 'drop_extras': True}}}}

In [3]:
%load_ext autoreload
%autoreload 2

## Prepare dataset

In [4]:
# "Extra" phones in TIMIT not corresponding to speech sounds
timit_extras = ["pau", "epi", "h#"]

In [5]:
df = pd.read_csv("timit_merged.csv")

In [6]:
if cfg.autoencoder.data.representation.drop_extras:
    df = df[~df.phone.isin(timit_extras)]

In [7]:
all_words = df.groupby(["dialect", "speaker", "sentence_idx", "word_idx"]).apply(lambda xs: xs.phone.str.cat(sep=" ")).value_counts()
all_words, all_word_freqs = [tuple(word.split(" ")) for word in all_words.index], list(all_words)

In [8]:
vocab = Vocabulary("")
for word in all_words:
    vocab.add_sequence(word)

In [9]:
# Prepare input tensor
max_length = 10
n = len(all_words)
input_ids = np.zeros((n, max_length), dtype=np.int32)
# target_ids = np.zeros((n, max_length), dtype=np.int32)

for idx, seq in enumerate(all_words):
    input_ids_i = [vocab.sos_token_id] + \
        [vocab.token2index[token] for token in seq][:max_length - 2] + \
        [vocab.eos_token_id]
    input_ids[idx, :len(input_ids_i)] = input_ids_i
    
all_data = TensorDataset(torch.LongTensor(input_ids))
train_data, val_data, test_data = random_split(all_data, [0.8, 0.1, 0.1])

if cfg.autoencoder.data.sampler.strategy != "word":
    raise NotImplementedError()
if cfg.autoencoder.data.sampler.weights == "uniform":
    train_sampler = RandomSampler(train_data)
elif cfg.autoencoder.data.sampler.weights == "unigram":
    train_sampler = WeightedRandomSampler(all_word_freqs)

train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=cfg.trainer.batch_size)
val_dataloader = DataLoader(val_data)

## Train

In [10]:
def train_epoch(dataloader, autoencoder, opt, criterion):
    total_loss = 0
    for data in dataloader:
        opt.zero_grad()
        
        batch_input, = data
        batch_target, = data
        
        dec_outputs, _ = autoencoder(batch_input, batch_target)
        
        loss = criterion(
            dec_outputs.view(-1, dec_outputs.size(-1)),
            batch_target.view(-1)
        )
        loss.backward()
        
        opt.step()
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

In [42]:
def evaluate(dataloader, autoencoder, criterion):
    with torch.no_grad():
        total_loss = 0
        decoded_sequences = []
        for data in dataloader:
            batch_input, = data
            batch_target, = data
            
            dec_outputs, _ = autoencoder(batch_input, batch_target)
            loss = criterion(
                dec_outputs.view(-1, dec_outputs.size(-1)),
                batch_target.view(-1)
            )
            total_loss += loss.item()
            
            _, topi = dec_outputs.topk(1)
            decoded_ids = topi.squeeze()
            decoded_tokens = []
            for idx in decoded_ids:
                if idx == autoencoder.vocabulary.eos_token_id:
                    decoded_tokens.append("<EOS>")
                    break
                decoded_tokens.append(autoencoder.vocabulary.index2token[idx.item()])
                
            decoded_sequences.append(decoded_tokens)
    
    total_loss /= len(val_dataloader)
    return total_loss, decoded_sequences

In [63]:
def train(train_dataloader, val_dataloader, autoencoder, n_epochs,
          learning_rate=0.01, print_every=100, val_every=100):
    print_loss_total = 0
    
    opt = optim.Adam(autoencoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()
    
    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, autoencoder, opt, criterion)
        print_loss_total += loss
        
        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print("%d %.4f" % (epoch, print_loss_avg))
            
        if epoch % val_every == 0:
            val_loss, _ = evaluate(val_dataloader, autoencoder, criterion)
            print("--- val loss: %.4f" % val_loss)

In [73]:
hidden_size = 256
autoencoder = Autoencoder(hidden_size, vocab)

In [74]:
train(train_dataloader, val_dataloader, autoencoder, 4, print_every=1, val_every=1)

1 0.8865
--- val loss: 0.4856
2 0.3730
--- val loss: 0.3971
3 0.3076
--- val loss: 0.3355
4 0.2793
--- val loss: 0.3177


## Decode

In [75]:
val_loss, val_outputs = evaluate(val_dataloader, autoencoder, nn.NLLLoss())

In [76]:
val_loss

0.31559363552704894

In [77]:
val_inputs = [" ".join(autoencoder.vocabulary.index2token[idx.item()] for idx in seq.squeeze())
              for seq, in val_dataloader]

In [78]:
list(zip(val_inputs[:20], [" ".join(toks) for toks in val_outputs[:20]]))

[('<SOS> kcl k ax n tcl t ey n <EOS>', '<SOS> kcl k ax n tcl t ey v <EOS>'),
 ('<SOS> dcl d ey dx ix <EOS> <SOS> <SOS> <SOS>',
  '<SOS> dcl d ey dx ix ng <SOS> <SOS> <SOS>'),
 ('<SOS> n ah m bcl b er <EOS> <SOS> <SOS>', '<SOS> n ah m bcl b er <EOS>'),
 ('<SOS> ix kcl k s pcl p er ih <EOS>', '<SOS> ix kcl k s pcl p er ih <EOS>'),
 ('<SOS> pcl p r aa s eh s tcl <EOS>', '<SOS> pcl p r aa s tcl tcl tcl <EOS>'),
 ('<SOS> r iy ix sh er ix ng <EOS> <SOS>', '<SOS> r ix ix ng er ix ng <EOS>'),
 ('<SOS> kcl k ax n s er n <EOS> <SOS>', '<SOS> kcl k ax n s er s <EOS>'),
 ('<SOS> bcl b aa dh axr <EOS> <SOS> <SOS> <SOS>',
  '<SOS> bcl b aa dh axr <EOS>'),
 ('<SOS> pcl p ow eh m z <EOS> <SOS> <SOS>', '<SOS> pcl p ow m z z <EOS>'),
 ('<SOS> r ix z uh l tcl t s <EOS>', '<SOS> r uh z l l tcl t s <EOS>'),
 ('<SOS> sh er l iy <EOS> <SOS> <SOS> <SOS> <SOS>', '<SOS> sh er l iy <EOS>'),
 ('<SOS> b aa bcl k ae tcl <EOS> <SOS> <SOS>',
  '<SOS> b aa tcl b ae tcl <EOS>'),
 ('<SOS> hh ix pcl p ix kcl k r <EOS>', 