# Transducer PyTorch implementation
Goal: In this notebook, we'll use it to insert missing vowels into a sentence.

This is an image of the architecture.

![RNN-T architecture](images/rnnt_architecture.png)

In [1]:
# Get training data.
!wget https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt

--2023-04-04 12:23:56--  https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3196229 (3.0M) [text/plain]
Saving to: ‘war_and_peace.txt’


2023-04-04 12:23:57 (5.66 MB/s) - ‘war_and_peace.txt’ saved [3196229/3196229]



In [15]:
# Imports
import torch
from tqdm import tqdm
from torch import nn
import torch.nn.functional as F
import math
import IPython


NULL_INDEX = 0

encoder_dim = 1024
predictor_dim = 1024
joiner_dim = 1024
     

In [None]:
# Encoder network
# The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.

class Encoder(nn.Module):
    """Encoder network.
    
    Input: Audio input

    Output: Encoded speech features. Context-aware encodings.
    """
    def __init__(self, num_inputs):
        super(Encoder, self).__init__()
        self.embed = torch.nn.Embedding(num_inputs, encoder_dim)
        self.rnn = torch.nn.GRU(input_size=encoder_dim, hidden_size=encoder_dim, num_layers=3, batch_first=True, bidirectional=True, dropout=0.1)
        self.linear = torch.nn.Linear(encoder_dim*2, joiner_dim)

    def forward(self, x):
        out = x
        out = self.embed(out)
        out = self.rnn(out)[0]
        out = self.linear(out)
        return out

In [None]:
# Predictor network
# The predictor is any causal network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention.

class Predictor(torch.nn.Module):
    """Predictor network.
    
    Input: Text inputs (labels).
    
    Output: RNN hidden states for each autoregressive timestep input.
    """
    def __init__(self, num_outputs):
        super(Predictor, self).__init__()
        self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
        self.rnn = torch.nn.GRUCell(input_size=predictor_dim, hidden_size=predictor_dim)
        self.linear = torch.nn.Linear(predictor_dim, joiner_dim)
        
        self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
        self.start_symbol = NULL_INDEX # In the original paper, a vector of 0s is used; just using the null index instead is easier when using an Embedding layer.

    def forward_one_step(self, input, previous_state):
        embedding = self.embed(input)
        state = self.rnn.forward(embedding, previous_state)
        out = self.linear(state)
        return out, state

    def forward(self, y):
        batch_size = y.shape[0]
        U = y.shape[1]
        outs = []
        state = torch.stack([self.initial_state] * batch_size).to(y.device)
        for u in range(U+1): # need U+1 to get null output for final timestep 
            if u == 0:
                decoder_input = torch.tensor([self.start_symbol] * batch_size).to(y.device)
            else:
                decoder_input = y[:,u-1]
            out, state = self.forward_one_step(decoder_input, state)
            outs.append(out)
        out = torch.stack(outs, dim=1)
        return out

In [None]:
class Joiner(torch.nn.Module):
  def __init__(self, num_outputs):
    super(Joiner, self).__init__()
    self.linear = torch.nn.Linear(joiner_dim, num_outputs)

  def forward(self, encoder_out, predictor_out):
    out = encoder_out + predictor_out
    out = torch.nn.functional.relu(out)
    out = self.linear(out)
    return out

In [None]:
class Transducer(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Transducer, self).__init__()
        self.encoder = Encoder(num_inputs)
        self.predictor = Predictor(num_outputs)
        self.joiner = Joiner(num_outputs)

        if torch.cuda.is_available():
            self.device = "cuda:0"
        else:
            self.device = "cpu"
        self.to(self.device)


    def compute_forward_prob(self, joiner_out, T, U, y):
        """Compute forward probability.

        Args:
            joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
            T: list of input lengths
            U: list of output lengths 
            y: label tensor (B, U_max+1)

        Returns:
            log_probs: tensor of shape (B)
        """
        # TODO understand this
        B = joiner_out.shape[0]
        T_max = joiner_out.shape[1]
        U_max = joiner_out.shape[2] - 1
        log_alpha = torch.zeros(B, T_max, U_max+1).to(self.device)
        for t in range(T_max):
            for u in range(U_max+1):
                if u == 0:
                    if t == 0:
                        log_alpha[:, t, u] = 0.

                    else: #t > 0
                        log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] 
                            
                else: #u > 0
                    if t == 0:
                        log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
                    
                    else: #t > 0
                        log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                            log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                            log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
                        ]), dim=0)

        log_probs = []
        for b in range(B):
            log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]
            log_probs.append(log_prob)
        log_probs = torch.stack(log_probs) 
        return log_prob

    def compute_loss(self, x, y, T, U):
        encoder_out = self.encoder.forward(x)
        predictor_out = self.predictor.forward(y)
        joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
        loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
        return loss
  
    def compute_single_alignment_prob(self, encoder_out, predictor_out, T, U, z, y):
        """
        Computes the probability of one alignment, z.
        """
        # TODO understand this
        t = 0; u = 0
        t_u_indices = []
        y_expanded = []
        for step in z:
            t_u_indices.append((t,u))
            if step == 0: # right (null)
                y_expanded.append(NULL_INDEX)
                t += 1
            if step == 1: # down (label)
                y_expanded.append(y[u])
                u += 1
        t_u_indices.append((T-1,U))
        y_expanded.append(NULL_INDEX)

        t_indices = [t for (t,u) in t_u_indices]
        u_indices = [u for (t,u) in t_u_indices]
        encoder_out_expanded = encoder_out[t_indices]
        predictor_out_expanded = predictor_out[u_indices]
        joiner_out = self.joiner.forward(encoder_out_expanded, predictor_out_expanded).log_softmax(1)
        logprob = -torch.nn.functional.nll_loss(input=joiner_out, target=torch.tensor(y_expanded).long().to(self.device), reduction="sum")
        return logprob  

In [None]:
# TODO understanding this!!
# Generate example inputs/outputs
num_outputs = len(string.ascii_uppercase) + 1 # [null, A, B, ... Z]
model = Transducer(1, num_outputs)
y_letters = "CAT"
y = torch.tensor([string.ascii_uppercase.index(l) + 1 for l in y_letters]).unsqueeze(0).to(model.device)
T = torch.tensor([4]); U = torch.tensor([len(y_letters)]); B = 1

encoder_out = torch.randn(B, T, joiner_dim).to(model.device)
predictor_out = torch.randn(B, U+1, joiner_dim).to(model.device)
joiner_out = model.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)

#######################################################
# Compute loss by enumerating all possible alignments #
#######################################################
all_permutations = list(itertools.permutations([0]*(T-1) + [1]*U))
all_distinct_permutations = list(Counter(all_permutations).keys())
alignment_probs = []
for z in all_distinct_permutations:
  alignment_prob = model.compute_single_alignment_prob(encoder_out[0], predictor_out[0], T.item(), U.item(), z, y[0])
  alignment_probs.append(alignment_prob)
loss_enumerate = -torch.tensor(alignment_probs).logsumexp(0)

#######################################################
# Compute loss using the forward algorithm            #
#######################################################
loss_forward = -model.compute_forward_prob(joiner_out, T, U, y)

print("Loss computed by enumerating all possible alignments: ", loss_enumerate)
print("Loss computed using the forward algorithm: ", loss_forward)

