# VAE Using Extended Kalman Filter for Speech Recognition using TIMIT  

The purpose of this demo is to help you learn about variational autoencoder. The algorithm is being implemented is from the paper "Auto-Encoding Variational Bayes", by Diederik P Kingma, Max Welling (https://arxiv.org/abs/1312.6114).

Followed the logic:

VAE Basic: https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/

VAE: https://github.com/ethanluoyc/pytorch-vae/blob/master/vae.py 

VAE: https://github.com/Baileyswu/pytorch-hmm-vae/blob/master/vae.py

TIMIT: https://github.com/jackjhliu/Pytorch-End-to-End-ASR-on-TIMIT

EKF: https://github.com/jnez71/kalmaNN

We are using TIMIT data.

You are free to change model acrhitecture, or any part of the logic. 

If you have any suggestions or find errors, please, don't be hesitate to text me at jayanta.jayantamukherjee@gmail.com

In [1]:
import argparse
import data
import eval_utils
import logging, sys
import matplotlib.pyplot as plt
import os
from prepare_data import prepare_csv
from show_history import plotLearning
import time 
import timeit
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
import torchvision
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
matplotlib.style.use('ggplot')

from torchvision import transforms
import torch.optim as optim
import yaml

  '"sox" backend is being deprecated. '


## Prepare Data from Raw WAV files: TIMIT


In [2]:
logging.basicConfig(stream=sys.stderr, level=logging.INFO)

prepare_csv("../TIMIT/TIMIT_DATA/")

TRAIN.csv is created.
DEV.csv is created.
TEST.csv is created.
Data preparation is complete !


#### Load Config & Clean up Previous Run Stats

In [3]:
cfgFile = "config/default.yaml" 
import shutil

with open(cfgFile) as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
        
if not cfg['logdir']:
    save_path = os.path.splitext(cfgFile)[0]
    
if os.path.exists(save_path):
    shutil.rmtree(save_path)

os.mkdir(save_path)

### Set the constants

In [4]:
gpu_id = 0
workers = 0
ckpt_freq = 10

## Linear VAE

In [5]:
features = 240 # 240 is the dimension of acoustic features.
# define a simple linear VAE
class LinearVAE(nn.Module):
    def __init__(self, target_size, hidden_size, encoder_layers, decoder_layers, drop_p=0.):
        """
        Args:
            target_size (integer): Target vocabulary size.
            hidden_size (integer): Size of GRU cells.
            encoder_layers (integer): EncoderRNN layers.
            decoder_layers (integer): DecoderRNN layers.
            drop_p (float): Probability to drop elements at Dropout layers.
        """
        super(LinearVAE, self).__init__()
        print("Init LinearVAE")

        #self.encoder = EncoderRNN(hidden_size, encoder_layers, drop_p)
        #self.decoder = DecoderRNN(target_size, hidden_size, decoder_layers, drop_p)
         
        # encoder
        self.enc1 = nn.Linear(in_features=features, out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=hidden_size)
        self.enc22 = nn.Linear(in_features=512, out_features=hidden_size) 
 
        # decoder 
        self.dec1 = nn.Linear(in_features=hidden_size, out_features=512)
        self.dec2 = nn.Linear(in_features=512, out_features=features)
        
    def encode(self, x):
        h1 = F.relu(self.enc1(x))
        return self.enc2(h1), self.enc22(h1) 

    def decode(self, z):
        h3 = F.relu(self.dec1(z))
        return torch.sigmoid(self.dec2(h3))
        
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample
    
    def forward(self, xs, xlens, ys=None):
        """
        The forwarding behavior depends on if ground-truths are provided.

        Args:
            xs (torch.LongTensor, [batch_size, seq_length, dim_features]): A mini-batch of FBANK features.
            xlens (torch.LongTensor, [batch_size]): Sequence lengths before padding.
            ys (torch.LongTensor, [batch_size, padded_length_of_target_sentences]): Padded ground-truths.

        Returns: 
            predictions (torch.FloatTensor, [batch_size, max_length]): The sentence generated by Greedy Search. 
        """ 
        # encoding
        mu, log_var = self.encode(xs)
        print("mu shape = ", mu.shape)
        print("log_var shape = ", log_var.shape)
        
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
        print("4 z shape = ", z.shape)
        
        reconstruction = self.decode(z)
        print("6 reconstruction shape = ", reconstruction.shape)
        
        return reconstruction, mu, log_var
  
            
    def get_lr(self, optimizer):
        """
        A helper function to retrieve the solver's learning rate.
        """
        for param_group in optimizer.param_groups:
            return param_group['lr']


    def log_history(self, save_path, message):
        """
        A helper function to log the history.
        The history text file is saved as: {SAVE_PATH}/history.txt

        Args:
            save_path (string): The location to log the history.
            message (string): The message to log.
        """
        fname = os.path.join(save_path,'history.csv')
        if not os.path.exists(fname):
            with open(fname, 'w') as f:
                f.write("datetime,epoch,learning rate,train loss,dev loss,error rate\n")
                f.write("%s\n" % message)
        else:
            with open(fname, 'a') as f:
                f.write("%s\n" % message)


    def save_checkpoint(self, filename, save_path, epoch, dev_error, cfg, weights):
        """
        Args:
            filename (string): Filename of this checkpoint.
            save_path (string): The location to save.
            epoch (integer): Epoch number.
            dev_error (float): Error rate on development set.
            cfg (dict): Experiment config for reconstruction.
            weights (dict): "state_dict" of this model.
        """
        filename = os.path.join(save_path, filename)
        info = {'epoch': epoch,
                'dev_error': dev_error,
                'cfg': cfg,
                'weights': weights}
        torch.save(info, filename) 


### Define Learning Parameters

In [6]:
# leanring parameters
epochs = 10
batch_size = 64
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Initialize Linear VAE Model

In [7]:
gpu_id = 0
workers = 0
ckpt_freq = 10

#input_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = cfg['train']['batch_size']
#transform = transforms.Compose([transforms.ToTensor()])

# Create dataset
train_loader = data.load(split='train', batch_size=batch_size, workers = workers)
val_loader = data.load(split='dev', batch_size=batch_size)
hidden_size = hidden_size=cfg['model']['hidden_size']
activation=cfg['model']['activation']

# Build model
tokenizer = torch.load('tokenizer.pth')
print("tokenizer.vocab length = ", len(tokenizer.vocab))
#      self, target_size, hidden_size, encoder_layers, decoder_layers, drop_p=0.
model = LinearVAE(target_size=len(tokenizer.vocab),
          hidden_size=cfg['model']['hidden_size'],
          encoder_layers=cfg['model']['encoder_layers'],
          decoder_layers=cfg['model']['decoder_layers'],
          drop_p=cfg['model']['drop_p'])

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss(reduction='sum')

TRAIN set size: 3696
DEV set size: 1152
tokenizer.vocab length =  66
Init LinearVAE


### Define Final Loss Method

In [8]:
def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

### Define Training Method

In [9]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(dataloader.dataset)/dataloader.batch_size)):
        (xs, xlens, ys) = data
        
        print("inside fit::: xs shape = ", xs.shape, ", xlens shape = ", xlens.shape, ", ys shape = ", ys.shape)
        #data, _ = data
        #data = data.to(device)
        #data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(xs, xlens)
        print("xs shape = ", xs.shape, "ys shape = ", ys.shape, ", reconstruction shape = ", reconstruction.shape)
        #ys1 = ys.reshape(reconstruction.shape)
        #reconstruction = reconstruction.reshape(ys.shape)
        #criterion = nn.CrossEntropyLoss()
        bce_loss = criterion(reconstruction, xs)
        loss = final_loss(bce_loss, mu, logvar)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

### Define Validation Method

In [10]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(dataloader.dataset)/dataloader.batch_size)):
            (xs, xlens, ys) = data
            print("inside validate::: xs shape = ", xs.shape, ", xlens shape = ", xlens.shape, ", ys shape = ", ys.shape)
            reconstruction, mu, logvar = model(xs, xlens)
            print("xs shape = ", xs.shape, "ys shape = ", ys.shape, ", reconstruction shape = ", reconstruction.shape)
            #criterion = nn.CrossEntropyLoss()
            bce_loss = criterion(reconstruction, xs)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            #if i == int(len(dataloader.dataset)/dataloader.batch_size) - 1:
            #    num_rows = 8
            #    both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
            #                      reconstruction.view(batch_size, 1, 28, 28)[:8]))
            #    save_image(both.cpu(), f"../outputs/output{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

### Train & Eval

In [11]:
import eval_utils

epochs = 2
train_loss = []
val_loss = []
RMS = []
train_epoch_durations = []
eval_epoch_durations = []
best_epoch = 0
best_error = float('inf')

for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    
    print("Learning rate: %f" % lr)
    ssqrtm = eval_utils.get_error(train_loader, model)
    RMS.append(ssqrtm)
    
    
    start_train_epoch = time.time()
    train_epoch_loss = fit(model, train_loader)
    end_train_epoch = time.time()
    train_epoch_duration = end_train_epoch - start_train_epoch
    train_epoch_durations.append(train_epoch_duration)

    val_epoch_loss = validate(model, val_loader)
    end_eval_epoch = time.time()
    eval_epoch_duration = end_eval_epoch - end_train_epoch
    eval_epoch_durations.append(eval_epoch_duration) 
    
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    # Compute dev error rate
    error = eval_utils.get_error(val_loader, model)
    print ("Dev. loss: %.3f," % val_loss, end=' ')
    print ("dev. error rate: %.4f" % error)
    if error < best_error:
        best_error = error
        best_epoch = epoch
        # Save best model
        save_checkpoint("best.pth", save_path, best_epoch, best_error, cfg, model.state_dict())
    print ("Best dev. error rate: %.4f @epoch: %d" % (best_error, best_epoch))
    
    # Save checkpoint
    if not epoch%ckpt_freq or epoch==cfg['train']['epochs']:
        save_checkpoint("checkpoint_%05d.pth"%epoch, save_path, epoch, error, cfg, model.state_dict())

    # Logging
    datetime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    msg = "%s,%d,%f,%f,%f,%f" % (datetime, epoch, lr, train_loss,  val_loss, error)
    log_history(save_path, msg)
    
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 2
Learning rate: 0.000100
table =  {'aa': 'aa', 'ae': 'ae', 'ah': 'ah', 'ao': 'aa', 'aw': 'aw', 'ax': 'ah', 'ax-h': 'ah', 'axr': 'er', 'ay': 'ay', 'b': 'b', 'bcl': 'sil', 'ch': 'ch', 'd': 'd', 'dcl': 'sil', 'dh': 'dh', 'dx': 'dx', 'eh': 'eh', 'el': 'l', 'em': 'm', 'en': 'n', 'eng': 'ng', 'epi': 'sil', 'er': 'er', 'ey': 'ey', 'f': 'f', 'g': 'g', 'gcl': 'sil', 'h#': 'sil', 'hh': 'hh', 'hv': 'hh', 'ih': 'ih', 'ix': 'ih', 'iy': 'iy', 'jh': 'jh', 'k': 'k', 'kcl': 'sil', 'l': 'l', 'm': 'm', 'n': 'n', 'ng': 'ng', 'nx': 'n', 'ow': 'ow', 'oy': 'oy', 'p': 'p', 'pau': 'sil', 'pcl': 'sil', 'r': 'r', 's': 's', 'sh': 'sh', 't': 't', 'tcl': 'sil', 'th': 'th', 'uh': 'uh', 'uw': 'uw', 'ux': 'uw', 'v': 'v', 'w': 'w', 'y': 'y', 'z': 'z', 'zh': 'sh'}


  fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)


mu shape =  torch.Size([64, 174, 256])
log_var shape =  torch.Size([64, 174, 256])
4 z shape =  torch.Size([64, 174, 256])
6 reconstruction shape =  torch.Size([64, 174, 240])
preds_batch shape =  torch.Size([64, 174, 240])  ys shape =  torch.Size([64, 66])
ys[ 0 ] =  tensor([ 5, 45, 11, 12, 11, 34,  6, 10, 11, 12, 43, 31, 21, 18, 14, 15, 13, 53,
        12, 22, 34,  6, 20, 13, 43, 12, 11, 12, 43, 44, 32, 33, 26, 42, 29, 13,
        29, 49, 12, 11, 51, 23, 24, 39, 12, 32, 10, 50, 30, 31, 61, 43, 44, 49,
        18, 14, 15, 59, 26, 42, 18, 11, 43, 44,  5,  2])
gt length =  176  gt =  h# dh ih s ih pcl p l ih s tcl k ay n dcl d ix v s eh pcl p r ix tcl s ih s tcl t ah m y ux z ix z ey s ih ng gcl g el s ah l f kcl k en tcl t ey n dcl d pau y ux n ih tcl t h#
preds_batch[ 0  ] shape =  torch.Size([174, 240])
oned_preds shape =  torch.Size([1, 41760])


TypeError: only integer tensors of a single element can be converted to an index

#### Plot History

In [None]:
plotLearning("config/default/history.csv")

In [None]:
from torch.autograd import Variable
import torch.utils.data

batch_size = cfg['train']['batch_size']
input_dim = 28 * 28
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for step, (xs, xlens, ys) in enumerate(train_loader):
    xs = xs.to(device) 
    ys = ys.to(device)
    #inputs = Variable(data.resize_(batch_size, input_dim))

modeled_data = model(xs, xlens, ys) #model(inputs)
print("modeled_data (loss) = ", modeled_data)
print("xs = ", xs)
print("ys = ", ys)
#model_data = modeled_data[0].reshape(batch_size, input_dim)
#print(input_data)        
#plt.imshow(model_data[0].detach().numpy().reshape(28, 28), cmap='gray')
#plt.show(block=True)

### Eval

In [None]:
# Restore checkpoint
info = torch.load("config/default/best.pth")
split = 'test'
print ("Dev. error rate of checkpoint: %.4f @epoch: %d" % (info['dev_error'], info['epoch']))
    
# Build model
tokenizer = torch.load('tokenizer.pth')
model = KalmanVAE(input_size=cfg['model']['input_size'],
            target_size=len(tokenizer.vocab),
            hidden_size=cfg['model']['hidden_size'],
            encoder_layers=cfg['model']['encoder_layers'],
            decoder_layers=cfg['model']['decoder_layers'],
            activation=cfg['model']['activation'],
            drop_p=cfg['model']['drop_p'])
 
model.load_state_dict(info['weights'])
model.eval() 
model.cuda()

# Evaluate
error = eval_utils.get_error(train_loader, model)
print ("Error rate on %s set = %.4f" % (split, error))

### Plot Training & Eval time

In [None]:
# Note that using plt.subplots below is equivalent to using
# fig = plt.figure() and then ax = fig.add_subplot(111)
fig, ax = plt.subplots()

#now create y values for the second plot
y = train_epoch_durations
#calculate the values for the Gaussian curve
x = np.arange(len(train_epoch_durations))
#plot the Gaussian curve
ax.plot(x, y, label = "Train Time")

ax.set(xlabel='Epoch (s)', ylabel='Time (ms)',
       title='Training time')

xe = np.arange(len(eval_epoch_durations))
ye = eval_epoch_durations
#plot sine wave
ax.plot(xe, ye, label = "Evaluation Time")

ax.grid()

#show the legend
plt.legend()
plt.show()

In [None]:
# Note that using plt.subplots below is equivalent to using
# fig = plt.figure() and then ax = fig.add_subplot(111)
fig, ax = plt.subplots()

#now create y values for the second plot
y = RMS
#calculate the values for the Gaussian curve
x = np.arange(len(RMS))
#plot the Gaussian curve
ax.plot(x, y, label = "Root-mean Square Error")

ax.set(xlabel='Epoch (s)', ylabel='RMSE (%)',
       title='Root-mean Square Error')

ax.grid()

#show the legend
plt.legend()
plt.show()
plt.savefig('img/rmse-ekf.pdf')
plt.savefig('img/rmse-ekf.png')

#### Inference auxiliary method

In [None]:
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import argparse
from matplotlib.pyplot import figure


def showAttention(predictions, attentions):
    output_words = predictions.split()
    # Set up figure with colorbar
    fig = plt.figure(figsize=(10,15))
    #figure(num=None, figsize=(8, 6), dpi=80, facecolor='w', edgecolor='k')

    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions, cmap='bone')
    fig.colorbar(cax)

    ax.set_yticklabels([''] + output_words)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.show()

### Inference

In [None]:
# Inference
with torch.no_grad():
    for (x, xlens, y) in train_loader:
        predictions, attentions = model(x.cuda(), xlens)
        predictions, attentions = predictions[0], attentions[0]
        predictions = tokenizer.decode(predictions)
        attentions = attentions[:len(predictions.split())].cpu().numpy()   # (target_length, source_length)
        ground_truth = tokenizer.decode(y[0])
        print ("Predict:")
        print (predictions)
        print ("Ground-truth:")
        print (ground_truth)
        print ()
        showAttention(predictions, attentions)

In [None]:
epoch = 0
with open('config/default/timing.csv', 'w') as filehandle:
    for (train_epoch_duration, eval_epoch_duration, RMSE) in zip(train_epoch_durations, eval_epoch_durations, RMS):
        msg = '%d, %s, %s, %s \n' % (epoch, train_epoch_duration, eval_epoch_duration, RMSE)
        #print(msg)
        epoch = epoch +1 
        filehandle.write(msg)