In [None]:
from transformers import AutoProcessor, Wav2Vec2Model, AutoFeatureExtractor
import torch
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")

inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")



In [7]:
inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

In [9]:
inputs.input_values.shape

torch.Size([1, 93680])

In [17]:
inputs = feature_extractor(dataset[3]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt", padding=True)

with torch.no_grad():
    outputs = model.transcriber(**inputs)

list(outputs.shape)

[1, 494, 768]

In [37]:
import torch
from torchaudio.models import emformer_rnnt_base

emformer = emformer_rnnt_base(num_symbols=4097)


# Parameters for batch size 4
B = 4
T = 15                # maximum source sequence length (including padding/right context)
D = 80               # feature dimension for sources
U = 10                # maximum target sequence length

# Create a random sources tensor of shape (B, T, D)
sources = torch.randn(B, T, D)
sources_wav2vec2 = torch.randn(B, 32000)
# source_lengths: valid source frame counts for each sample
source_lengths = torch.tensor([12, 15, 11, 14])

# For targets, assume our vocabulary size is 50
vocab_size = 50
targets = torch.randint(low=0, high=vocab_size, size=(B, U), dtype=torch.long)
# Target lengths for each sample in the batch
target_lengths = torch.tensor([8, 10, 9, 10])

# Optional predictor_state; set to None for this test
predictor_state = None

# Display the shapes and values
print("Batch Size 4")
print("Sources shape:", sources.shape)           # Expected shape: (4, 15, 40)
print("Source Lengths:", source_lengths)           # Expected: e.g., [12, 15, 11, 14]
print("Targets shape:", targets.shape)             # Expected shape: (4, 10)
print("Target Lengths:", target_lengths)           # Expected: e.g., [8, 10, 9, 10]

output, source_lengths, target_lengths, pred_state = emformer(sources, source_lengths, targets, target_lengths)
output_wav2vec2, source_lengths_wav2vec2, target_lengths_wav2vec2, pred_state_wav2vec2 = model(sources_wav2vec2, None, targets, target_lengths)

Batch Size 4
Sources shape: torch.Size([4, 15, 80])
Source Lengths: tensor([12, 15, 11, 14])
Targets shape: torch.Size([4, 10])
Target Lengths: tensor([ 8, 10,  9, 10])


In [38]:
print("Output shape:", output.shape)
print("Source Lengths:", source_lengths)
print("Target Lengths:", target_lengths)

Output shape: torch.Size([4, 2, 10, 4097])
Source Lengths: tensor([3, 3, 2, 3])
Target Lengths: tensor([ 8, 10,  9, 10])


In [39]:
print("Output shape:", output_wav2vec2.shape)
print("Source Lengths:", source_lengths_wav2vec2)
print("Target Lengths:", target_lengths_wav2vec2)

Output shape: torch.Size([4, 99, 10, 4097])
Source Lengths: tensor([99, 99, 99, 99], dtype=torch.int32)
Target Lengths: tensor([ 8, 10,  9, 10])


In [36]:
from typing import Optional, List, Tuple
from torchaudio.models.rnnt import _Predictor, _Joiner
from torchaudio.models import RNNT
from transformers import Wav2Vec2Model

class Wav2Vec2HiddenStates(Wav2Vec2Model):
    """
    Wav2Vec2Model with a modified forward method to return just last hidden state.
    """
    def forward(
        self,
        input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.Tensor:
        outputs = super().forward(
            input_values=input_values,
            attention_mask=attention_mask,
            mask_time_indices=mask_time_indices,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return outputs.last_hidden_state

class Wav2vec2RNNT(RNNT):
    @torch.jit.export
    def transcribe_streaming(
        self,
        sources: torch.Tensor,
        source_lengths,
        state: Optional[List[List[torch.Tensor]]],
    ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
        raise NotImplementedError("No streaming for Wav2Vec2Model.")
    
    def forward(
        self,
        sources: torch.Tensor,
        source_lengths,
        targets: torch.Tensor,
        target_lengths: torch.Tensor,
        predictor_state: Optional[List[List[torch.Tensor]]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
        source_encodings = self.transcriber(
            input_values=sources,
        )
        source_lengths = torch.full((source_encodings.size(0),), source_encodings.size(1), dtype=torch.int32)
        target_encodings, target_lengths, predictor_state = self.predictor(
            input=targets,
            lengths=target_lengths,
            state=predictor_state,
        )
        output, source_lengths, target_lengths = self.joiner(
            source_encodings=source_encodings,
            source_lengths=source_lengths,
            target_encodings=target_encodings,
            target_lengths=target_lengths,
        )

        return (
            output,
            source_lengths,
            target_lengths,
            predictor_state,
        )
        


def wav2vec2_rnnt_model(
    *,
    encoding_dim: int,
    num_symbols: int,
    symbol_embedding_dim: int,
    num_lstm_layers: int,
    lstm_layer_norm: bool,
    lstm_layer_norm_epsilon: float,
    lstm_dropout: float,
) -> Wav2vec2RNNT:
    encoder = Wav2Vec2HiddenStates.from_pretrained("facebook/wav2vec2-base")
    predictor = _Predictor(
        num_symbols,
        encoding_dim,
        symbol_embedding_dim=symbol_embedding_dim,
        num_lstm_layers=num_lstm_layers,
        lstm_hidden_dim=symbol_embedding_dim,
        lstm_layer_norm=lstm_layer_norm,
        lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
        lstm_dropout=lstm_dropout,
    )
    joiner = _Joiner(encoding_dim, num_symbols)
    return Wav2vec2RNNT(encoder, predictor, joiner)

model = wav2vec2_rnnt_model(
    encoding_dim=768, # Wav2Vec2-base output dim
    num_symbols=4097,
    symbol_embedding_dim=512,
    num_lstm_layers=3,
    lstm_layer_norm=True,
    lstm_layer_norm_epsilon=1e-3,
    lstm_dropout=0.3,
)



# Transducer implementation in PyTorch

*by Loren Lugosch*



In this notebook, we will implement a Transducer sequence-to-sequence model for inserting missing vowels into a sentence ("Hll, Wrld" --> "Hello, World").

In [None]:
import torch
import string
import numpy as np
import itertools
from collections import Counter
from tqdm import tqdm
import unidecode

# # Some training data.
# # Poor Tolstoy, once again reduced to grist for the neural network mill!
# !wget https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt
# !pwd

# Building blocks

First, we will define the encoder, predictor, and joiner using standard neural nets.

<img src="https://lorenlugosch.github.io/images/transducer/transducer-model.png" width="25%">

In [None]:
NULL_INDEX = 0

encoder_dim = 1024
predictor_dim = 1024
joiner_dim = 1024

The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.


In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, num_inputs):
        super(Encoder, self).__init__()
        self.embed = torch.nn.Embedding(num_inputs, encoder_dim)
        self.rnn = torch.nn.GRU(
            input_size=encoder_dim,
            hidden_size=encoder_dim,
            num_layers=3,
            batch_first=True,
            bidirectional=True,
            dropout=0.1,
        )
        self.linear = torch.nn.Linear(encoder_dim * 2, joiner_dim)

    def forward(self, x):
        out = x
        out = self.embed(out)
        out = self.rnn(out)[0]
        out = self.linear(out)
        return out

The predictor is any _causal_ network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention. 

In [None]:
class Predictor(torch.nn.Module):
    def __init__(self, num_outputs):
        super(Predictor, self).__init__()
        self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
        self.rnn = torch.nn.GRU(
            input_size=predictor_dim,
            hidden_size=predictor_dim,
            num_layers=3,
            batch_first=True,
            bidirectional=False,
            dropout=0.1,
        )
        self.linear = torch.nn.Linear(predictor_dim, joiner_dim)

        # Updated: initial_state now has shape (num_layers, predictor_dim)
        self.initial_state = torch.nn.Parameter(torch.randn(self.rnn.num_layers, predictor_dim))
        self.start_symbol = NULL_INDEX  # Using null index for embedding start symbol

    def forward_one_step(self, input, previous_state):
        # Embed input and add sequence dimension
        embedding = self.embed(input).unsqueeze(1)  # Shape: (batch_size, 1, predictor_dim)
        # Pass through GRU; GRU returns (output, hidden_state)
        output, new_state = self.rnn(embedding, previous_state)
        # Remove sequence dimension before the linear layer
        out = self.linear(output.squeeze(1))
        return out, new_state

    def forward(self, y):
        batch_size = y.shape[0]
        U = y.shape[1]
        outs = []
        # Expand initial_state to shape (num_layers, batch_size, predictor_dim)
        state = self.initial_state.unsqueeze(1).expand(-1, batch_size, -1).to(y.device)
        for u in range(U + 1):  # U+1 steps to include final timestep
            if u == 0:
                decoder_input = torch.tensor(
                    [self.start_symbol] * batch_size, device=y.device
                )
            else:
                decoder_input = y[:, u - 1]
            out, state = self.forward_one_step(decoder_input, state)
            outs.append(out)
        out = torch.stack(outs, dim=1)
        return out


The joiner is a feedforward network/MLP with one hidden layer applied independently to each $(t,u)$ index.

(The linear part of the hidden layer is contained in the encoder and predictor, so we just do the nonlinearity here and then the output layer.)

In [None]:
class Joiner(torch.nn.Module):
    def __init__(self, num_outputs):
        super(Joiner, self).__init__()
        self.linear = torch.nn.Linear(joiner_dim, num_outputs)

    def forward(self, encoder_out, predictor_out):
        out = encoder_out + predictor_out
        out = torch.nn.functional.relu(out)
        out = self.linear(out)
        return out

# Transducer model + loss function

Using the encoder, predictor, and joiner, we will implement the Transducer model and its loss function.

<img src="https://lorenlugosch.github.io/images/transducer/forward-messages.png" width="25%">

We can use a simple PyTorch implementation of the loss function, relying on automatic differentiation to give us gradients.

In [None]:
class Transducer(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Transducer, self).__init__()
        self.encoder = Encoder(num_inputs)
        self.predictor = Predictor(num_outputs)
        self.joiner = Joiner(num_outputs)

        if torch.cuda.is_available():
            self.device = "cuda:0"
        # elif torch.backends.mps.is_available():
        #     self.device = "mps"
        else:
            self.device = "cpu"
        self.to(self.device)
        print("Using device:", self.device)

    def compute_forward_prob(self, joiner_out, T, U, y):
        """
        joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
        T: list of input lengths
        U: list of output lengths
        y: label tensor (B, U_max+1)
        """
        B = joiner_out.shape[0]
        T_max = joiner_out.shape[1]
        U_max = joiner_out.shape[2] - 1
        log_alpha = torch.zeros(B, T_max, U_max + 1, device=self.device)
        for t in range(T_max):
            for u in range(U_max + 1):
                if u == 0:
                    if t == 0:
                        log_alpha[:, t, u] = 0.0

                    else:  # t > 0
                        log_alpha[:, t, u] = (
                            log_alpha[:, t - 1, u] + joiner_out[:, t - 1, 0, NULL_INDEX]
                        )

                else:  # u > 0
                    if t == 0:
                        log_alpha[:, t, u] = log_alpha[:, t, u - 1] + torch.gather(
                            joiner_out[:, t, u - 1],
                            dim=1,
                            index=y[:, u - 1].view(-1, 1),
                        ).reshape(-1)

                    else:  # t > 0
                        log_alpha[:, t, u] = torch.logsumexp(
                            torch.stack(
                                [
                                    log_alpha[:, t - 1, u]
                                    + joiner_out[:, t - 1, u, NULL_INDEX],
                                    log_alpha[:, t, u - 1]
                                    + torch.gather(
                                        joiner_out[:, t, u - 1],
                                        dim=1,
                                        index=y[:, u - 1].view(-1, 1),
                                    ).reshape(-1),
                                ]
                            ),
                            dim=0,
                        )

        log_probs = []
        for b in range(B):
            log_prob = (
                log_alpha[b, T[b] - 1, U[b]] + joiner_out[b, T[b] - 1, U[b], NULL_INDEX]
            )
            log_probs.append(log_prob)
        log_probs = torch.stack(log_probs)
        return log_probs

    def compute_loss(self, x, y, T, U):
        encoder_out = self.encoder.forward(x)
        predictor_out = self.predictor.forward(y)
        joiner_out = self.joiner.forward(
            encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)
        ).log_softmax(3)
        loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
        return loss

Let's first verify that the forward algorithm actually correctly computes the sum (in log space, the [logsumexp](https://lorenlugosch.github.io/posts/2020/06/logsumexp/)) of all possible alignments, using a short input/output pair for which computing all possible alignments is feasible.

<img src="https://lorenlugosch.github.io/images/transducer/cat-align-1.png" width="25%">

In [None]:
def compute_single_alignment_prob(self, encoder_out, predictor_out, T, U, z, y):
    """
    Computes the probability of one alignment, z.
    """
    t = 0
    u = 0
    t_u_indices = []
    y_expanded = []
    for step in z:
        t_u_indices.append((t, u))
        if step == 0:  # right (null)
            y_expanded.append(NULL_INDEX)
            t += 1
        if step == 1:  # down (label)
            y_expanded.append(y[u])
            u += 1
    t_u_indices.append((T - 1, U))
    y_expanded.append(NULL_INDEX)

    t_indices = [t for (t, u) in t_u_indices]
    u_indices = [u for (t, u) in t_u_indices]
    encoder_out_expanded = encoder_out[t_indices]
    predictor_out_expanded = predictor_out[u_indices]
    joiner_out = self.joiner.forward(
        encoder_out_expanded, predictor_out_expanded
    ).log_softmax(1)
    logprob = -torch.nn.functional.nll_loss(
        input=joiner_out,
        target=torch.tensor(y_expanded).long().to(self.device),
        reduction="sum",
    )
    return logprob


Transducer.compute_single_alignment_prob = compute_single_alignment_prob

In [None]:
from torchaudio.functional import rnnt_loss


def compute_loss(self, x, y, T, U):
    encoder_out = self.encoder(x)
    predictor_out = self.predictor(y)
    # Compute the joiner logits without applying log_softmax manually:
    logits = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1))
    logits = logits.float()
    # Ensure lengths are on the same device as logits.
    T = T.to(logits.device)
    U = U.to(logits.device)
    loss = rnnt_loss(
        logits,  # shape: (batch, T_max, U_max+1, num_outputs)
        y.to(torch.int32),  # shape: (batch, U_max+?); targets padded with zeros
        T.to(torch.int32),  # tensor of input lengths
        U.to(torch.int32),  # tensor of target lengths
        blank=NULL_INDEX,  # using your defined NULL_INDEX (should match the blank label)
        clamp=-1,  # default clamp value (no clamping)
        reduction="mean",
        fused_log_softmax=True,
    )
    return loss


Transducer.compute_loss = compute_loss

In [None]:
# Generate example inputs/outputs
num_outputs = len(string.ascii_uppercase) + 1  # [null, A, B, ... Z]
model = Transducer(1, num_outputs)
y_letters = "CAT"
y = (
    torch.tensor([string.ascii_uppercase.index(l) + 1 for l in y_letters])
    .unsqueeze(0)
    .to(model.device)
)
T = torch.tensor([4])
U = torch.tensor([len(y_letters)])
B = 1


encoder_out = torch.randn(B, T, joiner_dim).to(model.device)
predictor_out = torch.randn(B, U + 1, joiner_dim).to(model.device)
joiner_out = model.joiner.forward(
    encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)
).log_softmax(3)

#######################################################
# Compute loss by enumerating all possible alignments #
#######################################################
all_permutations = list(itertools.permutations([0] * (T - 1) + [1] * U))
all_distinct_permutations = list(Counter(all_permutations).keys())
alignment_probs = []
for z in all_distinct_permutations:
    alignment_prob = model.compute_single_alignment_prob(
        encoder_out[0], predictor_out[0], T.item(), U.item(), z, y[0]
    )
    alignment_probs.append(alignment_prob)
loss_enumerate = -torch.tensor(alignment_probs).logsumexp(0)

#######################################################
# Compute loss using the forward algorithm            #
#######################################################
loss_forward = -model.compute_forward_prob(joiner_out, T, U, y)

#######################################################
# Compute loss using torchaudio implementation        #
#######################################################


loss_torchaudio = rnnt_loss(
    joiner_out,  # shape: (B, T, U+1, num_outputs)
    y.to(torch.int32).to(joiner_out.device),  # shape: (B, U)
    T.to(torch.int32).to(joiner_out.device),  # tensor of input lengths
    U.to(torch.int32).to(joiner_out.device),  # tensor of target lengths
    blank=NULL_INDEX,  # using the defined NULL_INDEX
    reduction="mean",
    fused_log_softmax=True,
)

print("Loss computed by enumerating all possible alignments: ", loss_enumerate)
print("Loss computed using the forward algorithm: ", loss_forward)
print("Loss computed using torchaudio implementation: ", loss_torchaudio)

Now let's add the greedy search algorithm for predicting an output sequence.

(Note that I've assumed we're using RNNs for the predictor here. You would have to modify this code a bit if you want to use convolutions/self-attention instead.) 
<br/><br/>
<img src="https://lorenlugosch.github.io/images/transducer/greedy-search.png" width="50%">

In [None]:
def greedy_search(self, x, T):
    y_batch = []
    B = len(x)
    encoder_out = self.encoder.forward(x)
    U_max = 200
    for b in range(B):
        t = 0
        u = 0
        y = [self.predictor.start_symbol]
        predictor_state = self.predictor.initial_state.unsqueeze(0)
        while t < T[b] and u < U_max:
            predictor_input = torch.tensor([y[-1]], device=x.device)
            g_u, predictor_state = self.predictor.forward_one_step(
                predictor_input, predictor_state
            )
            f_t = encoder_out[b, t]
            h_t_u = self.joiner.forward(f_t, g_u)
            argmax = h_t_u.max(-1)[1].item()
            if argmax == NULL_INDEX:
                t += 1
            else:  # argmax == a label
                u += 1
                y.append(argmax)
        y_batch.append(y[1:])  # remove start symbol
    return y_batch


Transducer.greedy_search = greedy_search

The code above will work, but training will be very slow because the Transducer loss is written in pure Python. You can use the fast implementation from SpeechBrain instead by running the block below.

# Some utilities

Here we will add a bit of boilerplate code for training and loading data.

In [None]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, lines, batch_size):
        lines = list(filter(("\n").__ne__, lines))

        self.lines = lines  # list of strings
        collate = Collate()
        self.loader = torch.utils.data.DataLoader(
            self, batch_size=batch_size, num_workers=0, shuffle=True, collate_fn=collate
        )

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        line = self.lines[idx].replace("\n", "")
        line = unidecode.unidecode(line)  # remove special characters
        x = "".join(
            c for c in line if c not in "AEIOUaeiou"
        )  # remove vowels from input
        y = line
        return (x, y)


def encode_string(s):
    for c in s:
        if c not in string.printable:
            print(s)
    return [string.printable.index(c) + 1 for c in s]


def decode_labels(l):
    return "".join([string.printable[c - 1] for c in l])


class Collate:
    def __call__(self, batch):
        """
        batch: list of tuples (input string, output string)
        Returns a minibatch of strings, encoded as labels and padded to have the same length.
        """
        x = []
        y = []
        batch_size = len(batch)
        for index in range(batch_size):
            x_, y_ = batch[index]
            x.append(encode_string(x_))
            y.append(encode_string(y_))

        # pad all sequences to have same length
        T = [len(x_) for x_ in x]
        U = [len(y_) for y_ in y]
        T_max = max(T)
        U_max = max(U)
        for index in range(batch_size):
            x[index] += [NULL_INDEX] * (T_max - len(x[index]))
            x[index] = torch.tensor(x[index])
            y[index] += [NULL_INDEX] * (U_max - len(y[index]))
            y[index] = torch.tensor(y[index])

        # stack into single tensor
        x = torch.stack(x).to(torch.int32)
        y = torch.stack(y).to(torch.int32)
        T = torch.tensor(T).to(torch.int32)
        U = torch.tensor(U).to(torch.int32)

        return (x, y, T, U)


with open("war_and_peace.txt", "r") as f:
    lines = f.readlines()

end = round(0.9 * len(lines))
train_lines = lines[:end]
test_lines = lines[end:]
train_set = TextDataset(train_lines, batch_size=64)  # 8)
test_set = TextDataset(test_lines, batch_size=64)  # 8)
train_set.__getitem__(0)

In [None]:
import torch.amp as amp  # ensure we import amp


class Trainer:
    def __init__(self, model, lr):
        self.model = model
        self.lr = lr
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)

    def train(self, dataset, print_interval=20, use_bf16=False, device=None):
        train_loss = 0
        num_samples = 0
        self.model.train()

        pbar = tqdm(dataset.loader)
        for idx, batch in enumerate(pbar):
            x, y, T, U = batch
            # Use non_blocking transfers (ensure your DataLoader uses pin_memory=True)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True).to(torch.int32)
            T = T.to(device, non_blocking=True).to(torch.int32)
            U = U.to(device, non_blocking=True).to(torch.int32)
            batch_size = len(x)
            num_samples += batch_size

            # Forward pass under autocast for BF16 support if enabled.
            with amp.autocast(
                device_type=device,
                dtype=torch.bfloat16 if use_bf16 else torch.float32,
            ):
                joint_out, output_lengths, target_lengths, _ = self.model(
                    x,          # sources
                    T,          # source_lengths
                    y,          # targets
                    U           # target_lengths
                )
                output_lengths = output_lengths.clamp(max=joint_out.size(1))

                loss = rnnt_loss(
                    joint_out.float(),              # shape: (batch, T_max, U_max+1, num_outputs)
                    y,                              # targets (batch, U_max+?)
                    output_lengths,                              # tensor of input lengths
                    target_lengths,                              # tensor of target lengths
                    blank=NULL_INDEX,               # defined blank label
                    reduction="mean",
                    fused_log_softmax=True,
                )

            self.optimizer.zero_grad(set_to_none=True)
            pbar.set_description("%.2f" % loss.item())

            loss.backward()
            self.optimizer.step()

            train_loss += loss.item() * batch_size

        train_loss /= num_samples
        return train_loss

    def test(self, dataset, print_interval=1):
        test_loss = 0
        num_samples = 0
        self.model.eval()
        pbar = tqdm(dataset.loader)
        with torch.no_grad():
            for idx, batch in enumerate(pbar):
                x, y, T, U = batch
                x = x.to(self.model.device)
                y = y.to(self.model.device)
                batch_size = len(x)
                num_samples += batch_size
                loss = self.model.compute_loss(x, y, T, U)
                pbar.set_description("%.2f" % loss.item())
                test_loss += loss.item() * batch_size
        test_loss /= num_samples
        return test_loss

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}, \t Trainable parameters: {trainable_params}")

In [None]:
from torchaudio.models import Emformer
from torchaudio.models.rnnt import _TimeReduction, _Transcriber, RNNT, _Predictor

from typing import List, Optional, Tuple

class _EmformerEncoderEmbed(torch.nn.Module, _Transcriber):
    def __init__(
        self,
        *,
        num_tokens: int,              # New parameter: vocabulary size (e.g., num_chars + 1)
        input_dim: int,               # Dimension of the token embeddings (should match expected feature dim)
        output_dim: int,
        segment_length: int,
        right_context_length: int,
        time_reduction_input_dim: int,
        time_reduction_stride: int,
        transformer_num_heads: int,
        transformer_ffn_dim: int,
        transformer_num_layers: int,
        transformer_left_context_length: int,
        transformer_dropout: float = 0.0,
        transformer_activation: str = "relu",
        transformer_max_memory_size: int = 0,
        transformer_weight_init_scale_strategy: str = "depthwise",
        transformer_tanh_on_mem: bool = False,
    ) -> None:
        super().__init__()
        # Add an embedding layer to map token indices to continuous embeddings.
        self.embedding = torch.nn.Embedding(num_tokens, input_dim)
        self.input_linear = torch.nn.Linear(
            input_dim,
            time_reduction_input_dim,
            bias=False,
        )
        self.time_reduction = _TimeReduction(time_reduction_stride)
        transformer_input_dim = time_reduction_input_dim * time_reduction_stride
        self.transformer = Emformer(
            transformer_input_dim,
            transformer_num_heads,
            transformer_ffn_dim,
            transformer_num_layers,
            segment_length // time_reduction_stride,
            dropout=transformer_dropout,
            activation=transformer_activation,
            left_context_length=transformer_left_context_length,
            right_context_length=right_context_length // time_reduction_stride,
            max_memory_size=transformer_max_memory_size,
            weight_init_scale_strategy=transformer_weight_init_scale_strategy,
            tanh_on_mem=transformer_tanh_on_mem,
        )
        self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
        self.layer_norm = torch.nn.LayerNorm(output_dim)

    def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Assume input is token indices of shape (B, T)
        embedded = self.embedding(input)  # Convert token indices to embeddings, shape: (B, T, input_dim)
        input_linear_out = self.input_linear(embedded)
        time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
        transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
        output_linear_out = self.output_linear(transformer_out)
        layer_norm_out = self.layer_norm(output_linear_out)
        return layer_norm_out, transformer_lengths

    @torch.jit.export
    def infer(
        self,
        input: torch.Tensor,
        lengths: torch.Tensor,
        states: Optional[List[List[torch.Tensor]]],
    ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
        # Same modification for inference: embed token indices before further processing.
        embedded = self.embedding(input)
        input_linear_out = self.input_linear(embedded)
        time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
        (
            transformer_out,
            transformer_lengths,
            transformer_states,
        ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
        output_linear_out = self.output_linear(transformer_out)
        layer_norm_out = self.layer_norm(output_linear_out)
        return layer_norm_out, transformer_lengths, transformer_states

class _JoinerPad(torch.nn.Module):
    def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        elif activation == "tanh":
            self.activation = torch.nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation {activation}")

    def forward(
        self,
        source_encodings: torch.Tensor,   # shape: (B, T, D)
        source_lengths: torch.Tensor,
        target_encodings: torch.Tensor,   # shape: (B, U, D)
        target_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Pad target_encodings with an extra blank token representation.
        # This creates a tensor of shape (B, 1, D) with zeros.
        blank = torch.zeros(target_encodings.size(0), 1, target_encodings.size(2), 
                            device=target_encodings.device, dtype=target_encodings.dtype)
        # Concatenate the blank token to the target encodings.
        target_encodings_padded = torch.cat([target_encodings, blank], dim=1)  # shape: (B, U+1, D)
        
        # Unsqueeze and add to get joint representations.
        joint_encodings = source_encodings.unsqueeze(2) + target_encodings_padded.unsqueeze(1)  # shape: (B, T, U+1, D)
        activation_out = self.activation(joint_encodings)
        output = self.linear(activation_out)
        return output, source_lengths, target_lengths


input_dim=80
encoding_dim=1024
num_symbols=len(string.printable)
segment_length=16
right_context_length=4
time_reduction_input_dim=128
time_reduction_stride=4
transformer_num_heads=8
transformer_ffn_dim=2048
transformer_num_layers=20
transformer_dropout=0.1
transformer_activation="gelu"
transformer_left_context_length=30
transformer_max_memory_size=0
transformer_weight_init_scale_strategy="depthwise"
transformer_tanh_on_mem=True
symbol_embedding_dim=512
num_lstm_layers=3
lstm_layer_norm=True
lstm_layer_norm_epsilon=1e-3
lstm_dropout=0.3
    
encoder = _EmformerEncoderEmbed(
        num_tokens=num_symbols,
        input_dim=input_dim,
        output_dim=encoding_dim,
        segment_length=segment_length,
        right_context_length=right_context_length,
        time_reduction_input_dim=time_reduction_input_dim,
        time_reduction_stride=time_reduction_stride,
        transformer_num_heads=transformer_num_heads,
        transformer_ffn_dim=transformer_ffn_dim,
        transformer_num_layers=transformer_num_layers,
        transformer_dropout=transformer_dropout,
        transformer_activation=transformer_activation,
        transformer_left_context_length=transformer_left_context_length,
        transformer_max_memory_size=transformer_max_memory_size,
        transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
        transformer_tanh_on_mem=transformer_tanh_on_mem,
    )
predictor = _Predictor(
    num_symbols,
    encoding_dim,
    symbol_embedding_dim=symbol_embedding_dim,
    num_lstm_layers=num_lstm_layers,
    lstm_hidden_dim=symbol_embedding_dim,
    lstm_layer_norm=lstm_layer_norm,
    lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
    lstm_dropout=lstm_dropout,
)
joiner = _JoinerPad(encoding_dim, num_symbols)    

# Training the model

Now we will train a model. This will generate some output sequences every 20 batches.

In [None]:
num_chars = len(string.printable)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Number of characters: {num_chars}")

In [None]:
model = Transducer(num_inputs=num_chars + 1, num_outputs=num_chars + 1)
count_parameters(model)

In [None]:
model = RNNT(encoder, predictor, joiner).to(device)
count_parameters(model)

In [None]:
trainer = Trainer(model=model, lr=0.00001)

num_epochs = 1
train_losses = []
test_losses = []

torch.compile(model)
for epoch in range(num_epochs):
    train_loss = trainer.train(train_set, print_interval=10000, use_bf16=False, device=device)
    train_losses.append(train_loss)
    print("Epoch %d: train loss = %f, test loss = %f" % (epoch, train_loss, train_loss))