# Transducer implementation in PyTorch

*by Loren Lugosch*



In this notebook, we will implement a Transducer sequence-to-sequence model for inserting missing vowels into a sentence 

EX: ("W wll mplmnt sm cd." --> "We will implement some code.")
*idea: we can change the target sentences to be specific to a domain*


Default: ("Hll, Wrld" --> "Hello, World").

In [None]:
import torch
import string
import numpy as np
import itertools
from collections import Counter
from tqdm import tqdm
!pip install unidecode
import unidecode


"""
DIRECTIONS: Aside from the given training data, find one other open-source data. 
<15 min>
"""
# 1. Default training data.
!wget https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt
!pwd

# 2. Find a second training dataset.


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting unidecode
  Downloading Unidecode-1.3.4-py3-none-any.whl (235 kB)
[K     |████████████████████████████████| 235 kB 8.5 MB/s 
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.3.4
--2022-09-11 18:46:19--  https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3196229 (3.0M) [text/plain]
Saving to: ‘war_and_peace.txt’


2022-09-11 18:46:19 (269 MB/s) - ‘war_and_peace.txt’ saved [3196229/3196229]

/content


# Building blocks

First, we will define the encoder, predictor, and joiner using standard neural nets.

<img src="https://lorenlugosch.github.io/images/transducer/transducer-model.png" width="25%">

In [None]:
"""
RHETORIAL Q: How does changing these numbers affect the performance of the model?
In training? In testing? In after-paper performance?
"""
NULL_INDEX = 0

encoder_dim = 1024
predictor_dim = 1024
joiner_dim = 1024

The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.


In [None]:
class Encoder(torch.nn.Module):
  def __init__(self, num_inputs):
    """
    @num_inputs: the input size/length

    DIRECTIONS: complete the variables for the input_size, hidden_size, and bidirectional arguments in self.rnn
    <3 min>
    """
    super(Encoder, self).__init__()
    self.embed = torch.nn.Embedding(num_inputs, encoder_dim)
    self.rnn = torch.nn.GRU(input_size=None, hidden_size=None, num_layers=3, batch_first=True, bidirectional=None, dropout=0.1)
    self.linear = torch.nn.Linear(encoder_dim*2, joiner_dim)

  def forward(self, x):
    out = x
    out = self.embed(out)
    out = self.rnn(out)[0]
    out = self.linear(out)
    return out

The predictor is any _causal_ network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention. 

In [None]:
class Predictor(torch.nn.Module):
  def __init__(self, num_outputs):
    """
    @num_outputs: the output size/length

    DIRECTIONS: complete the variables for the input_size and hidden_size arguments in self.rnn
                complete the arguments to torch.nn.Linear in self.linear (hint: 2 required)
    <3 min>
    """
    super(Predictor, self).__init__()
    self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
    self.rnn = torch.nn.GRUCell(input_size=None, hidden_size=None)
    self.linear = torch.nn.Linear()
    
    self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
    self.start_symbol = NULL_INDEX # In the original paper, a vector of 0s is used; just using the null index instead is easier when using an Embedding layer.


  def forward_one_step(self, input, previous_state):
    """
    This is a helper function for the inherited forward method (since the input may vary).
    @input: decoder input
    @previous_state: state before passing the input through the RNN's forward method
    """
    embedding = self.embed(input)
    state = self.rnn.forward(embedding, previous_state)
    out = self.linear(state)
    return out, state


  def forward(self, y):
    """
    @y: tensor y

    DIRECTIONS: complete the variables for batch_size and U (hint: utilize how y is formatted)
                complete the variable in the for loop, i.e. replace the 'None'
    <5 min>
    """
    batch_size = None
    U = None 
    outs = []
    state = torch.stack([self.initial_state] * batch_size).to(y.device)
    for u in range(None): # hint: we want to get the NULL output for the final timestep 
      if u == 0:
        decoder_input = torch.tensor([self.start_symbol] * batch_size).to(y.device)
      else:
        decoder_input = y[:,u-1]
      out, state = self.forward_one_step(decoder_input, state)
      outs.append(out)
    out = torch.stack(outs, dim=1)
    return out

The joiner is a feedforward network/MLP with one hidden layer applied independently to each $(t,u)$ index.

(The linear part of the hidden layer is contained in the encoder and predictor, so we just do the nonlinearity here and then the output layer.)

In [None]:
class Joiner(torch.nn.Module):
  def __init__(self, num_outputs):
    """
    @num_outputs: size of softmax output over all labels
    """
    super(Joiner, self).__init__()
    self.linear = torch.nn.Linear(joiner_dim, num_outputs)

  def forward(self, encoder_out, predictor_out):
    """
    @encoder_out: 
    @predictor_out: 

    DIRECTIONS:    choose and apply a nonlinear function of your choice
    RHETORICAL Q:  why do we add nonlinearity in a neural network?
    <5 min>
    """
    out = encoder_out + predictor_out
    # out = ___ # add nonlinearity here
    out = self.linear(out)
    return out

# Transducer model + loss function

Using the encoder, predictor, and joiner, we will implement the Transducer model and its loss function.

<img src="https://lorenlugosch.github.io/images/transducer/forward-messages.png" width="25%">

We can use a simple PyTorch implementation of the loss function, relying on automatic differentiation to give us gradients.

In [None]:
class Transducer(torch.nn.Module):
  def __init__(self, num_inputs, num_outputs):
    super(Transducer, self).__init__()
    self.encoder = Encoder(num_inputs)
    self.predictor = Predictor(num_outputs)
    self.joiner = Joiner(num_outputs)

    if torch.cuda.is_available(): self.device = "cuda:0"
    else: self.device = "cpu"
    self.to(self.device)

  def compute_forward_prob(self, joiner_out, T, U, y):
    """
    @joiner_out: tensor of shape (B, T_max, U_max+1, num_labels)
    @T: list of input lengths
    @U: list of output lengths 
    @y: label tensor (B, U_max+1)

    DIRECTIONS: draw out a couple iterations of the nested for loop below
    <15 min>
    """
    B = joiner_out.shape[0]                                        #B = batch size??
    T_max = joiner_out.shape[1]
    U_max = joiner_out.shape[2] - 1
    log_alpha = torch.zeros(B, T_max, U_max+1).to(model.device)
    for t in range(T_max):
      for u in range(U_max+1):
          if u == 0:
            if t == 0:
              log_alpha[:, t, u] = 0.

            else: #t > 0
              log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] 
                  
          else: #u > 0
            if t == 0:
              log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
            
            else: #t > 0
              log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                  log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                  log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
              ]), dim=0)
    
    log_probs = []
    for b in range(B):
      log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]
      log_probs.append(log_prob)
    log_probs = torch.stack(log_probs) 
    return log_prob # history of logits??

  def compute_loss(self, x, y, T, U):
    """
    @x: input/training tensor
    @y: label tensor
    @T: list of the length of input sequences
    @U: list of the length of output sequences
    """
    encoder_out = self.encoder.forward(x)
    predictor_out = self.predictor.forward(y)
    joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
    loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
    return loss

Let's first verify that the forward algorithm actually correctly computes the sum (in log space, the [logsumexp](https://lorenlugosch.github.io/posts/2020/06/logsumexp/)) of all possible alignments, using a short input/output pair for which computing all possible alignments is feasible.

<img src="https://lorenlugosch.github.io/images/transducer/cat-align-1.png" width="25%">

In [None]:
def compute_single_alignment_prob(self, encoder_out, predictor_out, T, U, z, y):
    """
    Computes the probability of one alignment, z.
    @encoder_out: Transducer's self.encoder
    @predictor_out: Transducer's self.predictor
    @T: list of the length of input sequences
    @U: list of the length of output sequences
    @z: 
    @y: label tensor

    DIRECTIONS: write a brief description of the argument 'z' above
                complete the variables for t_indices and u_indices
    <5 min>
    """
    t = 0; u = 0
    t_u_indices = []
    y_expanded = []
    for step in z:
      t_u_indices.append((t,u))
      if step == 0: # right (null)
        y_expanded.append(NULL_INDEX)
        t += 1
      if step == 1: # down (label)
        y_expanded.append(y[u])
        u += 1
    t_u_indices.append((T-1,U))
    y_expanded.append(NULL_INDEX)

    t_indices = None
    u_indices = None
    encoder_out_expanded = encoder_out[t_indices]
    predictor_out_expanded = predictor_out[u_indices]
    joiner_out = self.joiner.forward(encoder_out_expanded, predictor_out_expanded).log_softmax(1)
    logprob = -torch.nn.functional.nll_loss(input=joiner_out, target=torch.tensor(y_expanded).long().to(self.device), reduction="sum")
    return logprob

Transducer.compute_single_alignment_prob = compute_single_alignment_prob

In [None]:
# Generate example inputs/outputs
num_outputs = len(string.ascii_uppercase) + 1 # [null, A, B, ... Z]
model = Transducer(1, num_outputs)
y_letters = "CAT"
y = torch.tensor([string.ascii_uppercase.index(l) + 1 for l in y_letters]).unsqueeze(0).to(model.device)
T = torch.tensor([4]); U = torch.tensor([len(y_letters)]); B = 1

encoder_out = torch.randn(B, T, joiner_dim).to(model.device)
predictor_out = torch.randn(B, U+1, joiner_dim).to(model.device)
joiner_out = model.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)

#######################################################
# Compute loss by enumerating all possible alignments #
#######################################################
all_permutations = list(itertools.permutations([0]*(T-1) + [1]*U))
all_distinct_permutations = list(Counter(all_permutations).keys())
alignment_probs = []
for z in all_distinct_permutations:
  alignment_prob = model.compute_single_alignment_prob(encoder_out[0], predictor_out[0], T.item(), U.item(), z, y[0])
  alignment_probs.append(alignment_prob)
loss_enumerate = -torch.tensor(alignment_probs).logsumexp(0)

#######################################################
# Compute loss using the forward algorithm            #
#######################################################
loss_forward = -model.compute_forward_prob(joiner_out, T, U, y)

print("Loss computed by enumerating all possible alignments: ", loss_enumerate)
print("Loss computed using the forward algorithm: ", loss_forward)

Loss computed by enumerating all possible alignments:  tensor(21.0155)
Loss computed using the forward algorithm:  tensor(21.0155, device='cuda:0', grad_fn=<NegBackward>)


Now let's add the greedy search algorithm for predicting an output sequence.

---



(Note that I've assumed we're using RNNs for the predictor here. You would have to modify this code a bit if you want to use convolutions/self-attention instead.) 
<br/><br/>
<img src="https://lorenlugosch.github.io/images/transducer/greedy-search.png" width="50%">

In [None]:
"""
DIRECTIONS: YOU DO NOT NEED TO IMPLEMENT BEAM SEARCH
Here is an *opportunity* to create a beam-search. While the code
for a greedy search is here, we can improve this algorithmically! So, you
use the greedy search code here to ensure that things are working

<might take a while>
"""

"""
def greedy_search(self, x, T):
  y_batch = []
  B = len(x)
  encoder_out = self.encoder.forward(x)
  U_max = 200
  for b in range(B):
    t = 0; u = 0; y = [self.predictor.start_symbol]; predictor_state = self.predictor.initial_state.unsqueeze(0)
    while t < T[b] and u < U_max:
      predictor_input = torch.tensor([ y[-1] ]).to(x.device)
      g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)
      f_t = encoder_out[b, t]
      h_t_u = self.joiner.forward(f_t, g_u)
      argmax = h_t_u.max(-1)[1].item()
      if argmax == NULL_INDEX:
        t += 1
      else: # argmax == a label
        u += 1
        y.append(argmax)
    y_batch.append(y[1:]) # remove start symbol
  return y_batch

Transducer.greedy_search = greedy_search
"""

In [None]:
!pip install speechbrain
from speechbrain.nnet.loss.transducer_loss import TransducerLoss
transducer_loss = TransducerLoss(0)

def compute_loss(self, x, y, T, U):
    encoder_out = self.encoder.forward(x)
    predictor_out = self.predictor.forward(y)
    joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
    #loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
    T = T.to(joiner_out.device)
    U = U.to(joiner_out.device)
    loss = transducer_loss(joiner_out, y, T, U) #, blank_index=NULL_INDEX, reduction="mean")
    return loss

Transducer.compute_loss = compute_loss

Collecting speechbrain
[?25l  Downloading https://files.pythonhosted.org/packages/27/a0/16623a27ceddd5c96dbf1f0ce3393c5266a8276a805f1a111382f8a46efb/speechbrain-0.5.5-py3-none-any.whl (350kB)
[K     |█                               | 10kB 22.2MB/s eta 0:00:01[K     |█▉                              | 20kB 29.3MB/s eta 0:00:01[K     |██▉                             | 30kB 23.6MB/s eta 0:00:01[K     |███▊                            | 40kB 26.9MB/s eta 0:00:01[K     |████▊                           | 51kB 25.4MB/s eta 0:00:01[K     |█████▋                          | 61kB 27.8MB/s eta 0:00:01[K     |██████▌                         | 71kB 18.1MB/s eta 0:00:01[K     |███████▌                        | 81kB 19.0MB/s eta 0:00:01[K     |████████▍                       | 92kB 17.7MB/s eta 0:00:01[K     |█████████▍                      | 102kB 17.7MB/s eta 0:00:01[K     |██████████▎                     | 112kB 17.7MB/s eta 0:00:01[K     |███████████▏                    | 12

# Some utilities

Here we will add a bit of boilerplate code for training and loading data.

In [None]:
class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines, batch_size):
    """
    @lines: list of strings
    """
    lines = list(filter(("\n").__ne__, lines))

    self.lines = lines 
    collate = Collate()
    self.loader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=1, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].replace("\n", "")
    line = unidecode.unidecode(line) # remove special characters
    x = "".join(c for c in line if c not in "AEIOUaeiou") # remove vowels from input
    y = line
    return (x,y)

def encode_string(s):
  """
  @s: string
  """
  for c in s:
    if c not in string.printable:
      print(s)
  return [string.printable.index(c) + 1 for c in s]

def decode_labels(l):
  """
  @l: list of labels
  """
  return "".join([string.printable[c - 1] for c in l])


class Collate:
  def __call__(self, batch):
    """
    Returns a minibatch of strings, encoded as labels and padded to have the same length.
    @batch: list of tuples (input string, output string)

    DIRECTIONS: after obtaining results from training on the default text, train on your second training text 
    <10 min>
    """
    x = []; y = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_,y_ = batch[index]
      x.append(encode_string(x_))
      y.append(encode_string(y_))

    # pad all sequences to have same length
    T = [len(x_) for x_ in x]
    U = [len(y_) for y_ in y]
    T_max = max(T)
    U_max = max(U)
    for index in range(batch_size):
      x[index] += [NULL_INDEX] * (T_max - len(x[index]))
      x[index] = torch.tensor(x[index])
      y[index] += [NULL_INDEX] * (U_max - len(y[index]))
      y[index] = torch.tensor(y[index])

    # stack into single tensor
    x = torch.stack(x)
    y = torch.stack(y)
    T = torch.tensor(T)
    U = torch.tensor(U)

    return (x,y,T,U)

with open("war_and_peace.txt", "r") as f:
  lines = f.readlines()

end = round(0.9 * len(lines))
train_lines = lines[:end]
test_lines = lines[end:]
train_set = TextDataset(train_lines, batch_size=64) #8)
test_set = TextDataset(test_lines, batch_size=64) #8)
train_set.__getitem__(0)

('"Wll, Prnc, s Gn nd Lcc r nw jst fmly stts f th',
 '"Well, Prince, so Genoa and Lucca are now just family estates of the')

In [None]:
class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
  
  def train(self, dataset, print_interval = 20):
    train_loss = 0
    num_samples = 0
    self.model.train()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x,y,T,U = batch
      x = x.to(self.model.device); y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      self.optimizer.zero_grad()
      pbar.set_description("%.2f" % loss.item())
      loss.backward()
      self.optimizer.step()
      train_loss += loss.item() * batch_size
      if idx % print_interval == 0:
        self.model.eval()
        guesses = self.model.greedy_search(x,T)
        self.model.train()
        print("\n")
        for b in range(2):
          print("input:", decode_labels(x[b,:T[b]]))
          print("guess:", decode_labels(guesses[b]))
          print("truth:", decode_labels(y[b,:U[b]]))
          print("")
    train_loss /= num_samples
    return train_loss

  def test(self, dataset, print_interval=1):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x,y,T,U = batch
      x = x.to(self.model.device); y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      pbar.set_description("%.2f" % loss.item())
      test_loss += loss.item() * batch_size
      if idx % print_interval == 0:
        print("\n")
        print("input:", decode_labels(x[0,:T[0]]))
        print("guess:", decode_labels(self.model.greedy_search(x,T)[0]))
        print("truth:", decode_labels(y[0,:U[0]]))
        print("")
    test_loss /= num_samples
    return test_loss
    

# Training the model

Now we will train a model. This will generate some output sequences every 20 batches.

In [None]:
num_chars = len(string.printable)
model = Transducer(num_inputs=num_chars+1, num_outputs=num_chars+1)
trainer = Trainer(model=model, lr=0.0003)

num_epochs = 1
train_losses=[]
test_losses=[]

for epoch in range(num_epochs):
    train_loss = trainer.train(train_set)
    test_loss = trainer.test(test_set)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    print("Epoch %d: train loss = %f, test loss = %f" % (epoch, train_loss, test_loss))

In [None]:
print(train_losses)
print(test_losses)

Let's test the model on a new sentence:

In [None]:
"""
DIRECTIONS: Experiment with different test outputs. What are some things to keep in mind when changing the test outputs?
<5 min>
"""

test_output = "Most people have little difficulty reading this sentence"
test_input = "".join(c for c in test_output if c not in "AEIOUaeiou")
print("input: " + test_input)
x = torch.tensor(encode_string(test_input)).unsqueeze(0).to(model.device)
y = torch.tensor(encode_string(test_output)).unsqueeze(0).to(model.device)
T = torch.tensor([x.shape[1]]).to(model.device)
U = torch.tensor([y.shape[1]]).to(model.device)
guess = model.greedy_search(x,T)[0]
print("truth: " + test_output)
print("guess: " + decode_labels(guess))
print("")
y_guess = torch.tensor(guess).unsqueeze(0).to(model.device)
U_guess = torch.tensor(len(guess)).unsqueeze(0).to(model.device)

print("NLL of truth: " + str(model.compute_loss(x, y, T, U)))
print("NLL of guess: " + str(model.compute_loss(x, y_guess, T, U_guess)))

input: Mst ppl hv lttl dffclty rdng ths sntnc
truth: Most people have little difficulty reading this sentence
guess: Mos pepel have litle dificulty riding thes sentenc

NLL of truth: tensor(0.1161, device='cuda:0', grad_fn=<TransducerBackward>)
NLL of guess: tensor(1.4692, device='cuda:0', grad_fn=<TransducerBackward>)


Observe that the negative log-likelihood of the guess is actually worse than that of the true label sequence (AKA, a "[search error](https://www.aclweb.org/anthology/D19-1331.pdf)"). This suggests that we could get better results using a beam search instead of the greedy search.