<a href="https://colab.research.google.com/github/bpopeters/transducer-tutorial/blob/main/transducer_tutorial_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transducer implementation in PyTorch

*by Loren Lugosch*



In this notebook, we will implement a Transducer sequence-to-sequence model for inserting missing vowels into a sentence ("Hll, Wrld" --> "Hello, World").

In [39]:
import itertools
from collections import Counter
from tqdm import tqdm
import torch
import numpy as np


# Building blocks

First, we will define the encoder, predictor, and joiner using standard neural nets.

<img src="https://lorenlugosch.github.io/images/transducer/transducer-model.png" width="25%">

In [40]:
NULL_INDEX = 0

encoder_dim = 256
predictor_dim = 256
joiner_dim = 256

The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.


In [41]:
class Encoder(torch.nn.Module):
  def __init__(self, num_inputs):
    super(Encoder, self).__init__()
    self.embed = torch.nn.Embedding(num_inputs, encoder_dim)
    self.rnn = torch.nn.GRU(input_size=encoder_dim, hidden_size=encoder_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3)
    self.linear = torch.nn.Linear(encoder_dim * 2, joiner_dim)

  def forward(self, x):
    out = x
    out = self.embed(out)
    out = self.rnn(out)[0]
    out = self.linear(out)
    return out

The predictor is any _causal_ network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention. 

In [42]:
class Predictor(torch.nn.Module):
  def __init__(self, num_outputs):
    super(Predictor, self).__init__()
    self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
    self.rnn = torch.nn.GRUCell(input_size=predictor_dim, hidden_size=predictor_dim)
    self.linear = torch.nn.Linear(predictor_dim, joiner_dim)
    
    self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
    self.start_symbol = NULL_INDEX # In the original paper, a vector of 0s is used; just using the null index instead is easier when using an Embedding layer.

  def forward_one_step(self, input, previous_state):
    embedding = self.embed(input)
    state = self.rnn.forward(embedding, previous_state)
    out = self.linear(state)
    return out, state

  def forward(self, y):
    batch_size = y.shape[0]
    U = y.shape[1]
    outs = []
    state = torch.stack([self.initial_state] * batch_size).to(y.device)
    for u in range(U+1): # need U+1 to get null output for final timestep 
      if u == 0:
        decoder_input = torch.tensor([self.start_symbol] * batch_size).to(y.device)
      else:
        decoder_input = y[:,u-1]
      out, state = self.forward_one_step(decoder_input, state)
      outs.append(out)
    out = torch.stack(outs, dim=1)
    return out

The joiner is a feedforward network/MLP with one hidden layer applied independently to each $(t,u)$ index.

(The linear part of the hidden layer is contained in the encoder and predictor, so we just do the nonlinearity here and then the output layer.)

In [43]:
class Joiner(torch.nn.Module):
  def __init__(self, num_outputs):
    super(Joiner, self).__init__()
    self.linear = torch.nn.Linear(joiner_dim, num_outputs)

  def forward(self, encoder_out, predictor_out):
    out = encoder_out + predictor_out
    out = torch.nn.functional.tanh(out)
    out = self.linear(out)
    return out

# Transducer model + loss function

Using the encoder, predictor, and joiner, we will implement the Transducer model and its loss function.

<img src="https://lorenlugosch.github.io/images/transducer/forward-messages.png" width="25%">

We can use a simple PyTorch implementation of the loss function, relying on automatic differentiation to give us gradients.

In [44]:
class Transducer(torch.nn.Module):
  def __init__(self, num_inputs, num_outputs):
    super(Transducer, self).__init__()
    self.encoder = Encoder(num_inputs)
    self.predictor = Predictor(num_outputs)
    self.joiner = Joiner(num_outputs)

    if torch.cuda.is_available(): self.device = "cuda:0"
    else: self.device = "cpu"
    self.to(self.device)

  def compute_forward_prob(self, joiner_out, T, U, y):
    """
    joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
    T: list of input lengths
    U: list of output lengths 
    y: label tensor (B, U_max+1)
    """
    B = joiner_out.shape[0]
    T_max = joiner_out.shape[1]
    U_max = joiner_out.shape[2] - 1
    log_alpha = torch.zeros(B, T_max, U_max+1).to(model.device)
    for t in range(T_max):
      for u in range(U_max+1):
          if u == 0:
            if t == 0:
              log_alpha[:, t, u] = 0.

            else: #t > 0
              log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] 
                  
          else: #u > 0
            if t == 0:
              log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
            
            else: #t > 0
              log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                  log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                  log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
              ]), dim=0)
    
    log_probs = []
    for b in range(B):
      log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]
      log_probs.append(log_prob)
    log_probs = torch.stack(log_probs) 
    return log_prob

  def compute_loss(self, x, y, T, U):
    encoder_out = self.encoder.forward(x)
    predictor_out = self.predictor.forward(y)
    joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
    loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
    return loss

Let's first verify that the forward algorithm actually correctly computes the sum (in log space, the [logsumexp](https://lorenlugosch.github.io/posts/2020/06/logsumexp/)) of all possible alignments, using a short input/output pair for which computing all possible alignments is feasible.

<img src="https://lorenlugosch.github.io/images/transducer/cat-align-1.png" width="25%">

In [45]:
def compute_single_alignment_prob(self, encoder_out, predictor_out, T, U, z, y):
    """
    Computes the probability of one alignment, z.
    """
    t = 0; u = 0
    t_u_indices = []
    y_expanded = []
    for step in z:
      t_u_indices.append((t,u))
      if step == 0: # right (null)
        y_expanded.append(NULL_INDEX)
        t += 1
      if step == 1: # down (label)
        y_expanded.append(y[u])
        u += 1
    t_u_indices.append((T-1,U))
    y_expanded.append(NULL_INDEX)

    t_indices = [t for (t,u) in t_u_indices]
    u_indices = [u for (t,u) in t_u_indices]
    encoder_out_expanded = encoder_out[t_indices]
    predictor_out_expanded = predictor_out[u_indices]
    joiner_out = self.joiner.forward(encoder_out_expanded, predictor_out_expanded).log_softmax(1)
    logprob = -torch.nn.functional.nll_loss(input=joiner_out, target=torch.tensor(y_expanded).long().to(self.device), reduction="sum")
    return logprob

Transducer.compute_single_alignment_prob = compute_single_alignment_prob

Now let's add the greedy search algorithm for predicting an output sequence.

(Note that I've assumed we're using RNNs for the predictor here. You would have to modify this code a bit if you want to use convolutions/self-attention instead.) 
<br/><br/>
<img src="https://lorenlugosch.github.io/images/transducer/greedy-search.png" width="50%">

In [46]:
def greedy_search(self, x, T):
  y_batch = []
  B = len(x)
  encoder_out = self.encoder.forward(x)
  U_max = 200
  for b in range(B):
    t = 0; u = 0; y = [self.predictor.start_symbol]; predictor_state = self.predictor.initial_state.unsqueeze(0)
    while t < T[b] and u < U_max:
      predictor_input = torch.tensor([ y[-1] ]).to(x.device)
      g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)
      f_t = encoder_out[b, t]
      h_t_u = self.joiner.forward(f_t, g_u)
      argmax = h_t_u.max(-1)[1].item()
      if argmax == NULL_INDEX:
        t += 1
      else: # argmax == a label
        u += 1
        y.append(argmax)
    y_batch.append(y[1:]) # remove start symbol
  return y_batch

Transducer.greedy_search = greedy_search

The code above will work, but training will be very slow because the Transducer loss is written in pure Python. You can use the fast implementation from SpeechBrain instead by running the block below.

In [47]:
!pip install speechbrain
from speechbrain.nnet.loss.transducer_loss import TransducerLoss
transducer_loss = TransducerLoss(0)

def compute_loss(self, x, y, T, U):
    encoder_out = self.encoder.forward(x)
    predictor_out = self.predictor.forward(y)
    joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
    #loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
    T = T.to(joiner_out.device)
    U = U.to(joiner_out.device)
    loss = transducer_loss(joiner_out, y, T, U) #, blank_index=NULL_INDEX, reduction="mean")
    return loss

Transducer.compute_loss = compute_loss



# Some utilities

Here we will add a bit of boilerplate code for training and loading data.

In [48]:
def read_corpus(path):
  with open(path) as f:
    return [line.strip().split("\t") for line in f if line != "\n"]

train = read_corpus("hun_train.tsv")
dev = read_corpus("hun_dev.tsv")
test = read_corpus("hun_test.tsv")

src, tgt = zip(*train)
src_itos = ["_"] + list(set(itertools.chain.from_iterable([list(s) for s in src])))
tgt_itos = ["_"] + list(set(itertools.chain.from_iterable([t.split() for t in tgt])))
print(src_itos)
print(tgt_itos)

src_stoi = {s: i for i, s in enumerate(src_itos)}
tgt_stoi = {t: i for i, t in enumerate(tgt_itos)}

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines, batch_size):

    self.lines = lines # list of pairs of strings
    collate = Collate()
    self.loader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=1, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    x, y = self.lines[idx]
    return x, y

# encode and decode functions will need to be different
# encode turns strings into lists of indices
def encode_src(s):
  return [src_stoi[c] for c in s]

def encode_tgt(t):
  return [tgt_stoi[c] for c in t.split()]

def decode_src(seq):
  return "".join([src_itos[i] for i in seq])

def decode_tgt(seq):
  return " ".join([tgt_itos[i] for i in seq])

class Collate:
  def __call__(self, batch):
    """
    batch: list of tuples (input string, output string)
    Returns a minibatch of strings, encoded as labels and padded to have the same length.
    """
    x = []; y = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_, y_ = batch[index]
      x.append(encode_src(x_))
      y.append(encode_tgt(y_))

    # pad all sequences to have same length
    T = [len(x_) for x_ in x]
    U = [len(y_) for y_ in y]
    T_max = max(T)
    U_max = max(U)
    for index in range(batch_size):
      x[index] += [NULL_INDEX] * (T_max - len(x[index]))
      x[index] = torch.tensor(x[index])
      y[index] += [NULL_INDEX] * (U_max - len(y[index]))
      y[index] = torch.tensor(y[index])

    # stack into single tensor
    x = torch.stack(x)
    y = torch.stack(y)
    T = torch.tensor(T)
    U = torch.tensor(U)

    return x , y, T, U

train_set = TextDataset(train, batch_size=64)
dev_set = TextDataset(dev, batch_size=64)
test_set = TextDataset(test, batch_size=64)

['_', 'w', 'j', 'ő', 'l', 'í', 'p', 'b', 'á', 'd', 'k', 'ű', 'n', 'o', 'y', 'v', 'ó', 'i', 's', 'z', 'ü', 'u', 'r', 'ö', 't', 'é', 'ú', 'm', 'x', 'h', 'e', 'a', 'c', 'g', 'f']
['_', 'eː', 'cː', 'fː', 'j', 'ɡ', 'ʃ', 'l', 'ʝ', 'sː', 'øː', 'ɲː', 'jː', 't͡ʃ', 'pː', 'kː', 'ɦ', 'p', 'ɟ', 'b', 'uː', 'd͡ʒ', 'vː', 'tː', 'd', 'k', 'ɛ', 'ø', 'd͡zː', 'lː', 'ɒ', 'n', 'nː', 'o', 'ʃː', 'y', 'd͡ʒː', 'rː', 'v', 'f', 'i', 's', 'z', 'mː', 'ʒ', 'u', 'r', 't͡ʃː', 't͡s', 't', 'bː', 'yː', 'm', 'iː', 'x', 'ɲ', 'h', 'zː', 'ɡː', 'ɱ', 'oː', 'ɟː', 't͡sː', 'c', 'aː', 'dː', 'ŋ']


In [49]:
class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
  
  def train(self, dataset, print_interval = 60):
    train_loss = 0
    num_samples = 0
    self.model.train()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x,y,T,U = batch
      x = x.to(self.model.device); y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      self.optimizer.zero_grad()
      pbar.set_description("%.2f" % loss.item())
      loss.backward()
      self.optimizer.step()
      train_loss += loss.item() * batch_size
    
      if idx % print_interval == 0:
        self.model.eval()
        guesses = self.model.greedy_search(x,T)
        self.model.train()
        print("\n")
        for b in range(2):
          print("input:", decode_src(x[b,:T[b]]))
          print("guess:", decode_tgt(guesses[b]))
          print("truth:", decode_tgt(y[b,:U[b]]))
          print("")
    train_loss /= num_samples
    return train_loss

  def test(self, dataset):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x,y,T,U = batch
      x = x.to(self.model.device)
      y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      pbar.set_description("%.2f" % loss.item())
      test_loss += loss.item() * batch_size
    test_loss /= num_samples
    return test_loss

  def generate(self, dataset):
    num_correct = 0
    num_samples = 0
    self.model.eval()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x, y, T, U = batch
      batch_size = len(x)
      x = x.to(self.model.device)
      y = y.to(self.model.device)

      y_hat = self.model.greedy_search(x, T)
      y_str = [decode_tgt(y[i, :U[i]]) for i in range(len(x))]
      y_hat_str = [decode_tgt(y_hat_i) for y_hat_i in y_hat]

      num_correct += sum(yh == yi for yh, yi in zip(y_hat_str, y_str))
      num_samples += len(x)
      # pbar.set_description("%.2f" % loss.item())
      # test_loss += loss.item() * batch_size
    wer = 1 - num_correct / num_samples
    return wer
    

# Training the model

Now we will train a model. This will generate some output sequences every 20 batches.

In [None]:
model = Transducer(num_inputs=len(src_itos), num_outputs=len(tgt_itos))
trainer = Trainer(model=model, lr=0.0003)

num_epochs = 20
train_losses=[]
test_losses=[]
test_wers = []

for epoch in range(num_epochs):
    train_loss = trainer.train(train_set)
    dev_loss = trainer.test(dev_set)
    dev_wer = trainer.generate(dev_set)
    train_losses.append(train_loss)
    test_losses.append(dev_loss)
    test_wers.append(dev_wer)
    print("Epoch %d: train loss = %f, dev loss = %f, dev wer = %f" % (epoch, train_loss, dev_loss, dev_wer))

5.66:   2%|▏         | 3/125 [00:00<00:19,  6.26it/s]



input: árú
guess: 
truth: aː r uː

input: autóknál
guess: 
truth: ɒ u t oː k n aː l



2.63:  50%|█████     | 63/125 [00:05<00:06,  9.89it/s]



input: költség
guess: k
truth: k ø l t͡ʃ eː ɡ

input: ismeretet
guess: ɛ
truth: i ʃ m ɛ r ɛ t ɛ t



1.08:  98%|█████████▊| 123/125 [00:09<00:00,  9.13it/s]



input: kezdődő
guess: k ɛ z d øː
truth: k ɛ z d øː d øː

input: lehessen
guess: l ɛ ʃ ɛ n
truth: l ɛ h ɛ ʃː ɛ n



1.13: 100%|██████████| 125/125 [00:09<00:00, 12.94it/s]
1.14: 100%|██████████| 16/16 [00:00<00:00, 28.43it/s]
100%|██████████| 16/16 [00:07<00:00,  2.19it/s]


Epoch 0: train loss = 2.713217, dev loss = 1.014923, dev wer = 0.839000


0.98:   2%|▏         | 3/125 [00:00<00:25,  4.78it/s]



input: írásos
guess: r aː ʃ o ʃ
truth: iː r aː ʃ o ʃ

input: ft
guess: f t
truth: f o r i n t



0.51:  50%|█████     | 63/125 [00:05<00:06,  9.06it/s]



input: szentháromság
guess: s ɛ n t h aː r o m aː ɡ
truth: s ɛ n t h aː r o m ʃ aː ɡ

input: részesít
guess: r eː s ɛ ʃ iː t
truth: r eː s ɛ ʃ iː t



0.22:  98%|█████████▊| 123/125 [00:09<00:00,  8.43it/s]



input: további
guess: t o v aː b i
truth: t o v aː bː i

input: ázsiai
guess: aː z i j ɒ j i
truth: aː ʒ i j ɒ j i



0.25: 100%|██████████| 125/125 [00:09<00:00, 12.84it/s]
0.27: 100%|██████████| 16/16 [00:00<00:00, 28.40it/s]
100%|██████████| 16/16 [00:07<00:00,  2.00it/s]


Epoch 1: train loss = 0.509435, dev loss = 0.246773, dev wer = 0.421000


0.22:   2%|▏         | 3/125 [00:00<00:26,  4.62it/s]



input: lejre
guess: l ɛ j r ɛ
truth: l ɛ j r ɛ

input: napot
guess: n ɒ p o t
truth: n ɒ p o t



0.18:  50%|█████     | 63/125 [00:05<00:07,  8.50it/s]



input: épületeket
guess: eː p y l ɛ t ɛ k ɛ t
truth: eː p y l ɛ t ɛ k ɛ t

input: mellőle
guess: m ɛ lː øː l ɛ
truth: m ɛ lː øː l ɛ



0.09:  98%|█████████▊| 123/125 [00:09<00:00,  9.05it/s]



input: abból
guess: ɒ bː oː l
truth: ɒ bː oː l

input: ötvenedik
guess: ø t v ɛ n ɛ d i k
truth: ø t v ɛ n ɛ d i k



0.09: 100%|██████████| 125/125 [00:09<00:00, 13.43it/s]
0.27: 100%|██████████| 16/16 [00:00<00:00, 30.38it/s]
100%|██████████| 16/16 [00:08<00:00,  2.00it/s]


Epoch 2: train loss = 0.191209, dev loss = 0.128609, dev wer = 0.405000


0.11:   2%|▏         | 3/125 [00:00<00:24,  4.97it/s]



input: röhög
guess: r ø ɡ
truth: r ø h ø ɡ

input: háborúban
guess: h aː b o r uː b ɒ n
truth: h aː b o r uː b ɒ n



0.11:  50%|█████     | 63/125 [00:04<00:06,  9.22it/s]



input: gyenge
guess: ɟ ɛ ŋ ɛ
truth: ɟ ɛ ŋ ɡ ɛ

input: élesen
guess: eː l ɛ ʃ ɛ n
truth: eː l ɛ ʃ ɛ n



0.17:  98%|█████████▊| 123/125 [00:08<00:00,  9.29it/s]



input: eljárást
guess: ɛ jː aː r aː ʃ t
truth: ɛ jː aː r aː ʃ t

input: élünk
guess: eː l y ŋ k
truth: eː l y ŋ k



0.06: 100%|██████████| 125/125 [00:08<00:00, 14.07it/s]
0.08: 100%|██████████| 16/16 [00:00<00:00, 30.51it/s]
100%|██████████| 16/16 [00:07<00:00,  2.00it/s]


Epoch 3: train loss = 0.113033, dev loss = 0.087486, dev wer = 0.363000


0.13:   2%|▏         | 3/125 [00:00<00:27,  4.48it/s]



input: park
guess: p ɒ r k
truth: p ɒ r k

input: dob
guess: d o b
truth: d o b



0.05:  50%|█████     | 63/125 [00:05<00:07,  8.77it/s]



input: távolsági
guess: t aː v o l ʃ aː ɡ i
truth: t aː v o l ʃ aː ɡ i

input: jövőbeli
guess: j ø v øː b ɛ l i
truth: j ø v øː b ɛ l i



0.05:  98%|█████████▊| 123/125 [00:09<00:00,  8.58it/s]



input: felmérések
guess: f ɛ l eː r eː ʃ ɛ k
truth: f ɛ l m eː r eː ʃ ɛ k

input: égtájak
guess: eː k t aː j ɒ k
truth: eː k t aː j ɒ k



0.04: 100%|██████████| 125/125 [00:09<00:00, 12.94it/s]
0.06: 100%|██████████| 16/16 [00:00<00:00, 29.64it/s]
100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 4: train loss = 0.077927, dev loss = 0.067994, dev wer = 0.386000


0.07:   2%|▏         | 3/125 [00:00<00:24,  4.92it/s]



input: hanyatlás
guess: h ɒ ɲ ɒ t l aː ʃ
truth: h ɒ ɲ ɒ t l aː ʃ

input: ártalmas
guess: aː r t ɒ l m ɒ ʃ
truth: aː r t ɒ l m ɒ ʃ



0.03:  50%|█████     | 63/125 [00:04<00:06,  8.90it/s]



input: megtud
guess: m ɛ k t u d
truth: m ɛ k t u d

input: természetes
guess: t ɛ r m eː s ɛ t ɛ ʃ
truth: t ɛ r m eː s ɛ t ɛ ʃ



0.03:  98%|█████████▊| 123/125 [00:09<00:00,  8.84it/s]



input: hibákat
guess: h i b aː k ɒ t
truth: h i b aː k ɒ t

input: hasznos
guess: h ɒ s o ʃ
truth: h ɒ s n o ʃ



0.07: 100%|██████████| 125/125 [00:09<00:00, 13.66it/s]
0.02: 100%|██████████| 16/16 [00:00<00:00, 29.46it/s]
100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 5: train loss = 0.058313, dev loss = 0.054402, dev wer = 0.375000


0.03:   2%|▏         | 3/125 [00:00<00:25,  4.78it/s]



input: mentek
guess: m ɛ n ɛ k
truth: m ɛ n t ɛ k

input: lányok
guess: l aː ɲ o k
truth: l aː ɲ o k



0.05:  50%|█████     | 63/125 [00:04<00:06,  9.29it/s]



input: rovarok
guess: r o v ɒ r o k
truth: r o v ɒ r o k

input: latin
guess: l ɒ t i n
truth: l ɒ t i n



0.05:  98%|█████████▊| 123/125 [00:08<00:00,  8.94it/s]



input: életben
guess: eː l ɛ d b ɛ n
truth: eː l ɛ d b ɛ n

input: jenő
guess: j ɛ n øː
truth: j ɛ n øː



0.02: 100%|██████████| 125/125 [00:09<00:00, 13.80it/s]
0.15: 100%|██████████| 16/16 [00:00<00:00, 29.77it/s]
100%|██████████| 16/16 [00:08<00:00,  1.98it/s]


Epoch 6: train loss = 0.045752, dev loss = 0.048279, dev wer = 0.317000


0.02:   2%|▏         | 3/125 [00:00<00:28,  4.32it/s]



input: atomerőmű
guess: ɒ t o m ɛ r øː m yː
truth: ɒ t o m ɛ r øː m yː

input: lényeges
guess: l eː ɲ ɛ ɡ ɛ ʃ
truth: l eː ɲ ɛ ɡ ɛ ʃ



0.02:  50%|█████     | 63/125 [00:05<00:06,  8.95it/s]



input: kísérletek
guess: k iː ʃ eː r ɛ t ɛ k
truth: k iː ʃ eː r l ɛ t ɛ k

input: szűz
guess: s
truth: s yː z



0.03:  98%|█████████▊| 123/125 [00:09<00:00,  8.78it/s]



input: szabály
guess: s ɒ b aː j
truth: s ɒ b aː j

input: kötetes
guess: k ø t ɛ t ɛ ʃ
truth: k ø t ɛ t ɛ ʃ



0.02: 100%|██████████| 125/125 [00:09<00:00, 13.18it/s]
0.02: 100%|██████████| 16/16 [00:00<00:00, 29.37it/s]
100%|██████████| 16/16 [00:07<00:00,  2.03it/s]


Epoch 7: train loss = 0.037440, dev loss = 0.042565, dev wer = 0.469000


0.02:   2%|▏         | 3/125 [00:00<00:24,  4.96it/s]



input: népszerű
guess: n eː p ɛ r yː
truth: n eː p s ɛ r yː

input: önökkel
guess: ø n ø kː ɛ l
truth: ø n ø kː ɛ l



0.01:  50%|█████     | 63/125 [00:04<00:07,  8.58it/s]



input: hetvennégy
guess: h ɛ t v ɛ nː eː ɟ
truth: h ɛ t v ɛ nː eː ɟ

input: hónapok
guess: h oː n ɒ p o k
truth: h oː n ɒ p o k



0.09:  98%|█████████▊| 123/125 [00:09<00:00,  8.88it/s]



input: faxon
guess: f ɒ k s o n
truth: f ɒ k s o n

input: álmai
guess: aː l ɒ j i
truth: aː l m ɒ j i



0.03: 100%|██████████| 125/125 [00:09<00:00, 13.60it/s]
0.01: 100%|██████████| 16/16 [00:00<00:00, 30.00it/s]
100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 8: train loss = 0.030524, dev loss = 0.039545, dev wer = 0.389000


0.05:   2%|▏         | 3/125 [00:00<00:25,  4.83it/s]



input: munkacsoport
guess: m u ŋ k ɒ t͡ʃ o p o r t
truth: m u ŋ k ɒ t͡ʃ o p o r t

input: segítséget
guess: ʃ ɛ ɡ iː t͡ʃː eː ɡ ɛ t
truth: ʃ ɛ ɡ iː t͡ʃː eː ɡ ɛ t



0.02:  50%|█████     | 63/125 [00:04<00:06,  9.18it/s]



input: árai
guess: aː r ɒ j i
truth: aː r ɒ j i

input: világgal
guess: v i l aː ɡː ɒ l
truth: v i l aː ɡː ɒ l



0.02:  98%|█████████▊| 123/125 [00:09<00:00,  8.57it/s]



input: árut
guess: aː r u t
truth: aː r u t

input: tevékenységet
guess: t ɛ v eː k ɛ ɲ eː ɡ ɛ t
truth: t ɛ v eː k ɛ ɲ ʃ eː ɡ ɛ t



0.01: 100%|██████████| 125/125 [00:09<00:00, 13.69it/s]
0.03: 100%|██████████| 16/16 [00:00<00:00, 30.93it/s]
100%|██████████| 16/16 [00:08<00:00,  1.99it/s]


Epoch 9: train loss = 0.026360, dev loss = 0.036387, dev wer = 0.403000


0.01:   2%|▏         | 3/125 [00:00<00:28,  4.26it/s]



input: tanácshoz
guess: t ɒ n aː t͡ʃ o z
truth: t ɒ n aː t͡ʃ h o z

input: szakemberekkel
guess: s ɒ k ɛ m ɛ r ɛ kː ɛ l
truth: s ɒ k ɛ m b ɛ r ɛ kː ɛ l



0.01:  50%|█████     | 63/125 [00:05<00:07,  8.24it/s]



input: tegyenek
guess: t ɛ ɟ ɛ n ɛ k
truth: t ɛ ɟ ɛ n ɛ k

input: összefüggő
guess: ø sː ɛ f y ɡː øː
truth: ø sː ɛ f y ɡː øː



0.02:  98%|█████████▊| 123/125 [00:09<00:00,  8.86it/s]



input: ered
guess: ɛ r ɛ d
truth: ɛ r ɛ d

input: japán
guess: j ɒ p aː n
truth: j ɒ p aː n



0.01: 100%|██████████| 125/125 [00:09<00:00, 12.67it/s]
0.01: 100%|██████████| 16/16 [00:00<00:00, 29.53it/s]
100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 10: train loss = 0.021996, dev loss = 0.033850, dev wer = 0.377000


0.01:   2%|▏         | 3/125 [00:00<00:25,  4.83it/s]



input: földet
guess: f ø l ɛ t
truth: f ø l d ɛ t

input: tetemes
guess: t ɛ t ɛ m ɛ ʃ
truth: t ɛ t ɛ m ɛ ʃ



0.01:  50%|█████     | 63/125 [00:04<00:06,  9.08it/s]



input: futnak
guess: f u t n ɒ k
truth: f u t n ɒ k

input: vigyázz
guess: v i
truth: v i ɟ aː zː



0.04:  98%|█████████▊| 123/125 [00:09<00:00,  8.79it/s]



input: fogvatartott
guess: f o ɡ ɒ t ɒ r t o tː
truth: f o ɡ v ɒ t ɒ r t o tː

input: közöttünk
guess: k ø z ø tː y ŋ k
truth: k ø z ø tː y ŋ k



0.01: 100%|██████████| 125/125 [00:09<00:00, 13.40it/s]
0.02: 100%|██████████| 16/16 [00:00<00:00, 30.16it/s]
100%|██████████| 16/16 [00:07<00:00,  2.00it/s]


Epoch 11: train loss = 0.018753, dev loss = 0.033718, dev wer = 0.360000


0.01:   2%|▏         | 3/125 [00:00<00:25,  4.83it/s]



input: alkalmat
guess: ɒ l k ɒ l ɒ t
truth: ɒ l k ɒ l m ɒ t

input: ezáltal
guess: ɛ z aː l t ɒ l
truth: ɛ z aː l t ɒ l



0.01:  50%|█████     | 63/125 [00:04<00:06,  9.48it/s]



input: jutott
guess: j u t o tː
truth: j u t o tː

input: volt
guess: v o l t
truth: v o l t



0.01:  98%|█████████▊| 123/125 [00:08<00:00,  9.28it/s]



input: regény
guess: r ɛ ɡ eː ɲ
truth: r ɛ ɡ eː ɲ

input: evangélium
guess: ɛ v ɒ ŋ eː l i j u m
truth: ɛ v ɒ ŋ ɡ eː l i j u m



0.01: 100%|██████████| 125/125 [00:08<00:00, 13.98it/s]
0.02: 100%|██████████| 16/16 [00:00<00:00, 30.84it/s]
100%|██████████| 16/16 [00:08<00:00,  1.99it/s]


Epoch 12: train loss = 0.016013, dev loss = 0.032175, dev wer = 0.381000


0.02:   2%|▏         | 3/125 [00:00<00:28,  4.35it/s]



input: óvszer
guess: oː f ɛ r
truth: oː f s ɛ r

input: felmérés
guess: f ɛ l eː r eː ʃ
truth: f ɛ l m eː r eː ʃ



0.01:  50%|█████     | 63/125 [00:05<00:07,  8.77it/s]



input: munkanélküli
guess: m u ŋ k ɒ n eː l k y l i
truth: m u ŋ k ɒ n eː l k y l i

input: bán
guess: b aː n
truth: b aː n



0.01:  98%|█████████▊| 123/125 [00:09<00:00,  8.66it/s]



input: egyszemélyes
guess: ɛ c ɛ m eː j ɛ ʃ
truth: ɛ c s ɛ m eː j ɛ ʃ

input: fogyó
guess: f o ɟ oː
truth: f o ɟ oː



0.00: 100%|██████████| 125/125 [00:09<00:00, 13.09it/s]
0.00: 100%|██████████| 16/16 [00:00<00:00, 28.86it/s]
100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 13: train loss = 0.013879, dev loss = 0.031379, dev wer = 0.468000


0.00:   2%|▏         | 3/125 [00:00<00:25,  4.79it/s]



input: foglalkozik
guess: f o ɡ ɒ l k o z i k
truth: f o ɡ l ɒ l k o z i k

input: mentes
guess: m ɛ n ɛ ʃ
truth: m ɛ n t ɛ ʃ



0.02:  50%|█████     | 63/125 [00:05<00:07,  8.67it/s]



input: vasárnapi
guess: v ɒ ʃ aː r ɒ p i
truth: v ɒ ʃ aː r n ɒ p i

input: divat
guess: d i v ɒ t
truth: d i v ɒ t



0.01:  98%|█████████▊| 123/125 [00:09<00:00,  9.21it/s]



input: tartalmát
guess: t ɒ r t ɒ l m aː t
truth: t ɒ r t ɒ l m aː t

input: gyűlt
guess: ɟ
truth: ɟ yː l t



0.01: 100%|██████████| 125/125 [00:09<00:00, 13.59it/s]
0.10: 100%|██████████| 16/16 [00:00<00:00, 30.52it/s]
100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 14: train loss = 0.012224, dev loss = 0.031787, dev wer = 0.391000


0.01:   2%|▏         | 3/125 [00:00<00:26,  4.65it/s]



input: nyelvészek
guess: ɲ ɛ l v eː s ɛ k
truth: ɲ ɛ l v eː s ɛ k

input: szereplő
guess: s ɛ r ɛ p øː
truth: s ɛ r ɛ p l øː



0.00:  50%|█████     | 63/125 [00:04<00:07,  8.76it/s]



input: főn
guess: f øː n
truth: f øː n

input: ritkaság
guess: r i t k ɒ ʃ aː ɡ
truth: r i t k ɒ ʃ aː ɡ



0.00:  98%|█████████▊| 123/125 [00:08<00:00,  9.21it/s]



input: szerintem
guess: s ɛ r i n ɛ m
truth: s ɛ r i n t ɛ m

input: fogvatartott
guess: f o ɡ ɒ t ɒ r t o tː
truth: f o ɡ v ɒ t ɒ r t o tː



0.00: 100%|██████████| 125/125 [00:09<00:00, 13.84it/s]
0.09: 100%|██████████| 16/16 [00:00<00:00, 29.87it/s]
100%|██████████| 16/16 [00:07<00:00,  2.00it/s]


Epoch 15: train loss = 0.010940, dev loss = 0.030687, dev wer = 0.447000


0.01:   2%|▏         | 3/125 [00:00<00:27,  4.46it/s]



input: oszét
guess: o s eː t
truth: o s eː t

input: építeni
guess: eː p iː t ɛ n i
truth: eː p iː t ɛ n i



0.00:  50%|█████     | 63/125 [00:05<00:07,  8.56it/s]



input: tavak
guess: t ɒ v ɒ k
truth: t ɒ v ɒ k

input: ötletet
guess: ø t l ɛ t ɛ t
truth: ø t l ɛ t ɛ t



0.01:  98%|█████████▊| 123/125 [00:09<00:00,  8.61it/s]



input: menet
guess: m ɛ n ɛ t
truth: m ɛ n ɛ t

input: feletti
guess: f ɛ l ɛ tː i
truth: f ɛ l ɛ tː i



0.01: 100%|██████████| 125/125 [00:09<00:00, 13.08it/s]
0.00: 100%|██████████| 16/16 [00:00<00:00, 29.29it/s]
100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 16: train loss = 0.009622, dev loss = 0.030972, dev wer = 0.393000


0.00:   2%|▏         | 3/125 [00:00<00:24,  4.88it/s]



input: annácska
guess: ɒ nː aː t͡ʃ ɒ
truth: ɒ nː aː t͡ʃ k ɒ

input: generáció
guess: ɡ ɛ n ɛ r aː t͡s i j oː
truth: ɡ ɛ n ɛ r aː t͡s i j oː



0.00:  50%|█████     | 63/125 [00:04<00:07,  8.41it/s]



input: sarkán
guess: ʃ ɒ r k aː n
truth: ʃ ɒ r k aː n

input: azokra
guess: ɒ z o k r ɒ
truth: ɒ z o k r ɒ



0.01:  98%|█████████▊| 123/125 [00:09<00:00,  9.03it/s]



input: időbeli
guess: i d øː b ɛ l i
truth: i d øː b ɛ l i

input: térségnek
guess: t eː r eː ɡ ɛ k
truth: t eː r ʃ eː ɡ n ɛ k



0.01: 100%|██████████| 125/125 [00:09<00:00, 13.57it/s]
0.00: 100%|██████████| 16/16 [00:00<00:00, 30.36it/s]
100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 17: train loss = 0.008737, dev loss = 0.030771, dev wer = 0.460000


0.00:   2%|▏         | 3/125 [00:00<00:25,  4.81it/s]



input: bevezetés
guess: b ɛ v ɛ z ɛ t eː ʃ
truth: b ɛ v ɛ z ɛ t eː ʃ

input: rá
guess: r aː
truth: r aː



0.01:  50%|█████     | 63/125 [00:04<00:06,  8.91it/s]



input: csillagkép
guess: t͡ʃ ɒ kː eː p
truth: t͡ʃ i lː ɒ kː eː p

input: hatóságok
guess: h ɒ t oː ʃ aː ɡ o k
truth: h ɒ t oː ʃ aː ɡ o k



0.01:  98%|█████████▊| 123/125 [00:08<00:00,  8.86it/s]



input: villanyt
guess: v i l ɒ ɲ
truth: v i lː ɒ ɲ t

input: programba
guess: p r o ɡ r ɒ m ɒ
truth: p r o ɡ r ɒ m b ɒ



0.00: 100%|██████████| 125/125 [00:09<00:00, 13.79it/s]
0.00: 100%|██████████| 16/16 [00:00<00:00, 29.31it/s]
 19%|█▉        | 3/16 [00:01<00:06,  1.94it/s]

In [None]:
print(train_losses)
print(test_losses)

Let's test the model on a new sentence:

In [None]:
test_output = "ɒ d m i n i s t r aː t o r n ɒ k"
test_input = "adminisztrátornak"
print("input: " + test_input)
x = torch.tensor(encode_src(test_input)).unsqueeze(0).to(model.device)
y = torch.tensor(encode_tgt(test_output)).unsqueeze(0).to(model.device)
T = torch.tensor([x.shape[1]]).to(model.device)
U = torch.tensor([y.shape[1]]).to(model.device)
guess = model.greedy_search(x,T)[0]
print("truth: " + test_output)
print("guess: " + decode_tgt(guess))
print("")
y_guess = torch.tensor(guess).unsqueeze(0).to(model.device)
U_guess = torch.tensor(len(guess)).unsqueeze(0).to(model.device)

print("NLL of truth: " + str(model.compute_loss(x, y, T, U)))
print("NLL of guess: " + str(model.compute_loss(x, y_guess, T, U_guess)))

In [None]:
wer = trainer.generate(test_set)
print(wer)

Observe that the negative log-likelihood of the guess is actually worse than that of the true label sequence (AKA, a "[search error](https://www.aclweb.org/anthology/D19-1331.pdf)"). This suggests that we could get better results using a beam search instead of the greedy search.