In [41]:
from hydra import initialize, compose
from omegaconf import OmegaConf

import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, WeightedRandomSampler

from models import Vocabulary
from models.autoencoder import EncoderRNN, DecoderRNN, Autoencoder

In [13]:
with initialize(version_base=None, config_path="conf"):
    cfg = compose(config_name="config", overrides=["+autoencoder=phone_word_uniform"])
cfg

{'trainer': {'batch_size': 32}, 'autoencoder': {'data': {'sampler': {'strategy': 'word', 'weights': 'uniform'}, 'representation': {'form': 'phone-sequence', 'drop_extras': True}}}}

In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Prepare dataset

In [15]:
# "Extra" phones in TIMIT not corresponding to speech sounds
timit_extras = ["pau", "epi", "h#"]

In [16]:
df = pd.read_csv("timit_merged.csv")

In [17]:
if cfg.autoencoder.data.representation.drop_extras:
    df = df[~df.phone.isin(timit_extras)]

In [18]:
all_words = df.groupby(["dialect", "speaker", "sentence_idx", "word_idx"]).apply(lambda xs: xs.phone.str.cat(sep=" ")).value_counts()
all_words, all_word_freqs = [tuple(word.split(" ")) for word in all_words.index], list(all_words)

In [19]:
vocab = Vocabulary("")
for word in all_words:
    vocab.add_sequence(word)

In [35]:
# Prepare input tensor
max_length = 10
n = len(all_words)
input_ids = np.zeros((n, max_length), dtype=np.int32)
# target_ids = np.zeros((n, max_length), dtype=np.int32)

for idx, seq in enumerate(all_words):
    input_ids_i = [vocab.sos_token_id] + \
        [vocab.token2index[token] for token in seq][:max_length - 2] + \
        [vocab.eos_token_id]
    input_ids[idx, :len(input_ids_i)] = input_ids_i
    
train_data = TensorDataset(torch.LongTensor(input_ids))

if cfg.autoencoder.data.sampler.strategy != "word":
    raise NotImplementedError()
if cfg.autoencoder.data.sampler.weights == "uniform":
    train_sampler = RandomSampler(train_data)
elif cfg.autoencoder.data.sampler.weights == "unigram":
    train_sampler = WeightedRandomSampler(all_word_freqs)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=cfg.trainer.batch_size)

## Train

In [40]:
def train_epoch(dataloader, autoencoder, opt, criterion):
    total_loss = 0
    for data in dataloader:
        opt.zero_grad()
        
        batch_input, = data
        batch_target, = data
        
        dec_outputs, _ = autoencoder(batch_input, batch_target)
        
        loss = criterion(
            dec_outputs.view(-1, dec_outputs.size(-1)),
            batch_target.view(-1)
        )
        loss.backward()
        
        opt.step()
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

In [45]:
def train(train_dataloader, autoencoder, n_epochs,
          learning_rate=0.01, print_every=100):
    print_loss_total = 0
    
    opt = optim.Adam(autoencoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()
    
    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, autoencoder, opt, criterion)
        print_loss_total += loss
        
        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print("%d %.4f" % (epoch, print_loss_avg))

In [43]:
hidden_size = 32
autoencoder = Autoencoder(hidden_size, vocab)

In [46]:
train(train_dataloader, autoencoder, 10, print_every=1)

1 1.7004
2 1.2246
3 0.9881
4 0.8487
5 0.7524
6 0.6769
7 0.6239
8 0.5877
9 0.5458
10 0.5275
