In [None]:
%load_ext autoreload
%autoreload 2

from nemo.collections.asr.models import EncDecMultiTaskModel

## Init the model and change the decoding strategy to greedy

In [None]:
from omegaconf import DictConfig
import torch

torch.autograd.set_grad_enabled(False)

model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash").eval().cpu()
decoding_strategy = DictConfig({
        "strategy": "greedy",
        "return_best_hypothesis": True,
    },)
model.change_decoding_strategy(decoding_strategy)

## Prepare inputs for the encoder

In [None]:
import torchaudio

waveform, sr = torchaudio.load("./audio_2.mp3") # BTW this could be a random array as well, if you don't have any audio files at hand
# Resample to 16kHz
target_sr = 16_000
if sr != target_sr:
    resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
    waveform = resampler(waveform)

## Utility functions that are used within the Nemo's EncDecMultiTaskModel

In [None]:
import torch

def lens_to_mask(lens, max_length):
    """
    Create a mask from a tensor of lengths.
    """
    batch_size = lens.shape[0]
    arange = torch.arange(max_length, device=lens.device)
    mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
    return mask

def mask_padded_tokens(tokens, pad_id):
    mask = tokens != pad_id
    return mask

## Prep input for the decoder

In [None]:
import torch

kwargs = {
    "input_signal": waveform.detach(),
    "length": torch.tensor([waveform.shape[-1]], dtype=torch.int32).detach(),
}

preprocessor_output = model.preprocessor.get_features(**kwargs)
log_mel = preprocessor_output[0]
log_mel_length = preprocessor_output[1]

with torch.no_grad():
    encoded, encoded_len = model.encoder.forward_for_export(audio_signal=log_mel, length=log_mel_length)
    enc_states = encoded.permute(0,2,1)
    enc_states = model.encoder_decoder_proj(enc_states).detach()
    enc_mask = lens_to_mask(encoded_len, enc_states.shape[1]).to(enc_states.dtype).detach()
    input_ids = torch.tensor([ 7,  4, 16, 62, 62,  5,  9, 11, 13], dtype=torch.int64).unsqueeze(0)

## Define a wrapper and export the decoder

In [None]:
class DecoderWrapper(torch.nn.Module):
    def __init__(self, embedding, decoder, classifier):
        super().__init__()
        self.embedding = embedding
        self.decoder = decoder
        self.classifier = classifier

    def forward(
        self,
        decoder_input_ids=None,
        encoder_hidden_states=None,
    ):
        input_ids = decoder_input_ids
        encoder_input_mask = lens_to_mask(
            torch.tensor([encoder_hidden_states.shape[1]], dtype=torch.int32),
            encoder_hidden_states.shape[1],
        ).to(encoder_hidden_states.dtype)
        logits, decoder_mems_list = self._one_step_forward(
            input_ids,
            encoder_hidden_states,
            encoder_input_mask,
            None,  # no past mems yet
            0,
        )

        next_tokens = torch.argmax(logits[:, -1], dim=-1)
        input_ids = torch.cat((input_ids, next_tokens.unsqueeze(1)), dim=-1)
        return input_ids

    def _one_step_forward(
        self,
        decoder_input_ids=None,
        encoder_hidden_states=None,
        encoder_input_mask=None,
        decoder_mems_list=None,
        pos=0,
    ):
        decoder_hidden_states = self.embedding.forward(decoder_input_ids, start_pos=pos)
        decoder_input_mask = mask_padded_tokens(decoder_input_ids, 2).float()

        if encoder_hidden_states is not None:
            decoder_mems_list = self.decoder.forward(
                decoder_hidden_states,
                decoder_input_mask,
                encoder_hidden_states,
                encoder_input_mask,
                decoder_mems_list,
                return_mems=True,
            )
        else:
            decoder_mems_list = self.decoder.forward(
                decoder_hidden_states,
                decoder_input_mask,
                decoder_mems_list,
                return_mems=True,
            )

        logits = self.classifier.forward(hidden_states=decoder_mems_list[-1][:, -1:])
        return logits, decoder_mems_list

In [None]:
decoder = model.decoding.transformer_decoder.decoder
embedding = model.decoding.transformer_decoder.embedding
classifier = model.log_softmax.mlp

wrapper = DecoderWrapper(
    decoder=decoder,
    embedding=embedding,
    classifier=classifier,
).eval()

In [None]:
from torch.export import Dim

dynamic_shapes = {
    "decoder_input_ids": {
        1: Dim("encoder_hidden_state_len", min=1, max=3000)
    },  # Not sure if that is the number we're looking for, but essentially it should be up to 40s
    "encoder_hidden_states": {1: Dim("decode_input_ids", min=1, max=1024)},
}

with torch.no_grad():
    exported = torch.export.export(
        wrapper,
        args=(input_ids, enc_states),
        strict=False,
        dynamic_shapes=dynamic_shapes
    )

In [None]:
exported.module().forward(input_ids, enc_states,)[0].dtype

In [None]:
torch.export.save(exported, './bin/nemo_decoder.pt2')