# Speech Recognition + LSTM + CTC + Torch Audio

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    """
    Encoder network: processes the input sequence (e.g., acoustic features)
    using a bidirectional LSTM and reduces the dimension with a linear layer.
    """
    def __init__(self, input_dim, hidden_dim, num_layers=1):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.fc   = nn.Linear(hidden_dim * 2, hidden_dim)
        
    def forward(self, x):
        # x: (batch, T, input_dim)
        out, _ = self.lstm(x)   # out: (batch, T, 2 * hidden_dim)
        out    = self.fc(out)      # out: (batch, T, hidden_dim)
        return out

class Predictor(nn.Module):
    """
    Predictor network: a language model that predicts the next token based on previous tokens.
    Uses an embedding layer and a unidirectional LSTM.
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1):
        super(Predictor, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm      = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        
    def forward(self, y):
        # y: (batch, U) target token sequence
        embedded = self.embedding(y)  # (batch, U, embed_dim)
        out, _   = self.lstm(embedded)    # (batch, U, hidden_dim)
        return out

class JointNetwork(nn.Module):
    """
    Joint network: combines encoder and predictor outputs using a specified combination mode.
    
    The available modes are:
      - 'multiplicative' or 'mul': element-wise multiplication.
      - 'additive' or 'add': element-wise addition.
    
    After combining the features, the network applies a tanh nonlinearity, then a fully connected layer
    to project to the vocabulary space, and finally returns a probability distribution via softmax.
    """
    MODES = {
        'multiplicative': lambda f, g: f * g,
        'mul': lambda f, g: f * g,
        'additive': lambda f, g: f + g,
        'add': lambda f, g: f + g
    }
    
    def __init__(self, hidden_dim, joint_dim, vocab_size, mode='additive'):
        """
        Args:
            hidden_dim (int): Dimension of the encoder and predictor outputs.
            joint_dim (int): Dimension of the joint space.
            vocab_size (int): Number of tokens in the vocabulary.
            mode (str): Combination mode, one of 'multiplicative'/'mul' or 'additive'/'add'.
        """
        super(JointNetwork, self).__init__()
        self.join_mode = self.MODES[mode]
        self.fc_enc    = nn.Linear(hidden_dim, joint_dim)
        self.fc_pred   = nn.Linear(hidden_dim, joint_dim)
        self.fc_out    = nn.Linear(joint_dim, vocab_size)
        
    def forward(self, enc_out, pred_out):
        """
        Combines the encoder and predictor outputs.
        
        Args:
            enc_out (Tensor): Encoder output of shape (batch, T, hidden_dim).
            pred_out (Tensor): Predictor output of shape (batch, U, hidden_dim).
        
        Returns:
            Tensor: Vocabulary probability distribution of shape (batch, T, U, vocab_size).
        """
        # Transform and expand dimensions for broadcasting:
        f_enc = self.fc_enc(enc_out).unsqueeze(2)    # (batch, T, 1, joint_dim)
        f_pred = self.fc_pred(pred_out).unsqueeze(1)    # (batch, 1, U, joint_dim)
        
        # Combine using the specified mode and apply tanh:
        joint = torch.tanh(self.join_mode(f_enc, f_pred))  # (batch, T, U, joint_dim)
        
        # Project to the vocabulary space and return softmax probabilities:
        logits = self.fc_out(joint)  # (batch, T, U, vocab_size)
        return torch.softmax(logits, dim=-1)

class RNNT(nn.Module):
    """
    RNN-Transducer (RNN-T) model combining the Encoder, Predictor, and JointNetwork.
    """
    def __init__(self, input_dim, vocab_size, encoder_hidden_dim,
                 predictor_embed_dim, predictor_hidden_dim,
                 joint_dim, encoder_layers=1, predictor_layers=1, joint_mode='additive'):
        super(RNNT, self).__init__()
        self.encoder   = Encoder(input_dim, encoder_hidden_dim, encoder_layers)
        self.predictor = Predictor(vocab_size, predictor_embed_dim, predictor_hidden_dim, predictor_layers)
        self.joint     = JointNetwork(encoder_hidden_dim, joint_dim, vocab_size, mode=joint_mode)
        
    def forward(self, x, y):
        """
        Args:
            x (Tensor): Input sequence (e.g., acoustic features) of shape (batch, T, input_dim).
            y (Tensor): Target token sequence of shape (batch, U).
        
        Returns:
            Tensor: Vocabulary distribution of shape (batch, T, U, vocab_size).
        """
        enc_out  = self.encoder(x)     # (batch, T, hidden_dim)
        pred_out = self.predictor(y)  # (batch, U, hidden_dim)
        logits   = self.joint(enc_out, pred_out)  # (batch, T, U, vocab_size)
        return logits

In [4]:
# Hyperparameters
batch_size = 2
T = 50   # Length of input sequence (e.g., number of acoustic frames)
U = 20   # Length of target token sequence
input_dim = 40      # Dimension of input features
vocab_size = 30     # Vocabulary size (including blank token)
encoder_hidden_dim = 256
predictor_embed_dim = 128
predictor_hidden_dim = 256
joint_dim = 512

# Instantiate the model
model = RNNT(
    input_dim, 
    vocab_size, 
    encoder_hidden_dim,
    predictor_embed_dim, 
    predictor_hidden_dim, 
    joint_dim
)

# Dummy input: acoustic features and target sequences
x = torch.randn(batch_size, T, input_dim)
y = torch.randint(0, vocab_size, (batch_size, U))

# Forward pass: obtain logits over the vocabulary
logits = model(x, y)
print("Logits shape:", logits.shape)  # Expected shape: (batch, T, U, vocab_size)

# To compute loss, you would typically use an RNN-T loss function,
# for example, from a third-party implementation.

Logits shape: torch.Size([2, 50, 20, 30])
