In [1]:
import k2
import torch
import torch.nn.functional as F
from typing import Union, List, Literal
from transformers import AutoTokenizer, BertModel
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, TrainingArguments, Trainer
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# trying to understand torch gather function
t = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]])
torch.gather(t, 1, torch.tensor([[0, 0, 1],[1, 0, 0],[1, 0, 0],[0, 1, 0]]))

tensor([[ 1,  1,  2],
        [ 5,  4,  4],
        [ 8,  7,  7],
        [10, 11, 10]])

In [41]:
import pickle

with open('sample_input.pkl', 'rb') as file:
    sample_input = pickle.load(file)
    print(sample_input)

with open('sample_output.pkl', 'rb') as file:
    sample_output = pickle.load(file)
    print(sample_output)

emissions = sample_output["logits"]
log_probs = F.log_softmax(emissions[0], dim=-1, dtype=torch.float32)
print(torch.argmax(log_probs, dim=1))
sampled_logits = F.gumbel_softmax(emissions[0], tau=10, hard=True, dim=-1)
path = torch.zeros(sampled_logits.shape)
index_list = torch.argmax(sampled_logits, dim=1)
paths_list = []
for i in index_list:
    paths_list.append([i])
print(torch.tensor(paths_list).shape)

{'input_values': tensor([[-0.0584, -0.0568, -0.0510,  ..., -0.0050, -0.0039, -0.0039]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32), 'labels': tensor([[11, 15, 14, 20, 18, 15, 12, 12,  0, 13,  5,  4,  0, 18,  5,  9, 19,  5,
          4,  9,  5, 20, 20,  0, 15,  7,  0, 18,  5, 16, 18,  5, 19,  5, 14, 20,
          1, 19, 10, 15, 14, 19, 18,  5,  7, 14,  9, 14,  7,  5, 18,  0]])}
CausalLMOutput(loss=tensor(0.0778, requires_grad=True), logits=tensor([[[  5.2293,  -0.6872,  -4.2354,  ...,   3.9036, -15.1366, -15.2423],
         [  7.3447,  -1.7595,  -4.6908,  ...,   5.4498, -15.2072, -15.2022],
         [ -1.0451,  -0.9015,  -3.7738,  ...,   1.9222, -12.4645, -12.4136],
         ...,
         [ -4.7462,  -2.5563,  -5.0150,  ...,   9.7356, -14.3328, -14.3965],
         [ -3.1324,  -2.8448,  -5.0748,  ...,   9.3147, -14.6481, -14.7300],
         [ 11.1722,  -2.1108,  -3.5456,  ...,   3.0746, -15.1404, -15.1994]]],
       requires_grad=True), hidden_states=None, a

In [16]:
sample_output["logits"]

tensor([[[  5.2293,  -0.6872,  -4.2354,  ...,   3.9036, -15.1366, -15.2423],
         [  7.3447,  -1.7595,  -4.6908,  ...,   5.4498, -15.2072, -15.2022],
         [ -1.0451,  -0.9015,  -3.7738,  ...,   1.9222, -12.4645, -12.4136],
         ...,
         [ -4.7462,  -2.5563,  -5.0150,  ...,   9.7356, -14.3328, -14.3965],
         [ -3.1324,  -2.8448,  -5.0748,  ...,   9.3147, -14.6481, -14.7300],
         [ 11.1722,  -2.1108,  -3.5456,  ...,   3.0746, -15.1404, -15.1994]]],
       requires_grad=True)

In [42]:
def sample_n_paths(num_of_paths, softmax_ctc):
    """
    sample n paths with respect to the ctc probabilites
    :param softmax_ctc: model output ctc after softmax
    :return: N sampled paths
    """
    ctc_swapped = softmax_ctc
    paths = [list(torch.utils.data.WeightedRandomSampler(line, num_of_paths, replacement=True)) for line in
                ctc_swapped]
    return paths

t = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [7,8,9]])
print("input tensor shape:", t.shape)
paths = torch.tensor(sample_n_paths(1, t)).type(torch.LongTensor)
paths_list = sample_n_paths(1, t)
print("path:", paths)
print("paths shape:", paths.shape)

def paths_mass_prob(paths, softmax_ctc, model_pred_length, eps: float = 1e-7):
    """
    compute the path probability mass
    :param paths: ctc alignments
    :param softmax_ctc: model logits after softmax
    :param model_pred_length:  max length of all given paths
    :return: avg of the paths probability
    """
    log_indexes_probs = softmax_ctc.gather(1, paths)
    print("after gather shape:", log_indexes_probs.shape)
    for idx, pred_length in enumerate(model_pred_length):
        log_indexes_probs[idx, pred_length:, :] = torch.zeros(
            (log_indexes_probs.shape[1] - pred_length, 1))
    return torch.sum(log_indexes_probs, dim=1) / (model_pred_length.unsqueeze(-1))

mass_prob = paths_mass_prob(paths, t, 3)
print(mass_prob)

input tensor shape: torch.Size([4, 3])
path: tensor([[2],
        [0],
        [1],
        [2]])
paths shape: torch.Size([4, 1])
after gather shape: torch.Size([4, 1])


TypeError: 'int' object is not iterable

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model_name = "NbAiLab/nb-wav2vec2-300m-bokmaal"

processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
processor_woLM = Wav2Vec2Processor.from_pretrained(model_name)

# model = Wav2Vec2ForCTC.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(
    model_name,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)
model_tokenizer = AutoTokenizer.from_pretrained(model_name)


Please use `allow_patterns` and `ignore_patterns` instead.
Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 1687.34it/s]


# k2 debugging

In [4]:
!python3 -m k2.version

Collecting environment information...

k2 version: 1.24.4
Build type: Release
Git SHA1: 8f976a1e1407e330e2a233d68f81b1eb5269fdaa
Git date: Thu Jun 6 02:13:08 2024
Cuda used to build k2: 12.1
cuDNN used to build k2: 
Python version used to build k2: 3.9
OS used to build k2: CentOS Linux release 7.9.2009 (Core)
CMake version: 3.29.3
GCC version: 9.3.1
CMAKE_CUDA_FLAGS: -Wno-deprecated-gpu-targets -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -lineinfo --

In [None]:
ref_texts = [ [0, 2], [1, 2, 3] ]

refs = k2.levenshtein_graph(ref_texts, device=device)

In [None]:
refs = k2.levenshtein_graph([ [1, 2], [1, 2, 3] ], device=device)

In [6]:
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright 2023 Lucky Wong
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""CTC based Minimum Word Error Rate Loss definition."""

class MWERLoss(torch.nn.Module):
    '''Minimum Word Error Rate Loss compuration in k2.

    See equation 2 of https://arxiv.org/pdf/2106.02302.pdf about its definition.
    '''

    def __init__(
        self,
        vocab_size: int,
        subsampling_factor: int,
        search_beam: int = 20,
        output_beam: int = 8,
        min_active_states: int = 30,
        max_active_states: int = 10000,
        temperature: float = 1.0,
        num_paths: int = 100,
        use_double_scores: bool = True,
        nbest_scale: float = 0.5,
        reduction: Literal['none', 'mean', 'sum'] = 'none'
    ) -> None:
        """
        Args:
          search_beam:
            Decoding beam, e.g. 20.  Smaller is faster, larger is more exact
            (less pruning). This is the default value; it may be modified by
            `min_active_states` and `max_active_states`.
          output_beam:
             Beam to prune output, similar to lattice-beam in Kaldi.  Relative
             to best path of output.
          min_active_states:
            Minimum number of FSA states that are allowed to be active on any given
            frame for any given intersection/composition task. This is advisory,
            in that it will try not to have fewer than this number active.
            Set it to zero if there is no constraint.
          max_active_states:
            Maximum number of FSA states that are allowed to be active on any given
            frame for any given intersection/composition task. This is advisory,
            in that it will try not to exceed that but may not always succeed.
            You can use a very large number if no constraint is needed.
          subsampling_factor:
            The subsampling factor of the model.
          temperature:
            For long utterances, the dynamic range of scores will be too large
            and the posteriors will be mostly 0 or 1.
            To prevent this it might be a good idea to have an extra argument
            that functions like a temperature.
            We scale the logprobs by before doing the normalization.
          use_double_scores:
            True to use double precision floating point.
            False to use single precision.
          reduction:
            Specifies the reduction to apply to the output:
            'none' | 'sum' | 'mean'.
            'none': no reduction will be applied.
                    The returned 'loss' is a k2.RaggedTensor, with
                    loss.tot_size(0) == batch_size.
                    loss.tot_size(1) == total_num_paths_of_current_batch
                    If you want the MWER loss for each utterance, just do:
                    `loss_per_utt = loss.sum()`
                    Then loss_per_utt.shape[0] should be batch_size.
                    See more example usages in 'k2/python/tests/mwer_test.py'
            'sum': sum loss of each path over the whole batch together.
            'mean': divide above 'sum' by total num paths over the whole batch.
          nbest_scale:
            Scale `lattice.score` before passing it to :func:`k2.random_paths`.
            A smaller value leads to more unique paths at the risk of being not
            to sample the path with the best score.
          num_paths:
            Number of paths to **sample** from the lattice
            using :func:`k2.random_paths`.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.search_beam = search_beam
        self.output_beam = output_beam
        self.min_active_states = min_active_states
        self.max_active_states = max_active_states

        self.num_paths = num_paths
        self.nbest_scale = nbest_scale
        self.subsampling_factor = subsampling_factor

        self.mwer_loss = k2.MWERLoss(
            temperature=temperature,
            use_double_scores=use_double_scores,
            reduction=reduction
        )

    def forward(
        self,
        emissions: torch.Tensor,
        emissions_lengths: torch.Tensor,
        labels: torch.Tensor,
        labels_length: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
                probability distribution over labels; output of acoustic model.
            labels (torch.FloatTensor): CPU tensor of shape `(batch, label_len)` storing labels.
            emissions_lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
                in time axis of the output Tensor in each batch.
            labels_length (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
                label in each batch.

        Returns:
            torch.FloatTensor:
                Minimum Word Error Rate loss.
        """
        H = k2.ctc_topo(
            max_token=self.vocab_size-1,
            modified=False,
            device=emissions.device,
        )

        supervision_segments = torch.stack(
            (
                torch.tensor(range(emissions_lengths.shape[0])),
                torch.zeros(emissions_lengths.shape[0]),
                emissions_lengths.cpu(),
            ),
            1,
        ).to(torch.int32)

        lattice = get_lattice(
            nnet_output=emissions,
            decoding_graph=H,
            supervision_segments=supervision_segments,
            search_beam=self.search_beam,
            output_beam=self.output_beam,
            min_active_states=self.min_active_states,
            max_active_states=self.max_active_states,
            subsampling_factor=self.subsampling_factor,
        )

        token_ids = []
        for i in range(labels_length.size(0)):
            # token_ids.append(labels[i, : labels_length[i]].cpu().tolist())
            temp = labels[i, : labels_length[i]].cpu().tolist()
            token_ids.append(list(filter(lambda num: num != 0, temp)))

        loss = self.mwer_loss(
            lattice, token_ids,
            nbest_scale=self.nbest_scale,
            num_paths=self.num_paths
        )

        return loss

In [7]:
def get_lattice(
    nnet_output: torch.Tensor,
    decoding_graph: k2.Fsa,
    supervision_segments: torch.Tensor,
    search_beam: float,
    output_beam: float,
    min_active_states: int,
    max_active_states: int,
    subsampling_factor: int = 1,
) -> k2.Fsa:
    """Get the decoding lattice from a decoding graph and neural
    network output.
    Args:
      nnet_output:
        It is the output of a neural model of shape `(N, T, C)`.
      decoding_graph:
        An Fsa, the decoding graph. It can be either an HLG
        (see `compile_HLG.py`) or an H (see `k2.ctc_topo`).
      supervision_segments:
        A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
        Each row contains information for a supervision segment. Column 0
        is the `sequence_index` indicating which sequence this segment
        comes from; column 1 specifies the `start_frame` of this segment
        within the sequence; column 2 contains the `duration` of this
        segment.
      search_beam:
        Decoding beam, e.g. 20.  Smaller is faster, larger is more exact
        (less pruning). This is the default value; it may be modified by
        `min_active_states` and `max_active_states`.
      output_beam:
         Beam to prune output, similar to lattice-beam in Kaldi.  Relative
         to best path of output.
      min_active_states:
        Minimum number of FSA states that are allowed to be active on any given
        frame for any given intersection/composition task. This is advisory,
        in that it will try not to have fewer than this number active.
        Set it to zero if there is no constraint.
      max_active_states:
        Maximum number of FSA states that are allowed to be active on any given
        frame for any given intersection/composition task. This is advisory,
        in that it will try not to exceed that but may not always succeed.
        You can use a very large number if no constraint is needed.
      subsampling_factor:
        The subsampling factor of the model.
    Returns:
      An FsaVec containing the decoding result. It has axes [utt][state][arc].
    """
    dense_fsa_vec = k2.DenseFsaVec(
        nnet_output,
        supervision_segments,
        allow_truncate=subsampling_factor - 1,
    )

    lattice = k2.intersect_dense_pruned(
        decoding_graph,
        dense_fsa_vec,
        search_beam=search_beam,
        output_beam=output_beam,
        min_active_states=min_active_states,
        max_active_states=max_active_states,
    )

    return lattice

In [4]:
import pickle

with open('sample_input.pkl', 'rb') as file:
    sample_input = pickle.load(file)
    print(sample_input)

with open('sample_output.pkl', 'rb') as file:
    sample_output = pickle.load(file)
    print(sample_output)

{'input_values': tensor([[-0.0584, -0.0568, -0.0510,  ..., -0.0050, -0.0039, -0.0039]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32), 'labels': tensor([[11, 15, 14, 20, 18, 15, 12, 12,  0, 13,  5,  4,  0, 18,  5,  9, 19,  5,
          4,  9,  5, 20, 20,  0, 15,  7,  0, 18,  5, 16, 18,  5, 19,  5, 14, 20,
          1, 19, 10, 15, 14, 19, 18,  5,  7, 14,  9, 14,  7,  5, 18,  0]])}
CausalLMOutput(loss=tensor(0.0778, requires_grad=True), logits=tensor([[[  5.2293,  -0.6872,  -4.2354,  ...,   3.9036, -15.1366, -15.2423],
         [  7.3447,  -1.7595,  -4.6908,  ...,   5.4498, -15.2072, -15.2022],
         [ -1.0451,  -0.9015,  -3.7738,  ...,   1.9222, -12.4645, -12.4136],
         ...,
         [ -4.7462,  -2.5563,  -5.0150,  ...,   9.7356, -14.3328, -14.3965],
         [ -3.1324,  -2.8448,  -5.0748,  ...,   9.3147, -14.6481, -14.7300],
         [ 11.1722,  -2.1108,  -3.5456,  ...,   3.0746, -15.1404, -15.1994]]],
       requires_grad=True), hidden_states=None, a

In [5]:
emissions = sample_output["logits"]
labels = sample_input["labels"]

labels_length = torch.zeros((labels.size(dim=0)))
for i in range(labels.size(dim=0)):
    labels_mask = labels[i] >= 0
    labels_length[i] = len(labels[i].masked_select(labels_mask))
labels_length = labels_length.to(torch.int)
print(labels)

emissions_lengths = torch.tensor([emissions.shape[1]])

tensor([[11, 15, 14, 20, 18, 15, 12, 12,  0, 13,  5,  4,  0, 18,  5,  9, 19,  5,
          4,  9,  5, 20, 20,  0, 15,  7,  0, 18,  5, 16, 18,  5, 19,  5, 14, 20,
          1, 19, 10, 15, 14, 19, 18,  5,  7, 14,  9, 14,  7,  5, 18,  0]])


In [10]:
log_probs = F.log_softmax(emissions, dim=-1)

In [11]:
token_ids = []
for i in range(labels_length.size(0)):
    temp = labels[i, : labels_length[i]].cpu().tolist()
    token_ids.append(list(filter(lambda num: num != 0, temp)))

In [99]:
H = k2.ctc_topo(
    max_token=33,
    modified=False,
    device=emissions.device,
)

# isym = k2.SymbolTable.from_str('''
# | 0
# a 1
# b 2
# c 3
# d 4
# e 5
# f 6
# g 7
# h 8
# i 9
# j 10
# k 11
# l 12
# m 13
# n 14
# o 15
# p 16
# q 17
# r 18
# s 19
# t 20
# u 21
# v 22
# w 23
# x 24
# y 25
# z 26
# å 27
# æ 28
# ø 29
# ''')

# osym = k2.SymbolTable.from_str('''
# | 0
# a 1
# b 2
# c 3
# d 4
# e 5
# f 6
# g 7
# h 8
# i 9
# j 10
# k 11
# l 12
# m 13
# n 14
# o 15
# p 16
# q 17
# r 18
# s 19
# t 20
# u 21
# v 22
# w 23
# x 24
# y 25
# z 26
# å 27
# æ 28
# ø 29
# ''')

# H.labels_sym = isym
# H.aux_labels_sym = osym

In [100]:
supervision_segments = torch.stack(
    (
        torch.tensor(range(emissions_lengths.shape[0])),
        torch.zeros(emissions_lengths.shape[0]),
        emissions_lengths.cpu(),
    ),
    1,
).to(torch.int32)

print(log_probs.shape)
print(emissions_lengths)
print(supervision_segments)

lattice = get_lattice(
    nnet_output=log_probs,
    decoding_graph=H,
    supervision_segments=supervision_segments,
    search_beam=20,
    output_beam=8,
    min_active_states=30,
    max_active_states=10000,
    subsampling_factor=1,
)

torch.Size([1, 166, 34])
tensor([166])
tensor([[  0,   0, 166]], dtype=torch.int32)


In [101]:
nbest = k2.Nbest.from_lattice(
            lattice=lattice,
            num_paths=100,
            use_double_scores=True,
            nbest_scale=0.01,
        )

In [102]:
def _get_texts(
    best_paths: k2.Fsa, return_ragged: bool = False
) -> Union[List[List[int]], k2.RaggedTensor]:
    """Extract the texts (as word IDs) from the best-path FSAs.

    Note:
        Used by Nbest.build_levenshtein_graphs during MWER computation.
        Copied from icefall.

    Args:
      best_paths:
        A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
        containing multiple FSAs, which is expected to be the result
        of k2.shortest_path (otherwise the returned values won't
        be meaningful).
      return_ragged:
        True to return a ragged tensor with two axes [utt][word_id].
        False to return a list-of-list word IDs.
    Returns:
      Returns a list of lists of int, containing the label sequences we
      decoded.
    """
    if isinstance(best_paths.aux_labels, k2.RaggedTensor):
        # remove 0's and -1's.
        aux_labels = best_paths.aux_labels.remove_values_leq(-1)
        # TODO: change arcs.shape() to arcs.shape
        aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)

        # remove the states and arcs axes.
        aux_shape = aux_shape.remove_axis(1)
        aux_shape = aux_shape.remove_axis(1)
        aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
    else:
        # remove axis corresponding to states.
        aux_shape = best_paths.arcs.shape().remove_axis(1)
        aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
        # remove 0's and -1's.
        aux_labels = aux_labels.remove_values_leq(-1)

    assert aux_labels.num_axes == 2
    if return_ragged:
        return aux_labels
    else:
        return aux_labels.tolist()

In [103]:
word_ids = _get_texts(nbest.fsa, return_ragged=False)
print(len(word_ids))

100


In [105]:
for i in range(50):
    id_list = word_ids[i]
    id_list_filtered = filter(lambda num: num != 31, id_list)
    chars_temp = model_tokenizer.convert_ids_to_tokens(id_list)
    print(len(id_list))
    print(re.sub(" +", " ", "".join(chars_temp).replace("|", " ")))

166
a k[PAD]on[PAD] t[PAD] ro[PAD]l [PAD] le[PAD] m ed [PAD] r e [PAD]i[PAD] s[PAD]e r[PAD] d i [PAD] je[PAD] t [PAD] t[PAD] [PAD]o[PAD]g [PAD] r e[PAD] p[PAD]re[PAD] s[PAD]e n[PAD] t[PAD] a[PAD] s jo [PAD] n [PAD]s[PAD] r[PAD]eg [PAD] ni[PAD]n[PAD]g[PAD]e[PAD] 
166
n k[PAD]on[PAD] t[PAD]r o[PAD]l [PAD] l[PAD] r med [PAD] r e [PAD]i[PAD] s[PAD] di [PAD] j[PAD] t[PAD] d[PAD] [PAD]o[PAD]g [PAD] [PAD] r[PAD]e[PAD] p[PAD]re[PAD] s[PAD] en [PAD] t [PAD]a[PAD] sj o[PAD] n[PAD] s[PAD] r e[PAD]y[PAD] n[PAD]in g [PAD]e[PAD]r[PAD] 
166
v k[PAD]on[PAD] t[PAD]r o[PAD]l[PAD] l n [PAD] med r a[PAD]i [PAD] s[PAD] er[PAD] d e[PAD] j[PAD]e[PAD]t [PAD] d[PAD]e o[PAD]g [PAD] r e[PAD] p[PAD]r[PAD]e[PAD] s[PAD]e n[PAD] t[PAD]a[PAD] s[PAD]jo[PAD] n [PAD] s[PAD] r[PAD]e[PAD]k[PAD] n i[PAD]n g[PAD] r[PAD] 
166
h k[PAD]on[PAD] t[PAD] ro[PAD]l [PAD] l [PAD] m e[PAD]d[PAD] [PAD] r e[PAD] i[PAD] s[PAD]e [PAD] d i[PAD] h[PAD] et [PAD] t[PAD] o[PAD]g [PAD] r e[PAD] pr e[PAD] s[PAD] e n[PAD] t[PAD] a[PAD] s[PAD]jo [

In [16]:
mwer = MWERLoss(vocab_size=32, subsampling_factor=1, reduction="mean")
masd_loss = mwer(emissions=log_probs, emissions_lengths=emissions_lengths, labels=labels, labels_length=labels_length)
print(masd_loss)

tensor(0.3665, dtype=torch.float64, grad_fn=<MeanBackward0>)


In [None]:
emissions = sample_output["logits"]

H = k2.ctc_topo(
    max_token=31,
    modified=False,
    device=emissions.device,
)

In [None]:
isym = k2.SymbolTable.from_str('''
| 0
a 1
b 2
[PAD] 3
''')

osym = k2.SymbolTable.from_str('''
| 0
a 1
b 2
[PAD] 3
''')

try_topo = k2.ctc_topo(max_token=3, modified=False)

try_topo.labels_sym = isym
try_topo.aux_labels_sym = osym

try_topo.draw('try_topo.svg')

In [None]:
def _get_texts(
    best_paths: k2.Fsa, return_ragged: bool = False
) -> Union[List[List[int]], k2.RaggedTensor]:
    """Extract the texts (as word IDs) from the best-path FSAs.

    Note:
        Used by Nbest.build_levenshtein_graphs during MWER computation.
        Copied from icefall.

    Args:
      best_paths:
        A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
        containing multiple FSAs, which is expected to be the result
        of k2.shortest_path (otherwise the returned values won't
        be meaningful).
      return_ragged:
        True to return a ragged tensor with two axes [utt][word_id].
        False to return a list-of-list word IDs.
    Returns:
      Returns a list of lists of int, containing the label sequences we
      decoded.
    """
    if isinstance(best_paths.aux_labels, k2.RaggedTensor):
        # remove 0's and -1's.
        aux_labels = best_paths.aux_labels.remove_values_leq(0)
        # TODO: change arcs.shape() to arcs.shape
        aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)

        # remove the states and arcs axes.
        aux_shape = aux_shape.remove_axis(1)
        aux_shape = aux_shape.remove_axis(1)
        aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
    else:
        # remove axis corresponding to states.
        aux_shape = best_paths.arcs.shape().remove_axis(1)
        aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
        # remove 0's and -1's.
        aux_labels = aux_labels.remove_values_leq(0)

    assert aux_labels.num_axes == 2
    if return_ragged:
        return aux_labels
    else:
        return aux_labels.tolist()

In [None]:
'''used in MWER loss forward function to get the hyps, function copied from k2/python/k2/nbest.py'''
def build_levenshtein_graphs(self) -> k2.Fsa:
    """Return an FsaVec with axes [utt][state][arc]."""
    word_ids = _get_texts(self.fsa, return_ragged=True)  # get_texts function copied in above!!!
    return k2.levenshtein_graph(word_ids)

# Pytorch Nbest CTC Decoder

In [9]:
from torchaudio.models.decoder import ctc_decoder

In [10]:
def compute_CTCloss_nbest(reference_text, output_logits, input_lengths, asd_model, asd_tokenizer):
    decoder = ctc_decoder(lexicon=None, tokens="tokens.txt", nbest=10, beam_size=100, blank_token="[PAD]",
                          sil_token="|", unk_word="[UNK]")
    targets = []
    target_lengths = []
    log_probs = F.log_softmax(output_logits, dim=-1, dtype=torch.float32).transpose(0, 1)

    for i in range(len(reference_text)):
        ref_text = reference_text[i].replace("[UNK]", "")
        logits = output_logits[i]
        # get nbest hypotheses and rank them
        nbest_list = decoder(logits.type(torch.float32).detach().cpu()[None, :, :])
        nbest_token_list = []
        asd_score_list = [0] * len(nbest_list[0])
        hyp_list = []
        for j, item in enumerate(nbest_list[0]):
            tokens = item.tokens
            for k in range(len(tokens)):
                if tokens[k] == 0:
                    tokens_mod = tokens[k+1:]
                else:
                    break
            chars = decoder.idxs_to_tokens(tokens_mod)
            nbest_token_list.append(tokens_mod)
            hyp_text = re.sub(" +", " ", "".join(chars).replace("|", " "))
            hyp_list.append(hyp_text)
            asd_score_list[j] = compute_asd_score_single_utt(asd_model, asd_tokenizer, ref_text, hyp_text)
        targets.append(torch.tensor(nbest_token_list[np.argmin(asd_score_list)]))
        target_lengths.append(torch.tensor(len(nbest_token_list[np.argmin(asd_score_list)])))

    targets_tensor = torch.cat(targets, dim=0)
    targets_len_tensor = torch.tensor(target_lengths)

    with torch.backends.cudnn.flags(enabled=False):
        loss = F.ctc_loss(
                log_probs,
                targets_tensor,
                input_lengths,
                targets_len_tensor,
                blank=31,
                reduction="mean",
                zero_infinity=True,
                )

    return loss

In [11]:
decoder = ctc_decoder(lexicon=None, tokens="tokens.txt", nbest=10, beam_size=100, blank_token="[PAD]",
                          sil_token="|", unk_word="[UNK]")

In [12]:
logits = sample_output["logits"][0]
log_probs = F.log_softmax(logits, dim=-1)
nbest_list = decoder(logits.type(torch.float32).detach().cpu()[None, :, :])

In [33]:
print(logits[-1][-1])

tensor(-15.1994, grad_fn=<SelectBackward0>)


In [20]:
print(logits)
print(logits.shape)
print(logits[165,0])

tensor([[  5.2293,  -0.6872,  -4.2354,  ...,   3.9036, -15.1366, -15.2423],
        [  7.3447,  -1.7595,  -4.6908,  ...,   5.4498, -15.2072, -15.2022],
        [ -1.0451,  -0.9015,  -3.7738,  ...,   1.9222, -12.4645, -12.4136],
        ...,
        [ -4.7462,  -2.5563,  -5.0150,  ...,   9.7356, -14.3328, -14.3965],
        [ -3.1324,  -2.8448,  -5.0748,  ...,   9.3147, -14.6481, -14.7300],
        [ 11.1722,  -2.1108,  -3.5456,  ...,   3.0746, -15.1404, -15.1994]],
       grad_fn=<SelectBackward0>)
torch.Size([166, 34])
tensor(11.1722, grad_fn=<SelectBackward0>)


In [13]:
nbest_path_probs = torch.zeros((len(nbest_list[0])), requires_grad=True, device=device).double()

for j, item in enumerate(nbest_list[0]):
    path_probs = torch.zeros((len(item.tokens)), requires_grad=True, device=device).double()
    chars = decoder.idxs_to_tokens(item.tokens)
    hyp_text = re.sub(" +", " ", "".join(chars).replace("|", " "))
    print(hyp_text)
    for i, token in enumerate(item.tokens):
        # if item.timesteps[i] < logits.shape[0]:
        if i < len(item.timesteps) - 1:
            start = item.timesteps[i]
            end = item.timesteps[i+1]
            path_probs[i] = logits[start:end,token].sum()
        else:
            path_probs[i] = logits[-1,token]
    nbest_path_probs[j] = path_probs.sum()


print(nbest_path_probs.sum())
print(nbest_path_probs.mean())


 kontroll med reisediett og representasjonsregninger 
 kontroll med reisediett og representasjonsregninger 
 kontroll med reisediette og representasjonsregninger 
 kontroll med reise diett og representasjonsregninger 
 kontroll med reisediette og representasjonsregninger 
 kontroll med reise diett og representasjonsregninger 
 kontroll med reisediett og representasjonsreininger 
 r kontroll med reisediett og representasjonsregninger 
 kontroll med reisedeett og representasjonsregninger 
 kontroll med reisedyett og representasjonsregninger 
tensor(2857.0664, device='cuda:0', dtype=torch.float64, grad_fn=<SumBackward0>)
tensor(285.7066, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)


In [None]:
def compute_nbest_asd(reference_text, output_logits, asd_model, asd_tokenizer, reduction="mean"):
    decoder = ctc_decoder(lexicon=None, tokens="tokens.txt", nbest=10, beam_size=100, blank_token="[PAD]",
                          sil_token="|", unk_word="[UNK]")
    loss = torch.zeros((len(reference_text)), requires_grad=True, device=device).double()
    for i in range(len(reference_text)):
        ref_text = reference_text[i].replace("[UNK]", "")
        logits = output_logits[i]
        # get nbest hypotheses and get path log probs * asd score
        nbest_list = decoder(logits.type(torch.float32).detach().cpu()[None, :, :])
        nbest_asd_loss = torch.zeros((len(nbest_list[0])), requires_grad=True, device=device).double()
        for j, item in enumerate(nbest_list[0]):
            path_probs = torch.zeros((len(item.tokens)), requires_grad=True, device=device).double()
            chars = decoder.idxs_to_tokens(item.tokens)
            hyp_text = re.sub(" +", " ", "".join(chars).replace("|", " "))
            for k, token in enumerate(item.tokens):
                if item.timesteps[k] < logits.shape[0]:
                    start = item.timesteps[k]
                    end = item.timesteps[k+1]
                    path_probs[k] = logits[start:end,token].sum()
                else:
                    path_probs[k] = logits[-1,token]
            nbest_asd_loss[j] = compute_asd_score_single_utt(asd_model, asd_tokenizer, ref_text, hyp_text) * path_probs.sum()
        loss[i] = nbest_asd_loss.mean()
    if reduction == "mean":
        return loss.mean()
    else:
        return loss.sum()