# Transducer PyTorch implementation
Goal: In this notebook, we'll use it to insert missing vowels into a sentence.

This is an image of the architecture.

![RNN-T architecture](images/rnnt_architecture.png)

In [1]:
# Imports
import torch
from tqdm import tqdm
from torch import nn
import torch.nn.functional as F
import itertools
from collections import Counter
# from speechbrain.nnet.loss.transducer_loss import TransducerLoss
import unidecode
import IPython
import string
import os
import urllib.request

In [2]:
# Get training data
file_path = "war_and_peace.txt"
url = "https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt"

if not os.path.exists(file_path):
    print("Downloading file...")
    urllib.request.urlretrieve(url, file_path)
    print("File downloaded.")
else:
    print("File already exists.")


File already exists.


In [3]:
# Variables
NULL_INDEX = 0
encoder_dim = 1024
predictor_dim = 1024
joiner_dim = 1024
     

In [4]:
# Encoder network
# The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.

class Encoder(nn.Module):
    """Encoder network.
    
    Input: Audio input

    Output: Encoded speech features. Context-aware encodings.
    """
    def __init__(self, num_inputs):
        super(Encoder, self).__init__()
        self.embed = torch.nn.Embedding(num_inputs, encoder_dim)
        self.rnn = torch.nn.GRU(input_size=encoder_dim, hidden_size=encoder_dim, num_layers=3, batch_first=True, bidirectional=True, dropout=0.1)
        self.linear = torch.nn.Linear(encoder_dim*2, joiner_dim)

    def forward(self, x):
        out = x
        out = self.embed(out)
        out = self.rnn(out)[0]
        out = self.linear(out)
        return out

In [5]:
# Predictor network
# The predictor is any causal network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention.

class Predictor(torch.nn.Module):
    """Predictor network.
    
    Input: Text inputs (labels).
    
    Output: RNN hidden states for each autoregressive timestep input.
    """
    def __init__(self, num_outputs):
        super(Predictor, self).__init__()
        self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
        self.rnn = torch.nn.GRUCell(input_size=predictor_dim, hidden_size=predictor_dim)
        self.linear = torch.nn.Linear(predictor_dim, joiner_dim)
        
        self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
        self.start_symbol = NULL_INDEX # In the original paper, a vector of 0s is used; just using the null index instead is easier when using an Embedding layer.

    def forward_one_step(self, input, previous_state):
        embedding = self.embed(input)
        state = self.rnn.forward(embedding, previous_state)
        out = self.linear(state)
        return out, state

    def forward(self, y):
        batch_size = y.shape[0]
        U = y.shape[1]
        outs = []
        state = torch.stack([self.initial_state] * batch_size).to(y.device)
        for u in range(U+1): # need U+1 to get null output for final timestep 
            if u == 0:
                decoder_input = torch.tensor([self.start_symbol] * batch_size).to(y.device)
            else:
                decoder_input = y[:,u-1]
            out, state = self.forward_one_step(decoder_input, state)
            outs.append(out)
        out = torch.stack(outs, dim=1)
        return out

In [6]:
class Joiner(torch.nn.Module):
    def __init__(self, num_outputs):
        super(Joiner, self).__init__()
        self.linear = torch.nn.Linear(joiner_dim, num_outputs)

    def forward(self, encoder_out, predictor_out):
        """Forward pass.
        
        Args:
            encoder_out: Output of the encoder network. Shape [batch_size, T, 1, encoder_dim]. # TODO Confirm that these shapes are correct.
            predictor_out: Output of the predictor network. Shape [batch_size, 1, U, predictor_dim]. # TODO Confirm that these shapes are correct.
        """
        out = encoder_out + predictor_out
        out = torch.nn.functional.relu(out)
        out = self.linear(out)
        return out

Alignment matrix

![RNN-T alignment matrix](images/rnnt_alignment_matrix.png)

In [7]:
class Transducer(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Transducer, self).__init__()
        self.encoder = Encoder(num_inputs)
        self.predictor = Predictor(num_outputs)
        self.joiner = Joiner(num_outputs)

        if torch.cuda.is_available():
            self.device = "cuda:0"
        else:
            self.device = "cpu"
        self.to(self.device)


    def compute_forward_prob(self, joiner_out, T, U, y):
        """Compute forward probability.

        Args:
            joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
            T: list of input lengths
            U: list of output lengths 
            y: label tensor (B, U_max+1)

        Returns:
            log_probs: log probs of each possible alignment path between input and output. tensor of shape (B)
        """
        B = joiner_out.shape[0]
        T_max = joiner_out.shape[1]
        U_max = joiner_out.shape[2] - 1
        log_alpha = torch.zeros(B, T_max, U_max+1).to(self.device)
        for t in range(T_max):
            for u in range(U_max+1):
                # The log_alpha tensor stores the summed forward probabilities up to a certain timestep, considering all possible paths to get to the timestep (null index = right/increment t or label index = down/increment u).
                # The following explains each case covered:
                # Case (t == 0) and (u == 0): This is the initial state. The forward probability is set to 0, representing a probability of 1 in log space.
                # Case (t > 0) and (u == 0): This case calculates the forward probability for the beginning of the output sequence by adding the previous forward probability (log_alpha[:, t-1, u]) with the joiner output probability for the NULL_INDEX.
                # Case (t == 0) and (u > 0): This case calculates the forward probability for the beginning of the input sequence. It does this by adding the previous forward probability (log_alpha[:, t, u-1]) with the joiner output probability at the corresponding label (y[:, u-1]).
                # Case (t > 0) and (u > 0): This case calculates the forward probability for other (t, u) pairs in the lattice. It computes the log sum of exponentials of the two possible paths:
                    #a) from the state (t-1, u) with the joiner output probability for the NULL_INDEX,
                    # b) from the state (t, u-1) with the joiner output probability at the corresponding label (y[:, u-1]).
                if u == 0:
                    if t == 0:
                        log_alpha[:, t, u] = 0.

                    else: #t > 0
                        log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] 
                            
                else: #u > 0
                    if t == 0:
                        log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
                    
                    else: #t > 0
                        log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                            log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                            log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
                        ]), dim=0)


        log_probs = []
        for b in range(B):
            log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX] # Add forward probability of final step (1, 3, 3 = which is a cumulated prob) and the null probability.
            log_probs.append(log_prob)
        log_probs = torch.stack(log_probs)
        return log_prob

    def compute_loss(self, x, y, T, U):
        encoder_out = self.encoder.forward(x)
        predictor_out = self.predictor.forward(y)
        joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
        loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
        return loss
  
    def compute_single_alignment_prob(self, encoder_out, predictor_out, T, U, z, y):
        """Computes the probability of one alignment, z.

        What does the encoder_out network represent?
        What does the predictor_out network represent?


        Args:
            encoder_out: Output of the encoder network. Shape [T, encoder_dim]. T = 4. In this example it is [4, 1024]
            predictor_out: Output of the predictor network. Shape [U, predictor_dim]. U = 3+1 = 4. In this example it is [4, 1024]
            T: Length of the encoder_out sequence.
            U: Length of the predictor_out sequence.
            z: Alignment. List of 0s and 1s, where 0 means "right" (increment t, encoder output) and 1 means "down" (increment u, label).
            y: Label sequence. List of labels.
        
        Returns:
            logprob: Log probability of the alignment.
        """
        t = 0; u = 0 # t and u are merely movements in the alignment matrix that we'll use to index into the encoder_out and predictor_out matrices.
        t_u_indices = []
        y_expanded = []
        for step in z:
            t_u_indices.append((t,u))
            if step == 0: # right (null)
                y_expanded.append(NULL_INDEX)
                t += 1
            if step == 1: # down (label)
                y_expanded.append(y[u])
                u += 1

        t_u_indices.append((T-1,U))
        y_expanded.append(NULL_INDEX)

        t_indices = [t for (t,u) in t_u_indices]
        u_indices = [u for (t,u) in t_u_indices]
        encoder_out_expanded = encoder_out[t_indices] # Indexed into the encoder/predictor_out lists using the t_u alignments.
        predictor_out_expanded = predictor_out[u_indices]
        joiner_out = self.joiner.forward(encoder_out_expanded, predictor_out_expanded).log_softmax(1) # joiner network takes the alignments at each timestep ((0, 0) hidden states) as inputs through it and pops out softmaxed probabilities.
        logprob = -torch.nn.functional.nll_loss(input=joiner_out, target=torch.tensor(y_expanded).long().to(self.device), reduction="sum") # For each permutation of the alignment matrix (encoder_out + predictor_out for each step), expect the target outcome is the z alignment sequence (0 or tensor(3)).
        return logprob  

In [8]:
# Generate example inputs/outputs
num_outputs = len(string.ascii_uppercase) + 1 # [null, A, B, ... Z]
model = Transducer(1, num_outputs)
y_letters = "CAT"
y = torch.tensor([string.ascii_uppercase.index(l) + 1 for l in y_letters]).unsqueeze(0).to(model.device) # tokenize
T = torch.tensor([4]) # Time
U = torch.tensor([len(y_letters)]) # Number of labels
B = 1 # Batch

# encoder_out and predictor_out have the same dimentions.
encoder_out = torch.randn(B, T, joiner_dim).to(model.device) # Fairly sure T = 4 because of NULL_INDEX (0) token prepended to the input. TODO: confirm this.
predictor_out = torch.randn(B, U+1, joiner_dim).to(model.device) # Fairly sure U = 3 + 1 because NULL_INDEX (0) token prepended to the input. TODO: confirm this.
joiner_out = model.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)

#######################################################
# Compute loss by enumer/ating all possible alignments #
#######################################################
all_permutations = list(itertools.permutations([0]*(T-1) + [1]*U))
all_distinct_permutations = list(Counter(all_permutations).keys())
alignment_probs = []
for z in all_distinct_permutations:
    alignment_prob = model.compute_single_alignment_prob(encoder_out[0], predictor_out[0], T.item(), U.item(), z, y[0])
    alignment_probs.append(alignment_prob)
loss_enumerate = -torch.tensor(alignment_probs).logsumexp(0) # log of the sum of the exponentials of all the alignment probability losses.

#######################################################
# Compute loss using the forward algorithm            #
#######################################################
loss_forward = -model.compute_forward_prob(joiner_out, T, U, y) # the probability of seeing the output given all possible alignments suggested by joiner_out outputs.

print("Loss computed by enumerating all possible alignments: ", loss_enumerate)
print("Loss computed using the forward algorithm: ", loss_forward)
if torch.allclose(loss_enumerate, loss_forward):
    print("Losses are equal.") # Losses should be equal.

Loss computed by enumerating all possible alignments:  tensor(24.2029)
Loss computed using the forward algorithm:  tensor(24.2029, grad_fn=<NegBackward0>)
Losses are equal.


In [9]:
def greedy_search(self, x, T):
  """Greey search algorithm used to predict an output sequence."""
  # TODO understand this
  y_batch = []
  B = len(x)
  encoder_out = self.encoder.forward(x)
  U_max = 200
  for b in range(B):
    t = 0; u = 0; y = [self.predictor.start_symbol]; predictor_state = self.predictor.initial_state.unsqueeze(0)
    while t < T[b] and u < U_max:
      predictor_input = torch.tensor([ y[-1] ]).to(x.device)
      g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)
      f_t = encoder_out[b, t]
      h_t_u = self.joiner.forward(f_t, g_u)
      argmax = h_t_u.max(-1)[1].item()
      if argmax == NULL_INDEX:
        t += 1
      else: # argmax == a label
        u += 1
        y.append(argmax)
    y_batch.append(y[1:]) # remove start symbol
  return y_batch

Transducer.greedy_search = greedy_search

In [10]:
# This is a faster version of the transudcer loss from speechbrain library. Presumably because it's not written in pure Python.
# TODO This doesn't work on non-CUDA devices.

# transducer_loss = TransducerLoss(0)

# def compute_loss(self, x, y, T, U):
#     encoder_out = self.encoder.forward(x)
#     predictor_out = self.predictor.forward(y)
#     joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
#     #loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
#     T = T.to(joiner_out.device)
#     U = U.to(joiner_out.device)
#     loss = transducer_loss(joiner_out, y, T, U) #, blank_index=NULL_INDEX, reduction="mean")
#     return loss

# Transducer.compute_loss = compute_loss

In [11]:
class TextDataset(torch.utils.data.Dataset):
  """Creates text dataset with vowels removed from input."""
  def __init__(self, lines, batch_size):
    lines = list(filter(("\n").__ne__, lines))

    self.lines = lines # list of strings
    collate = Collate()
    self.loader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=0, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].replace("\n", "")
    line = unidecode.unidecode(line) # remove special characters
    x = "".join(c for c in line if c not in "AEIOUaeiou") # remove vowels from input
    y = line
    return (x,y)

def encode_string(s):
  for c in s:
    if c not in string.printable:
      print(s)
  return [string.printable.index(c) + 1 for c in s]

def decode_labels(l):
  return "".join([string.printable[c - 1] for c in l])

class Collate:
  def __call__(self, batch):
    """
    batch: list of tuples (input string, output string)
    Returns a minibatch of strings, encoded as labels and padded to have the same length.
    """
    x = []; y = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_,y_ = batch[index]
      x.append(encode_string(x_))
      y.append(encode_string(y_))

    # pad all sequences to have same length
    T = [len(x_) for x_ in x]
    U = [len(y_) for y_ in y]
    T_max = max(T)
    U_max = max(U)
    for index in range(batch_size):
      x[index] += [NULL_INDEX] * (T_max - len(x[index]))
      x[index] = torch.tensor(x[index])
      y[index] += [NULL_INDEX] * (U_max - len(y[index]))
      y[index] = torch.tensor(y[index])

    # stack into single tensor
    x = torch.stack(x)
    y = torch.stack(y)
    T = torch.tensor(T)
    U = torch.tensor(U)

    return (x,y,T,U)

with open("war_and_peace.txt", "r") as f:
  lines = f.readlines()

end = round(0.9 * len(lines))
train_lines = lines[:end]
test_lines = lines[end:]
train_set = TextDataset(train_lines, batch_size=64) #8)
test_set = TextDataset(test_lines, batch_size=64) #8)
train_set.__getitem__(0)

('"Wll, Prnc, s Gn nd Lcc r nw jst fmly stts f th',
 '"Well, Prince, so Genoa and Lucca are now just family estates of the')

In [12]:
class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
  
  def train(self, dataset, print_interval = 20):
      train_loss = 0
      num_samples = 0
      self.model.train()
      pbar = tqdm(dataset.loader)
      for idx, batch in enumerate(pbar):
        x,y,T,U = batch
        x = x.to(self.model.device); y = y.to(self.model.device)
        batch_size = len(x)
        num_samples += batch_size
        loss = self.model.compute_loss(x,y,T,U)
        self.optimizer.zero_grad()
        pbar.set_description("%.2f" % loss.item())
        loss.backward()
        self.optimizer.step()
        train_loss += loss.item() * batch_size
        if idx % print_interval == 0:
          self.model.eval()
          guesses = self.model.greedy_search(x,T)
          self.model.train()
          print("\n")
          for b in range(2):
            print("input:", decode_labels(x[b,:T[b]]))
            print("guess:", decode_labels(guesses[b]))
            print("truth:", decode_labels(y[b,:U[b]]))
            print("")

      train_loss /= num_samples
      return train_loss

  def test(self, dataset, print_interval=1):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x,y,T,U = batch
      x = x.to(self.model.device); y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      pbar.set_description("%.2f" % loss.item())
      test_loss += loss.item() * batch_size
      if idx % print_interval == 0:
        print("\n")
        print("input:", decode_labels(x[0,:T[0]]))
        print("guess:", decode_labels(self.model.greedy_search(x,T)[0]))
        print("truth:", decode_labels(y[0,:U[0]]))
        print("")
    test_loss /= num_samples
    return test_loss
    

In [13]:
# Train the model.
num_chars = len(string.printable)
model = Transducer(num_inputs=num_chars+1, num_outputs=num_chars+1)
trainer = Trainer(model=model, lr=0.0003)

num_epochs = 1
train_losses=[]
test_losses=[]

for epoch in range(num_epochs):
    train_loss = trainer.train(train_set)
    test_loss = trainer.test(test_set)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    print("Epoch %d: train loss = %f, test loss = %f" % (epoch, train_loss, test_loss))

476.93:   0%|          | 0/709 [00:01<?, ?it/s]

hi
hi
