# EXPLAINING AUDIO MODELS USED IN THE RESEARCH.

## Whisper

In our project, we aimed at finetuning several existing audio models, one of which includes Whisper model by OpenAI.

Whisper is an automatic speech recognition (ASR) system developed by OpenAI. It utilizes a deep learning architecture to convert spoken language into written text. It involves several steps:

1. Data Preparation: Whisper requires a large amount of labeled audio data for training. This data is typically collected and transcribed, creating pairs of audio segments and their corresponding textual transcripts. The data is then preprocessed to extract features that capture relevant information from the audio signals, such as Mel-frequency cepstral coefficients (MFCCs) or spectrograms.

2. Acoustic Modeling: Whisper utilizes a deep neural network architecture, often based on recurrent neural networks (RNNs), to model the relationship between audio features and textual representations. Long Short-Term Memory (LSTM) networks are commonly used due to their ability to capture long-term dependencies in sequential data. The audio features are fed into the network, which predicts the probability distribution over a set of linguistic units (such as phonemes or subword units) at each time step.

3. Language Modeling: To improve the accuracy of transcription, Whisper incorporates a language model that adds linguistic context to the ASR system. Language models can be based on n-gram models, recurrent neural networks, or transformers. These models help resolve ambiguities and improve the overall quality of the transcriptions by considering the likelihood of certain word sequences.

4. Training: The Whisper model is trained using a large amount of paired audio-text data. The training process involves optimizing the model's parameters to minimize the difference between predicted transcriptions and the ground truth transcriptions. This is typically done using gradient-based optimization algorithms such as stochastic gradient descent (SGD) or its variants.

5. Decoding: Once the Whisper model is trained, it can be used for inference on new, unseen audio data. During decoding, the model takes as input the audio features and generates a sequence of predicted linguistic units. This sequence is then transformed into the final text output using decoding techniques such as the Connectionist Temporal Classification (CTC) algorithm or attention mechanisms.

The actual code of Whisper model provided by OpenAI is quite long, and in this part of research we aim at analyzing each of the parts of the model code in detail.

In [None]:
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function

In this part we import all the necessary Python libraries needed to run the following model code.

In [None]:
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int


class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)


class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

Analyzing the above code:

1. @dataclass: This is a decorator from the dataclass module in Python's standard library. It allows you to easily create classes that are primarily used to hold data. In this code, it is used to define the ModelDimensions class as a data class.

2. ModelDimensions: This class represents the dimensions or sizes of various components of a model. It is defined using the dataclass decorator. The class has several attributes such as n_mels, n_audio_ctx, n_audio_state, and so on, which are integers representing the sizes or dimensions of different parts of the model.

3. LayerNorm: This class is a subclass of nn.LayerNorm, which is a PyTorch module for performing layer normalization. The forward method of LayerNorm overrides the base class's forward method. It takes a tensor x as input, applies layer normalization to x.float() (casting it to float), and then returns the normalized tensor. The type(x.dtype) part ensures that the output tensor has the same data type as the input tensor.

4. Linear: This class is a subclass of nn.Linear, which represents a linear transformation (commonly known as a fully connected or dense layer) in a neural network. The forward method of Linear overrides the base class's forward method. It takes a tensor x as input and applies a linear transformation to x using the weights and biases defined in the nn.Linear class. The F.linear function is used to perform the linear transformation, and the resulting tensor is returned.

In [None]:
class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )


def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()

The provided Python code defines several classes and functions related to neural network modules used in a model. Let's break down the code step by step:

1. Conv1d: This class is a subclass of nn.Conv1d, which represents a 1-dimensional convolutional layer in a neural network. The _conv_forward method is implemented to override the base class's _conv_forward method. It performs the convolution operation by calling the _conv_forward method of the base class and passing the input tensor x, weight tensor weight (converted to the same data type as x), and bias tensor (if not None, converted to the same data type as x).

2. Sinusoids: This function generates sinusoids for positional embedding. It takes three arguments: length (the length of the sinusoids), channels (the number of channels in the output tensor), and max_timescale (the maximum timescale value for the sinusoids). It calculates logarithmic timescale increments based on the number of channels and uses them to generate sinusoids for both sine and cosine functions. The resulting sinusoids are concatenated along the channel dimension and returned as a tensor.

3. MultiHeadAttention: This class represents the multi-head attention mechanism in a neural network. It is a subclass of nn.Module. The class constructor takes two arguments: n_state (the input dimension of the attention mechanism) and n_head (the number of attention heads). The class defines several linear layers (self.query, self.key, self.value, self.out) which are used to project the input to the corresponding dimensions for attention calculations.

4. Forward method: This method overrides the base class's forward method. It performs the forward pass of the multi-head attention mechanism. It takes x as the input tensor, xa as an optional auxiliary input tensor, mask as an optional mask tensor, and kv_cache as an optional dictionary used for caching key-value projections for cross-attention. It first applies the query projection (self.query) to x. Depending on the presence of kv_cache and xa, it either performs key-value projections using self.key and self.value, or retrieves them from the cache. Then, it calls the qkv_attention method to compute the attention weights and the weighted sum of values. Finally, it applies the output projection (self.out) to the weighted sum and returns the result along with the attention weights.

5. qkv_attention method: This method performs the core calculations of the multi-head attention mechanism. It takes query tensor q, key tensor k, and value tensor v. It reshapes and permutes the tensors to prepare them for attention calculations. It calculates the attention scores by multiplying the query and key tensors, applies an optional mask, and performs softmax normalization. Finally, it computes the weighted sum of values using the attention scores and returns the result along with the attention scores.

Overall, the code defines classes and functions related to convolutional layers (Conv1d), positional embedding generation (sinusoids), and multi-head attention mechanism (MultiHeadAttention) commonly used in neural network models.

In [None]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x


class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x


class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

The provided Python code defines three classes: ResidualAttentionBlock, AudioEncoder, and TextDecoder. These classes are typically used in models for audio-to-text synthesis.

1. ResidualAttentionBlock: This class represents a residual attention block in the model. It is a subclass of nn.Module. The class constructor takes three arguments: n_state (the input dimension of the attention mechanism), n_head (the number of attention heads), and cross_attention (a boolean flag indicating whether cross-attention is performed). The class defines the following components:

  1.1. self.attn: An instance of the MultiHeadAttention class representing self-attention.

  1.2. self.attn_ln: An instance of the LayerNorm class representing layer normalization applied to the output of self-attention.

  1.3. self.cross_attn (optional): An instance of the MultiHeadAttention class representing cross-attention if cross_attention is True, otherwise None.

  1.4. self.cross_attn_ln (optional): An instance of the LayerNorm class representing layer normalization applied to the output of cross-attention if cross_attention is True, otherwise None.

  1.5. self.mlp: A sequential neural network module consisting of linear layers and GELU activation.

  1.6. self.mlp_ln: An instance of the LayerNorm class representing layer normalization applied to the output of the MLP.

  The forward method performs the forward pass of the residual attention block. It takes the input tensor x, an optional auxiliary input tensor xa, an optional mask tensor mask, and an optional key-value cache dictionary kv_cache. It applies self-attention (self.attn) to x, adds the result to the input tensor x, and applies layer normalization (self.attn_ln). If cross_attention is True, it performs cross-attention (self.cross_attn) using xa as the auxiliary input, adds the result to x, and applies layer normalization (self.cross_attn_ln). Finally, it applies the MLP (self.mlp) to x, adds the result to x, and applies layer normalization (self.mlp_ln). The output tensor x is returned.

2. AudioEncoder: This class represents the audio encoder in the model. It is a subclass of nn.Module. The class constructor takes five arguments: n_mels (the number of mel spectrogram channels), n_ctx (the maximum sequence length), n_state (the dimension of the hidden state), n_head (the number of attention heads), and n_layer (the number of residual attention blocks). The class defines the following components:

  2.1. self.conv1: An instance of the Conv1d class representing the first convolutional layer applied to the mel spectrogram.

  2.2. self.conv2: An instance of the Conv1d class representing the second convolutional layer applied to the output of the first convolutional layer.

  2.3. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)): A buffer tensor containing positional embeddings generated using the sinusoids function.

  2.4. self.blocks: A module list containing n_layer instances of the ResidualAttentionBlock class.

  2.5. self.ln_post: An instance of the LayerNorm class representing layer normalization applied to the output of the residual attention blocks.

  The forward method performs the forward pass of the audio encoder. It takes the input tensor x representing the mel spectrogram. It applies the first convolutional layer (self.conv1), the second convolutional layer (self.conv2), and permutes the dimensions of the tensor. It checks if the shape of the tensor matches the shape of the positional embeddings and adds the positional embeddings to the tensor. Then, it iterates through the residual attention blocks (self.blocks) and applies each block to the tensor. Finally, it applies layer normalization (self.ln_post) to the tensor and returns the result.

3. TextDecoder: This class represents the text decoder in the model. It is a subclass of nn.Module. The class constructor takes five arguments: n_vocab (the number of vocabulary tokens), n_ctx (the maximum sequence length), n_state (the dimension of the hidden state), n_head (the number of attention heads), and n_layer (the number of residual attention blocks). The class defines the following components:

  3.1. self.token_embedding: An embedding layer for the text tokens.

  3.2. self.positional_embedding: A trainable parameter representing the positional embeddings for the text tokens.

  3.3. self.blocks: A module list containing n_layer instances of the ResidualAttentionBlock class with cross_attention set to True.

  3.4. self.ln: An instance of the LayerNorm class representing layer normalization applied to the output of the residual attention blocks.

  3.5. self.register_buffer("mask", mask, persistent=False): A buffer tensor containing a triangular mask used in the self-attention mechanism.

  The forward method performs the forward pass of the text decoder. It takes the input tensor x representing the text tokens, the auxiliary input tensor xa representing the encoded audio features, and an optional key-value cache dictionary kv_cache. It applies token embedding (self.token_embedding) and positional embedding (self.positional_embedding) to the text tokens. It then iterates through the residual attention blocks (self.blocks) and applies each block to the tensor, using the auxiliary input xa and the mask tensor self.mask. After the residual attention blocks, it applies layer normalization (self.ln) to the tensor. Finally, it computes the logits by multiplying the tensor with the transposed weight matrix of the token embedding (self.token_embedding.weight) and returns the logits.

In [None]:
class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # use the last half among the decoder layers for time alignment by default;
        # to use a specific set of heads, see `set_alignment_heads()` below.
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        return self.decoder(tokens, audio_features)

    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        return self.decoder(tokens, self.encoder(mel))

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
        tensors calculated for the previous positions. This method returns a dictionary that stores
        all caches, and the necessary hooks for the key and value projection modules that save the
        intermediate tensors to be reused during later calculations.

        Returns
        -------
        cache : Dict[nn.Module, torch.Tensor]
            A dictionary object mapping the key/value projection modules to its cache
        hooks : List[RemovableHandle]
            List of PyTorch RemovableHandle objects to stop the hooks to be called
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function

The provided Python code defines a class called Whisper, which is a subclass of nn.Module. This class represents a Whisper model used for speech synthesis, specifically converting audio input into text output.

Here is a breakdown of the code:

1. Initialization:

  1.1. The __init__ method takes an input argument dims, which is an instance of the ModelDimensions class. This class contains various dimensions and parameters related to the model.

  1.2. The method initializes the parent class nn.Module using super().__init__().

  1.3. It assigns the dims argument to an instance variable self.dims for later use.

  1.4. It creates an instance of the AudioEncoder class, passing the necessary dimensions from dims, and assigns it to self.encoder.

  1.5. It creates an instance of the TextDecoder class, passing the necessary dimensions from dims, and assigns it to self.decoder.

  1.6. It initializes a tensor alignment_heads with shape (n_text_layer, n_text_head) where n_text_layer and n_text_head are dimensions from dims. The tensor is initialized with all False values except for the last half of the n_text_layer which is set to True. This tensor is registered as a buffer using self.register_buffer() and assigned to self.alignment_heads.

2. Setting Alignment Heads:

  2.1. The set_alignment_heads method takes a byte string dump as input.

  2.2. It decompresses the byte string using gzip and decodes it using base85.

  2.3. It converts the resulting array into a boolean array and reshapes it to (n_text_layer, n_text_head).

  2.4. It registers the reshaped array as a buffer using self.register_buffer() and assigns it to self.alignment_heads.

3. Audio Embedding:

  3.1. The embed_audio method takes a tensor mel representing mel spectrogram as input.

  3.2. It passes the mel tensor through the self.encoder and returns the result.

4. Logits Calculation:

  4.1. The logits method takes two tensors tokens and audio_features as input.

  4.2. It passes the tokens and audio_features tensors through the self.decoder and returns the result.

5. Forward Pass:

  5.1. The forward method takes two tensors mel and tokens as input.

  5.2. It passes the mel tensor through the self.encoder to obtain audio features.

  5.3. It passes the tokens and audio features through the self.decoder to obtain the output logits.

  5.4. It returns a dictionary containing the output logits.
  
6. Properties:

  6.1. The device property returns the device of the model parameters.

  6.2. The is_multilingual property returns a boolean indicating whether the model supports multiple languages based on the vocabulary size.

  6.3. The num_languages property returns the number of languages supported by the model based on the vocabulary size.

7. Key-Value Cache Hooks:

  7.1. The install_kv_cache_hooks method is used to install hooks for caching intermediate tensors during the key-value projection in the MultiHeadAttention module.

  7.2. It takes an optional cache dictionary as input, which stores the key-value caches.

  7.3. It initializes an empty dictionary cache if cache is None.

  7.4. It defines two nested functions: save_to_cache and install_hooks.

  7.5. The save_to_cache function is a hook that saves the intermediate tensors to the cache dictionary.

  7.6. The install_hooks function is used to traverse through the self.decoder and install hooks for the key and value projection modules.

  7.7. The method applies the install_hooks function to the self.decoder module and returns the cache dictionary and a list of hooks.

The remaining part of the code includes references to external functions (detect_language, transcribe, decode) which are not provided in the code snippet. These functions are likely defined elsewhere and serve specific purposes related to the speech synthesis model.

## Wav2Vec2

Wav2Vec2 represents a groundbreaking approach to Automatic Speech Recognition (ASR) developed by the Hugging Face and Facebook AI teams. It's a neural network architecture designed to transcribe speech into text, focusing on self-supervised learning methods to leverage unlabeled audio data efficiently.

Unlike traditional ASR models that heavily rely on supervised learning with paired audio-text data, Wav2Vec2 employs a self-supervised learning paradigm. It learns from raw audio data without requiring aligned transcripts for training. This is achieved through a process known as Contrastive Predictive Coding (CPC), where the model predicts future audio frames from preceding ones within the same audio clip.

The core innovation of Wav2Vec2 lies in its ability to generate representations, called contextualized speech representations, capturing meaningful information from raw audio. It uses convolutional neural networks (CNNs) for feature extraction, hierarchical quantization, and a transformer-based architecture for context aggregation.

These contextualized speech representations are powerful in capturing high-level features from audio, enabling the model to understand phonetic nuances, prosody, and language-specific speech patterns. This leads to better generalization and adaptation to various accents and speaking styles, making it more robust and versatile in diverse linguistic contexts.

Wav2Vec2 has demonstrated impressive results in various speech-related tasks, including speech recognition, speaker identification, and voice activity detection. It has also facilitated the development of language models that integrate both text and speech modalities, further advancing multimodal AI applications.

In this research we used pretrained Wav2Vec2 model and finetuned it not only on competition data, but also on datasets like Shrutilipi, MADASR and ULCA.

Once again, we aim on analyzing the code we produced step-by-step.

In [None]:
import re
import os
from utils import *
import pandas as pd
import json
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC

if __name__ == "__main__":

    data_path = "/path_to_training_data"

    # download and save pretrained
    processor = Wav2Vec2Processor.from_pretrained(
        "ai4bharat/indicwav2vec_v1_bengali"
    )
    model = Wav2Vec2ForCTC.from_pretrained("ai4bharat/indicwav2vec_v1_bengali")

    processor.save_pretrained("pretrained/ai4bharat/indicwav2vec_v1_bengali/")
    model.save_pretrained("pretrained/ai4bharat/indicwav2vec_v1_bengali/")

    train_df = pd.read_csv(os.path.join(data_path, "train_comp_processed.csv"))
    train_shru = pd.read_csv(os.path.join(data_path, "train_shru_processed.csv"))
    train_respin = pd.read_csv(os.path.join(data_path, "train_respin_processed.csv"))
    train_ulca = pd.read_csv(os.path.join(data_path, "train_ulca_processed.csv"))

    all_data = pd.concat([train_df, train_shru, train_respin, train_ulca])
    all_data = all_data.dropna(subset=['sentence'])

    texts = all_data["sentence"].tolist()
    vocab_list = []
    for text in texts:
        vocab_list.extend(list(text))
    vocab_list = list(set(vocab_list))

    old_vocab = json.load(
        open("pretrained/ai4bharat/indicwav2vec_v1_bengali/vocab.json", "rb")
    )
    new_vocab = list(set(vocab_list) - set(old_vocab.keys()) - set([" "]))

    len_old_vocab = len(old_vocab)
    for k in range(0, len(new_vocab)):
        old_vocab[new_vocab[k]] = k + len_old_vocab

    print(len_old_vocab, len(old_vocab))
    print(new_vocab)

    vocab_dict = json.dumps(old_vocab, ensure_ascii=False)
    with open(
        "pretrained/ai4bharat/indicwav2vec_v1_bengali/vocab.json", "w"
    ) as fp:
        fp.write(vocab_dict)

This Python script focuses on utilizing a pretrained model called "indicwav2vec_v1_bengali" from the AI4Bharat project, specifically for Bengali speech processing. It is the ASR Wav2Vec2 model we introduced earlier.

Let's break down the code:

1. Imports: The script imports necessary libraries/modules like re, os, utils, pandas, json, and components from the transformers library (Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC).

2. Pretrained Model Initialization: It initializes a Wav2Vec2 processor and model using the "ai4bharat/indicwav2vec_v1_bengali" pretrained weights and configurations.

3. Saving Pretrained Models: The script saves the initialized processor and model to a specified directory for future use.

4. Data Loading: It loads training data from multiple CSV files (train_comp_processed.csv, train_shru_processed.csv, train_respin_processed.csv, train_ulca_processed.csv) located in the specified data path.

5. Data Concatenation and Cleaning: It concatenates all the loaded dataframes into a single dataframe (all_data). It then drops rows with missing values in the 'sentence' column.

6. Vocabulary Preparation: It prepares a list of unique characters (vocab_list) from the 'sentence' column in the concatenated dataframe. It compares this list with the existing vocabulary (loaded from a JSON file) to extend the vocabulary with new characters encountered in the dataset.

7. Updating Vocabulary: It extends the existing vocabulary with new characters and assigns unique indices to these new characters.

8. Vocabulary Update and Saving: Finally, it updates the JSON file containing the vocabulary with the extended vocabulary and saves it for future use.

In [None]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Config, Wav2Vec2ConformerForCTC
from transformers.modeling_outputs import CausalLMOutput
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import (
    Wav2Vec2Processor,
)

class Wav2Vec2ForCTCV2(Wav2Vec2ForCTC):
    def __init__(self, config):
        super().__init__(config)

    def resize_lm_head(self, new_num_tokens):
        old_lm_head = self.lm_head
        # Build new lm head
        old_num_tokens, old_lm_head_dim = old_lm_head.weight.size()
        new_lm_head_shape = (old_lm_head_dim, new_num_tokens)
        has_new_lm_head_bias = old_lm_head.bias is not None
        new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
        new_lm_head = new_lm_head.to(
            old_lm_head.weight.device, dtype=old_lm_head.weight.dtype
        )
        self._init_weights(new_lm_head)
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)

        # initialize new lm head (in particular added tokens)
        self._init_weights(new_lm_head)

        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[
            :num_tokens_to_copy, :
        ]
        new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[
            :num_tokens_to_copy
        ]
        self.lm_head = new_lm_head


class Wav2Vec2ForCTCV3(Wav2Vec2ForCTC):
    def __init__(self, config):
        super().__init__(config)
        output_hidden_size = (
            config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
        )
        self.lm_head_inter = nn.Linear(output_hidden_size, config.vocab_size)
        # self.dropout_inter = nn.Dropout(config.final_dropout)

    def resize_lm_head(self, new_num_tokens):
        old_lm_head = self.lm_head
        # Build new lm head
        old_num_tokens, old_lm_head_dim = old_lm_head.weight.size()
        new_lm_head_shape = (old_lm_head_dim, new_num_tokens)
        has_new_lm_head_bias = old_lm_head.bias is not None
        new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
        new_lm_head = new_lm_head.to(
            old_lm_head.weight.device, dtype=old_lm_head.weight.dtype
        )
        self._init_weights(new_lm_head)
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)

        # initialize new lm head (in particular added tokens)
        self._init_weights(new_lm_head)

        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[
            :num_tokens_to_copy, :
        ]
        new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[
            :num_tokens_to_copy
        ]
        self.lm_head = new_lm_head
        self.lm_head_inter = new_lm_head

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, CausalLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        """

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        hidden_states = self.dropout(hidden_states)

        logits = self.lm_head(hidden_states)

        #print(type(outputs), len(outputs[0]), len(outputs[1]), len(outputs[2]))
        hidden_states_inter = outputs[2][12]
        logits_inter = self.lm_head_inter(hidden_states_inter)

        loss = None
        if labels is not None:
            if labels.max() >= self.config.vocab_size:
                raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")

            # retrieve loss input_lengths from attention_mask
            attention_mask = (
                attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
            )
            input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)

            # assuming that padded tokens are filled with -100
            # when not being attended to
            labels_mask = labels >= 0
            target_lengths = labels_mask.sum(-1)
            flattened_targets = labels.masked_select(labels_mask)

            # ctc_loss doesn't support fp16
            log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
            log_probs_inter = nn.functional.log_softmax(logits_inter, dim=-1, dtype=torch.float32).transpose(0, 1)

            with torch.backends.cudnn.flags(enabled=False):
                loss = nn.functional.ctc_loss(
                    log_probs,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,
                    zero_infinity=self.config.ctc_zero_infinity,
                )
                loss_inter = nn.functional.ctc_loss(
                    log_probs_inter,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,
                    zero_infinity=self.config.ctc_zero_infinity,
                )
                loss = 0.7*loss + 0.3*loss_inter

        if not return_dict:
            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutput(
            loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
        )

This code defines two custom classes, Wav2Vec2ForCTCV2 and Wav2Vec2ForCTCV3, which extend the base class Wav2Vec2ForCTC from the Hugging Face transformers library. These classes introduce modifications and additional functionality to the Wav2Vec2 model for connectionist temporal classification (CTC).

Wav2Vec2ForCTCV2 includes a method resize_lm_head that adjusts the linear layer in the model responsible for language modeling (lm_head). This adjustment involves copying weights from the existing lm_head to a new linear layer with a potentially different number of tokens. This method allows for adapting the model's output to different vocabulary sizes.

Wav2Vec2ForCTCV3 adds another modification to the base Wav2Vec2ForCTC class. In addition to resizing the lm_head, it introduces a new linear layer called lm_head_inter to handle intermediate language modeling. This class further extends the model's capability to adapt to different vocabularies and potentially introduces an intermediate language modeling stage during the forward pass.

Both classes inherit the forward method from the base class, which performs the main operations during the model's forward pass. This includes computing logits from the model's output, handling CTC loss computation, and returning the loss, logits, hidden states, and attentions if required.

The code also involves handling different token vocabularies and lengths, adjusting the model's output to accommodate varying token sizes, and computing CTC loss for training the model on speech-to-text tasks. Additionally, it ensures compatibility with different configurations and settings for the Wav2Vec2 model.

In [None]:
import pandas as pd
import numpy as np
import json
import os
import ast
import librosa
import glob
import csv
import pickle
from datasets import Audio
from datasets import Dataset
from bnunicodenormalizer import Normalizer
bnorm=Normalizer()
from datasets import concatenate_datasets
from tqdm.auto import tqdm

chars_to_ignore_regex = '[\।\,\?\!\;\:\"\—\‘\'\‚\“\”\…]'
data_path = "/path_to_training_data"
audio_dir = data_path + "/train_mp3s/"
num_workers = 16

def get_audio_length(audio_path):
    w = librosa.load(audio_path, sr=16_000, mono=False)[0]
    if len(w.shape)==2:
        print(w.shape)
    return w.shape[0]

def remove_special_characters(text):
    text = re.sub(chars_to_ignore_regex, ' ', text)
    text = ' '.join(text.split())
    return text

def normalize(text):
    _words = [bnorm(word)['normalized']  for word in text.split()]
    text =  " ".join([word for word in _words if word is not None])
    return text

# Prefilter data according to hosts' recommendation
df_qa = pd.read_csv(os.path.join(data_path, "NISQA_wavfiles.csv"),sep=",")
df_qa.rename(columns={'deg':'id'},inplace=True) ## rename to match other dfs
df_qa['id'] = df_qa['id'].apply(lambda x:x.split('.')[0])  ## remove .wav
df_qa = df_qa[df_qa.mos_pred>1.5]
df = pd.read_csv(os.path.join(data_path, "train_metadata_corrected.csv"),sep=",")
df["path"] = df['id'].apply(lambda x:audio_dir+x+".mp3")
df.set_index('id',inplace=True)
df = df.join(df_qa.set_index('id'), how='inner')
df = df.dropna(subset='yellowking_preds')[df.ykg_wer < 3]

df["sentence"] = [ normalize(x) for x in tqdm(df["sentence"]) ]
df["sentence"] = [ remove_special_characters(x) for x in tqdm(df["sentence_p"]) ]
df["length"] = [ get_audio_length(x) for x in tqdm(df["path"]) ]

df = df[["path", "sentence", "length"]]
df_meta.to_csv(os.path.join(data_path, "train_comp_processed.csv"), index=False)

What does the code above do?

1. Functions:

  1.1. get_audio_length(audio_path): Uses librosa to load audio files and return their lengths in samples.

  1.2. remove_special_characters(text): Removes specific characters (like punctuation) from text data.

  1.3. normalize(text): Normalizes the text data using a Bengali Unicode normalizer.

2. Preprocessing Steps:

  2.1. Loading and Filtering Data:

  Reads data from CSV files (NISQA_wavfiles.csv and train_metadata_corrected.csv);
  Filters and preprocesses the data based on certain conditions, such as removing rows with low 'mos_pred' values, matching IDs, and filtering based on 'ykg_wer' and 'yellowking_preds'.

  2.2. Data Normalization:

  Normalizes the sentences in the DataFrame using the defined normalize function.

  2.3. Removing Special Characters:

  Removes specific characters (specified in the chars_to_ignore_regex pattern) from the preprocessed sentences.

  2.4. Audio Length Calculation:

  Calculates the lengths of audio files using the get_audio_length function.

  2.5. Data Preparation and Export:

  Selects specific columns ("path", "sentence", "length") from the DataFrame.
  Saves the processed DataFrame as a CSV file named "train_comp_processed.csv" in the specified data path.

In [None]:
import torch
import os
import numpy as np
import pandas as pd
import random
import wandb
from tqdm.auto import tqdm
import librosa
import warnings
import argparse
warnings.filterwarnings("ignore")

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from torch.utils.data import DataLoader, IterableDataset
from functools import partial

from torch.utils.data import Dataset

import transformers
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, AutoProcessor, AutoConfig, AutoFeatureExtractor, AutoModelForSpeechSeq2Seq, AutoTokenizer
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments, HfArgumentParser, set_seed
from datasets import load_dataset, load_metric, Audio
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Any, Union, Optional
from models import Wav2Vec2ForCTCV2
from audiomentations import *

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_PROJECT"] = "b_speech"

# Training config class.
class CFG:
    dns_noise_path = "/path_to_DNS_challenge_noise"
    musan_path = "/path_to_MUSAN"

    pretrained_path = "pretrained/ai4bharat/indicwav2vec_v1_bengali"
    save_dir_stage_1 = "ckpt_stage1"
    save_dir_stage_2 = "ckpt_stage2"
    save_dir_stage_3 = "ckpt_stage3"

    sample_rate = 16000
    epochs = 5
    lr = 4e-5

    # Dropout configs for pretrained wav2vec2 model.
    attention_dropout = 0.1
    hidden_dropout = 0.1
    feat_proj_dropout = 0.1
    mask_time_prob = 0.1
    layerdrop = 0.1
    mask_feature_prob = 0.05

    max_input_length_in_sec = 13.0
    min_input_length_in_sec = 2.0

    # Trainer arugments.
    trainer = TrainingArguments(
      output_dir="weights",
      group_by_length=False,
      length_column_name="input_length",
      per_device_train_batch_size=4,
      per_device_eval_batch_size=4,
      gradient_accumulation_steps=1,
      num_train_epochs=epochs,
      gradient_checkpointing=False,
      fp16=True,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_steps=100, # number of bathes after which to log metrics from the model
    report_to="wandb",
      learning_rate=lr,
    weight_decay=0.0025,
      dataloader_num_workers=16,
    warmup_ratio=0.1,
      save_total_limit=5,
      push_to_hub=False,
      load_best_model_at_end=True,
    greater_is_better=False,
    metric_for_best_model='eval_wer',
      lr_scheduler_type="cosine",
      remove_unused_columns=False,
    )

augment_read_speech = Compose([
    TimeStretch(min_rate=0.8, max_rate=2.0, p=0.5, leave_length_unchanged=False),
    RoomSimulator(
        p=0.3),
    OneOf([
        AddBackgroundNoise(
            sounds_path=[
                CFG.dns_noise_path,
            ],
            min_snr_in_db=5.0,
            max_snr_in_db=30.0,
            noise_transform=PolarityInversion(),
            p=1.0
        ),
        AddBackgroundNoise(
            sounds_path=[
                CFG.musan_path,
            ],
            min_snr_in_db=5.0,
            max_snr_in_db=30.0,
            noise_transform=PolarityInversion(),
            p=1.0
        ),
        AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=1.0),
    ], p=0.7),
    Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.2),
    ])

augment_spontaneous_speech = Compose([
    TimeStretch(min_rate=0.8, max_rate=1.1, p=0.3, leave_length_unchanged=False),
    RoomSimulator(
        p=0.3),
    OneOf([
        AddBackgroundNoise(
            sounds_path=[
                CFG.dns_noise_path,
            ],
            min_snr_in_db=5.0,
            max_snr_in_db=30.0,
            noise_transform=PolarityInversion(),
            p=1.0
        ),
        AddBackgroundNoise(
            sounds_path=[
                CFG.musan_path,
            ],
            min_snr_in_db=5.0,
            max_snr_in_db=30.0,
            noise_transform=PolarityInversion(),
            p=1.0
        ),
        AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=1.0),
    ], p=0.3),
    Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.2),
    ])

class BengaliDataset(Dataset):

    def __init__(self, config, df, processor, split):
        self.df = df
        self.cfg = config
        self.arch = config.arch
        self.paths = df['path']
        self.sentences = df['sentence']
        self.sources = df['source']
        self.lengths = df['length'].to_numpy()
        self.len = len(self.df)
        self.sr = 16_000

        self.processor = processor
        self.split = split

    def __len__(self):
        return self.len

    def load_audio(self, idx):
        idx %= len(self.df)
        audio_path = self.paths[idx]
        sentence = self.sentences[idx]
        source = self.sources[idx]
        num_frames = self.lengths[idx]
        concat_augment=False

        wav = librosa.load(audio_path, sr=self.sr, mono=False)[0]
        wav = np.trim_zeros(wav, 'fb')

        if self.split=="train":
            if (num_frames<(self.cfg.max_input_length_in_sec*8000)) and (np.random.uniform() < 0.5):
                num_frames_concat = self.cfg.max_input_length_in_sec*self.sr-num_frames
                possible_indexes = np.where(self.lengths < num_frames_concat)[0]
                if len(possible_indexes)>0:
                    concat_augment=True
                    selected_index = np.random.choice(possible_indexes)
                    audio_path_concat = self.paths[selected_index]
                    sentence_concat = self.sentences[selected_index]
                    wav_concat = librosa.load(audio_path_concat, sr=self.sr, mono=False)[0]
                    wav_concat = np.trim_zeros(wav_concat, 'fb')
                    wav = np.concatenate((wav, wav_concat))
                    sentence = sentence + " " + sentence_concat

            try:
                if source=="spontaneous":
                    wav = augment_spontaneous_speech(samples=wav, sample_rate=self.sr)
                else:
                    wav = augment_read_speech(samples=wav, sample_rate=self.sr)
            except:
                print(audio_path)

        wav = np.expand_dims(wav, axis=0)

        input_values = self.processor(wav, sampling_rate=self.sr).input_values[0]

        input_length = len(input_values)
        with self.processor.as_target_processor():
            labels = self.processor(sentence).input_ids

        return {
            'input_values':input_values,
            'input_length':input_length,
            'labels':labels
        }

    def __getitem__(self, idx):
        if idx >= self.len:
            raise IndexError(f'index {idx} out of range {self.len}')
        return self.load_audio(idx)

This code is a training pipeline for a speech recognition task using the Wav2Vec2 model in PyTorch and Hugging Face's Transformers library. Here's a detailed explanation:

1. Libraries and Setup:
The code imports various libraries required for audio processing, data handling, model training, and augmentation, including torch, pandas, wandb, librosa, and others.

2. Configuration Class:
The CFG class holds configuration parameters for the training process, such as paths to noise datasets, sample rate, epochs, learning rate, augmentation parameters, and training arguments for the Trainer.

3. Augmentation Setup:
Defines augmentation pipelines (augment_read_speech and augment_spontaneous_speech) using the audiomentations library to simulate various conditions like time stretching, room simulation, adding background noise, and adjusting gain.

4. Custom Dataset Class:

  4.1. BengaliDataset is a custom dataset class inheriting from torch.utils.data.Dataset.

  4.2. It loads audio samples, performs augmentation (if training), tokenizes sentences, and prepares data for training the Wav2Vec2 model.

  4.3. __len__() returns the length of the dataset.

  4.4. load_audio() loads and preprocesses audio files, handling augmentation, concatenation, and tokenization.

  4.5. __getitem__() returns a dictionary containing input values, input length, and labels for the model.

5. Data Preprocessing and Augmentation:

  5.1. Audio files are loaded using librosa, trimmed, and augmented based on specific conditions defined in the dataset class.

  5.2. Tokenization and processing of sentences are performed using the Wav2Vec2 processor.

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

def main():
    # seed
    seed = 20
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    parser = argparse.ArgumentParser(
        description="ASR model training."
    )
    parser.add_argument(
        "--data_path",
        help="Path to folder with csv transcription file",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--stage",
        help="Training stage",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--filter_csv_path",
        help="Path to csv file for data filtering",
        default="",
        type=str,
    )

    args = parser.parse_args()

    if args.stage==2:
        CFG.epochs = 3
        CFG.lr = 3e-5
        CFG.pretrained_path = CFG.save_dir_stage_1
    elif args.stage==3:
        CFG.epochs = 3
        CFG.lr = 2e-5
        CFG.pretrained_path = CFG.save_dir_stage_2

    selected_cols = ["path","sentence","length"]

    train_comp = pd.read_csv(os.path.join(args.data_path, f"train_comp_processed.csv"))
    train_comp = train_comp[selected_cols]
    train_comp["source"] = "read"

    ext_file = f"train_shru_processed.csv"
    print(f"Loading {ext_file}...")
    train_shru = pd.read_csv(os.path.join(args.data_path, ext_file))
    train_shru = train_shru[selected_cols]
    train_shru["source"] = "spontaneous"

    ext_file = f"train_respin_processed.csv"
    print(f"Loading {ext_file}...")
    train_respin = pd.read_csv(os.path.join(args.data_path, ext_file))
    train_respin = train_respin[selected_cols]
    train_respin["source"] = "read"

    ext_file = f"train_ulca_processed.csv"
    print(f"Loading {ext_file}...")
    train_ulca = pd.read_csv(os.path.join(args.data_path, ext_file))
    train_ulca = train_ulca[selected_cols]
    train_ulca["source"] = "spontaneous"

    ext_file = f"valid_kathbath_processed.csv"
    print(f"Loading {ext_file}...")
    valid_kathbath = pd.read_csv(os.path.join(args.data_path, ext_file))
    valid_kathbath = valid_kathbath[selected_cols]
    valid_kathbath["source"] = "read"

    train_df = pd.concat([train_comp, train_respin, train_ulca, train_shru])
    valid_df = valid_kathbath

    train_df = train_df[train_df.length> (CFG.min_input_length_in_sec * 16000)]
    train_df = train_df[train_df.length< (CFG.max_input_length_in_sec * 16000)]

    train_df = train_df.dropna(subset=['sentence'])
    train_df = train_df.query('sentence.str.len() > 5')
    valid_df = valid_df.dropna(subset=['sentence'])
    valid_df = valid_df.query('sentence.str.len() > 5')

    if args.filter_csv_path != "":
        noise_df = pd.read_csv(args.filter_csv_path)
        # Filter 10% noisiest data
        noise_df = noise_df.sort_values('wer', ascending=False)[:int(noise_df.shape[0]*0.1)]
        noise_df = noise_df["path"].tolist()
        train_df = train_df[~train_df.path.isin(noise_df)]

    train_df = train_df.reset_index()
    valid_df = valid_df.reset_index()
    print("Split length: ", len(train_df), len(valid_df))

    train_df = train_df[["path","sentence", "length", "source"]]
    valid_df = valid_df[["path","sentence", "length", "source"]]

    processor = Wav2Vec2Processor.from_pretrained(CFG.pretrained_path)

    train_dataset = BengaliDataset(CFG, train_df, processor, "train")
    valid_dataset = BengaliDataset(CFG, valid_df, processor, "valid")

    data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

    wer_metric = load_metric("wer")

    # Loading model.
    print("Loading model...")
    model = Wav2Vec2ForCTCV2.from_pretrained(
        CFG.pretrained_path,
        ignore_mismatched_sizes=True,
        attention_dropout=CFG.attention_dropout,
        hidden_dropout=CFG.hidden_dropout,
        feat_proj_dropout=CFG.feat_proj_dropout,
        mask_time_prob=CFG.mask_time_prob,
        mask_feature_prob=CFG.mask_feature_prob,
        layerdrop=CFG.layerdrop,
        ctc_loss_reduction="mean",
        pad_token_id=processor.tokenizer.pad_token_id,
    )
    model.config.ctc_zero_infinity = True

    new_vocab_size = len(processor.tokenizer.get_vocab())
    print("New vocab size: ", new_vocab_size)
    model.resize_lm_head(new_num_tokens=len(processor.tokenizer.get_vocab()))
    model.config.vocab_size = new_vocab_size

    trainer = Trainer(
        model=model,
        data_collator=data_collator,
        args=CFG.trainer,
        compute_metrics=compute_metrics,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        tokenizer=processor.feature_extractor,
    )

    print("Start training...")
    trainer.train(
        # "weights/checkpoint-27036"
    )

    if args.stage==1:
        final_dir = CFG.save_dir_stage_1
    elif args.stage==2:
        final_dir = CFG.save_dir_stage_2
    else:
        final_dir = CFG.save_dir_stage_3

    trainer.save_model(final_dir)
    processor.save_pretrained(final_dir)

if __name__ == "__main__":
    main()

This continuation of the code sets up the training process for an Automatic Speech Recognition (ASR) model using Wav2Vec2. Here's a step-by-step explanation:

1. DataCollatorCTCWithPadding Class:

  1.1. This class is a data collator responsible for dynamically padding input sequences.

  1.2. It takes in a Wav2Vec2Processor and a padding strategy argument.

  1.3. The __call__ method pads input and label sequences separately using the processor's pad method, creating a batch of padded sequences suitable for training.

2. compute_metrics Function:

  2.1. This function calculates the Word Error Rate (WER) metric for the model's predictions.

  2.2. It decodes model predictions and references using the processor's batch_decode method and computes the WER using the wer_metric loaded from the datasets library.

3. main Function:

  3.1. Parses command-line arguments related to data paths, training stages, and filtering CSV paths.

  3.2. Based on the training stage, it adjusts the configuration parameters such as epochs, learning rate, and pretrained path.

  3.3. Loads and preprocesses data from CSV files, concatenating various datasets, filtering based on length and quality criteria, and resetting indexes.

  3.4. Creates instances of BengaliDataset using the processed train and validation datasets.

  3.5. Initializes a DataCollatorCTCWithPadding for preparing batches for the model.

  3.6. Loads the Wav2Vec2 model and configures it for training, including resizing the LM head for the updated vocabulary size.

  3.7. Sets up a Trainer object with necessary parameters for training, including datasets, model, data collator, evaluation metrics, and training arguments.

  3.8.Initiates training using the trainer.train() method, which runs the training loop.

  3.9. Saves the final model and processor based on the stage of training.

4. if __name__ == "__main__":

  4.1. Calls the main() function when the script is executed directly.

In [None]:
import typing as tp
from pathlib import Path
from functools import partial
from dataclasses import dataclass, field

import pandas as pd
import pyctcdecode
import numpy as np
from tqdm.auto import tqdm
import argparse

import librosa

import jiwer
import pyctcdecode
import kenlm
import os
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
from bnunicodenormalizer import Normalizer

SAMPLING_RATE = 16_000

class BengaliSRTestDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        audio_paths: list[str],
        sampling_rate: int
    ):
        self.audio_paths = audio_paths
        self.sampling_rate = sampling_rate

    def __len__(self,):
        return len(self.audio_paths)

    def __getitem__(self, index: int):
        audio_path = self.audio_paths[index]
        sr = self.sampling_rate
        w = librosa.load(audio_path, sr=sr, mono=False)[0]

        return w

def main():
    parser = argparse.ArgumentParser(
        description="ASR model validation."
    )
    parser.add_argument(
        "--data_path",
        help="Path to folder with csv transcription file",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_path",
        help="Path to inference model",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--filter_csv_path",
        help="Path to csv file for data filtering",
        default="",
        type=str,
    )

    model = Wav2Vec2ForCTC.from_pretrained(args.model_path)
    processor = Wav2Vec2Processor.from_pretrained(args.model_path)

    selected_cols = ["path","sentence","length"]

    train_comp = pd.read_csv(os.path.join(args.data_path, f"train_comp_processed.csv"))
    train_comp = train_comp[selected_cols]
    train_comp["source"] = "read"

    ext_file = f"train_shru_processed.csv"
    print(f"Loading {ext_file}...")
    train_shru = pd.read_csv(os.path.join(args.data_path, ext_file))
    train_shru = train_shru[selected_cols]
    train_shru["source"] = "spontaneous"

    ext_file = f"train_respin_processed.csv"
    print(f"Loading {ext_file}...")
    train_respin = pd.read_csv(os.path.join(args.data_path, ext_file))
    train_respin = train_respin[selected_cols]
    train_respin["source"] = "read"

    ext_file = f"train_ulca_processed.csv"
    print(f"Loading {ext_file}...")
    train_ulca = pd.read_csv(os.path.join(args.data_path, ext_file))
    train_ulca = train_ulca[selected_cols]
    train_ulca["source"] = "spontaneous"

    train_df = pd.concat([train_comp, train_respin, train_ulca, train_shru])

    train_df = train_df[train_df.length> (CFG.min_input_length_in_sec * 16000)]
    train_df = train_df[train_df.length< (CFG.max_input_length_in_sec * 16000)]

    train_df = train_df.dropna(subset=['sentence'])
    train_df = train_df.query('sentence.str.len() > 5')
    train_df = train_df.reset_index()

    valid = train_df[["path","sentence", "length", "source"]]

    valid_audio_paths = valid["path"].tolist()

    valid_dataset = BengaliSRTestDataset(
        valid_audio_paths, SAMPLING_RATE
    )

    collate_func = partial(
        processor.feature_extractor,
        return_tensors="pt", sampling_rate=SAMPLING_RATE,
        padding=True,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=32, shuffle=False,
        num_workers=16, collate_fn=collate_func, drop_last=False,
        pin_memory=True,
    )

    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")

    model = model.to(device)
    model = model.eval()
    model = model.half()

    pred_sentence_list = []

    with torch.no_grad():
        for batch in tqdm(valid_loader):
            x = batch["input_values"]
            x = x.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(True):
                y = model(x).logits
            y = torch.argmax(y, dim=-1)
            y = y.detach().cpu().numpy()

            for l in y:
                sentence = processor.decode(l)
                pred_sentence_list.append(sentence)

    valid["pred_sentence"] = pred_sentence_list
    valid["wer"] = [
        jiwer.wer(s, p_s)
        for s, p_s in tqdm(valid[["sentence", "pred_sentence"]].values)
    ]

    valid.to_csv(args.filter_csv_path, index=False)
    print(valid["wer"].mean())

if __name__ == "__main__":
    main()

This code is responsible for validating an Automatic Speech Recognition (ASR) model trained using the Wav2Vec2 architecture. Here's a breakdown of the validation part:

1. BengaliSRTestDataset Class:

  1.1. Represents a PyTorch dataset for the validation data.

  1.2. Loads audio paths and reads the audio files using librosa, returning the waveforms.

2. main() Function:

  2.1. Parses command-line arguments related to data paths, model path, and CSV file for data filtering.

  2.2. Loads the pre-trained Wav2Vec2 model and processor using the provided model path.

  2.3. Reads and preprocesses training data from various CSV files, filters based on length and quality criteria, and sets up a validation dataset.

  2.4. Sets up a DataLoader for the validation dataset, using the Wav2Vec2 processor's collate function to prepare batches.

  2.5. Checks for GPU availability and moves the model to the appropriate device (CPU or GPU), setting the model to evaluation mode and using mixed precision (half precision) for better performance.

  2.6. Runs inference on the validation set, iterating through the DataLoader, obtaining model predictions, and decoding them using the processor.

  2.7. Computes Word Error Rate (WER) using the jiwer library between ground truth sentences and predicted sentences.

  2.8. Saves the WER scores to a specified CSV file and prints the mean WER.