In [6]:
from transformers import (
    AddedToken,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    WhisperConfig,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    WhisperTokenizerFast,
    get_scheduler,
    set_seed,
)
import json
from datasets import Audio
from torch.utils.data import DataLoader, Dataset

In [19]:
feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-large-v3')
tokenizer = WhisperTokenizerFast.from_pretrained('openai/whisper-large-v3')
processor = WhisperProcessor.from_pretrained('openai/whisper-large-v3')
sampling_rate = feature_extractor.sampling_rate
config = WhisperConfig.from_pretrained('openai/whisper-large-v3')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [52]:
class Train(Dataset):
    def __init__(self, file):
        self.data = []
        with open(file) as fopen:
            for l in fopen:
                self.data.append(json.loads(l))

        self.audio = Audio(sampling_rate=16000)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        audio = self.audio.decode_example(
            self.audio.encode_example(
                self.data[item]['audio_filename']))['array']
        inputs = feature_extractor(audio, sampling_rate=sampling_rate)
        if self.data[item]['score_ms'] >= self.data[item]['score_en']:
            input_str = self.data[item]['predict_ms']
        else:
            input_str = self.data[item]['predict_en']

        token_ids = tokenizer(input_str, add_special_tokens=False).input_ids

        return {
            'input_features': inputs.input_features[0],
            'input_length': [len(audio)],
            'labels': token_ids,
        }

In [53]:
train_dataset = Train('sample-set.jsonl')

In [54]:
train_dataset[0]

{'input_features': array([[ 0.2577381 ,  0.12420219, -0.0346607 , ..., -0.02739227,
         -0.04189217,  0.03024662],
        [ 0.3553025 ,  0.22176659,  0.0629037 , ...,  0.07017213,
          0.05567229,  0.12781101],
        [ 0.38467056,  0.5888243 ,  0.55428416, ...,  0.35781485,
          0.14236629,  0.29945827],
        ...,
        [ 0.01472914, -0.4861127 , -0.53696144, ..., -0.39464724,
         -0.51083815, -0.345137  ],
        [-0.01119864, -0.52233994, -0.53696144, ..., -0.53696144,
         -0.53696144, -0.38162374],
        [-0.01916075, -0.5298321 , -0.53696144, ..., -0.53696144,
         -0.53696144, -0.38669133]], dtype=float32),
 'input_length': [480000],
 'labels': [50258,
  50282,
  50360,
  21851,
  27294,
  803,
  2760,
  1063,
  282,
  25835,
  20451,
  17922,
  7834,
  4072,
  647,
  1538,
  803,
  2760,
  1063,
  282,
  803,
  86,
  17017,
  44988,
  963,
  64,
  9160,
  18943,
  16281,
  1706,
  66,
  3504,
  23063,
  28042,
  409,
  9160,
  16281,
  1706

In [55]:
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor ([`Wav2Vec2Processor`])
            The processor used for proccessing the data.
        decoder_start_token_id (:obj: `int`)
            The start-of-sequence token id of the decoder.
        decoder_prev_token_id (:obj: `int`)
            The start-of-prompt token id of the decoder
        input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
            See above for details.
        max_target_length (:obj:`int`, `optional`):
            Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
    """

    processor: Any
    decoder_start_token_id: int
    decoder_prev_token_id: int
    input_padding:str = "max_length"
    target_padding:str = "max_length"
    max_target_length:int = None

    def __call__(
            self, features):
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        model_input_name = self.processor.model_input_names[0]

        # dataloader returns a list of features which we convert to a dict
        input_features = {model_input_name: [feature[model_input_name] for feature in features]}
        label_features = {"input_ids": [feature["labels"] for feature in features]}

        # reformat list to dict and set to pytorch format
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.input_padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            max_length=self.max_target_length,
            padding=self.target_padding,
            return_tensors="pt",
        )

        # shift labels to the right to get decoder input ids
        labels = labels_batch["input_ids"]
        decoder_input_ids = labels[:, :-1]
        labels = labels[:, 1:]
        labels_mask = labels_batch.attention_mask[:, 1:]

        # replace padding with -100 to ignore correctly when computing the loss
        labels = labels.masked_fill(labels_mask.ne(1), -100)

        # replace initial prompt tokens with -100 to ignore correctly when computing the loss
        bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
        prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
        labels = torch.where(prompt_mask, -100, labels)

        batch["labels"] = labels
        batch["decoder_input_ids"] = decoder_input_ids

        return batch

In [56]:
max_label_length = 384
decoder_start_token_id = config.decoder_start_token_id
decoder_prev_token_id = tokenizer.all_special_ids[-3]
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=decoder_start_token_id,
    decoder_prev_token_id=decoder_prev_token_id,
    input_padding="longest",
    target_padding="max_length",
    max_target_length=max_label_length,
)

In [57]:
batch = [train_dataset[i] for i in range(3)]
batch = data_collator(batch)
batch

{'input_features': tensor([[[ 0.2577,  0.1242, -0.0347,  ..., -0.0274, -0.0419,  0.0302],
         [ 0.3553,  0.2218,  0.0629,  ...,  0.0702,  0.0557,  0.1278],
         [ 0.3847,  0.5888,  0.5543,  ...,  0.3578,  0.1424,  0.2995],
         ...,
         [ 0.0147, -0.4861, -0.5370,  ..., -0.3946, -0.5108, -0.3451],
         [-0.0112, -0.5223, -0.5370,  ..., -0.5370, -0.5370, -0.3816],
         [-0.0192, -0.5298, -0.5370,  ..., -0.5370, -0.5370, -0.3867]],

        [[ 0.3003, -0.0457,  0.0751,  ...,  0.0324,  0.0659,  0.1541],
         [ 0.3979,  0.0518,  0.1726,  ...,  0.1299,  0.1635,  0.2517],
         [ 0.3618, -0.0771,  0.2485,  ...,  0.5078,  0.4999,  0.5185],
         ...,
         [ 0.1584, -0.3193, -0.5183,  ..., -0.5183, -0.5183, -0.5183],
         [ 0.1155, -0.3962, -0.5183,  ..., -0.5183, -0.5183, -0.5183],
         [ 0.1029, -0.4082, -0.5183,  ..., -0.5183, -0.5183, -0.5183]],

        [[ 0.7460, -0.0912,  0.0428,  ..., -0.0328,  0.1653, -0.0675],
         [ 0.8435,  0.0064