
A data collator class for CTC (Connectionist Temporal Classification) with padding functionality.
This class handles the batching and padding of input features and labels for wav2vec2 model training.
It processes audio features and their corresponding transcription labels, ensuring proper padding
and tensor conversion.
Args:
	processor (Wav2Vec2Processor): The wav2vec2 processor for handling inputs and labels
	padding (Union[bool, str]): The padding strategy to use. Defaults to True.
	max_length (Optional[int]): Maximum length for input features padding. Defaults to None.
	max_length_labels (Optional[int]): Maximum length for labels padding. Defaults to None.
	pad_to_multiple_of (Optional[int]): Pad input features to be multiple of this value. Defaults to None.
	pad_to_multiple_of_labels (Optional[int]): Pad labels to be multiple of this value. Defaults to None.
Methods:
	__call__(features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
		Processes a batch of features to create padded tensors suitable for model training.
		Args:
			features: List of dictionaries containing input values and labels
		Returns:
			Dict containing padded input tensors and processed labels with -100 for padding tokens


In [2]:
import json
import random
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import os
import numpy as np
import pandas as pd
import torch
import torchaudio
import transformers
from datasets import ClassLabel, load_dataset, load_metric, load_from_disk
from transformers import (Trainer, TrainingArguments, Wav2Vec2CTCTokenizer,
                          Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC,
                          Wav2Vec2Processor)

print(torch.cuda.is_available())
print(torch.cuda.device_count())
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))



  from .autonotebook import tqdm as notebook_tqdm


True
1
NVIDIA GeForce RTX 4070 Ti SUPER


In [4]:
import argparse
parser = argparse.ArgumentParser() 
parser.add_argument('--model', type=str, default="facebook/wav2vec2-large-xlsr-53")
parser.add_argument('--unfreeze', action='store_true')
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--warmup', type=float, default=500)
parser.add_argument('-f', '--fff', help="dummy argument to avoid error in Jupyter", default="dummy_value")
args = parser.parse_args()

print(f"args: {args}")



args: Namespace(model='facebook/wav2vec2-large-xlsr-53', unfreeze=False, lr=0.0003, warmup=500, fff='c:\\Users\\westw\\AppData\\Roaming\\jupyter\\runtime\\kernel-v3f64c2e6cf0900acf9997538cf609a7651e8d62c7.json')


In [5]:
# 从本地磁盘加载数据集 Load Cantonese language only 
common_voice_train = load_dataset("mozilla-foundation/common_voice_13_0", "zh-HK", split="train")
common_voice_test = load_dataset("mozilla-foundation/common_voice_13_0", "zh-HK", split="test")

unused_cols = ["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"]
common_voice_train = common_voice_train.remove_columns(unused_cols)
common_voice_test = common_voice_test.remove_columns(unused_cols)



You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
common_voice_test

Dataset({
    features: ['path', 'audio', 'sentence', 'variant'],
    num_rows: 5593
})

In [6]:
# data preprocessing

chars_to_ignore_regex = '[\丶\,\?\.\!\-\;\:"\“\%\‘\”\�\．\⋯\！\－\：\–\。\》\,\）\,\？\；\～\~\…\︰\，\（\」\‧\《\﹔\、\—\／\,\「\﹖\·\']'

import string
def remove_special_characters(batch):
    sen = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
    if "d" in sen:
        if len([c for c in sen if c in string.ascii_lowercase]) == 1:
            sen = sen.replace("d", "啲")
    batch["sentence"] = sen
    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names,)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names,)
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_list = [char for char in vocab_list if not char.isascii()]
vocab_list.append(" ")

vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

with open("vocab.json", "w") as vocab_file:
    json.dump(vocab_dict, vocab_file)


Map:   0%|          | 0/8425 [00:00<?, ? examples/s]

Map: 100%|██████████| 8425/8425 [00:00<00:00, 328875.47 examples/s]
Map: 100%|██████████| 5593/5593 [00:00<00:00, 302541.20 examples/s]


In [None]:
# init tokenizer

# resamplers = {
#     48000: torchaudio.transforms.Resample(48000, 16000),
#     44100: torchaudio.transforms.Resample(44100, 16000),
#     32000: torchaudio.transforms.Resample(32000, 16000), 
# }


# def load_and_resample(batch):
#     speech_array, sampling_rate = torchaudio.load(batch["path"])
#     batch["speech"] = resamplers[sampling_rate](speech_array).squeeze().numpy()
#     batch["sampling_rate"] = 16_000
#     batch["target_text"] = batch["sentence"]
#     return batch

# common_voice_train = common_voice_train.map(load_and_resample, remove_columns=common_voice_train.column_names,)
# common_voice_test = common_voice_test.map(load_and_resample, remove_columns=common_voice_test.column_names,)

# def prepare_dataset(batch):
#     batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
#     with processor.as_target_processor():
#         batch["labels"] = processor(batch["target_text"]).input_ids
#     return batch

# def prepare_dataset_wav2vec2(batch):
#     audio = batch["audio"] # This is a dict: {'array': ..., 'sampling_rate': ...}
#     # The processor handles both resampling (if needed) and feature extraction
#     features = processor(
#         audio["array"],
#         sampling_rate=audio["sampling_rate"],
#         text=batch["sentence"]
#     )
#     batch["input_values"] = features.input_values[0]
#     with processor.as_target_processor():
#         batch["labels"] = processor(batch["sentence"]).input_ids
#     return batch
	
	
# # set batch to false to let processor handle the batching
# common_voice_train = common_voice_train.map(prepare_dataset_wav2vec2, remove_columns=common_voice_train.column_names, batch_size=-1, num_proc=10, batched=False,)
# common_voice_test = common_voice_test.map(prepare_dataset_wav2vec2, remove_columns=common_voice_test.column_names, batch_size=-1, num_proc=10, batched=False,)



[]

In [7]:
# load datasets and resampling, the modern way
from datasets import Audio
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16000))


In [8]:
# print sample rows from common_voice_train
print(common_voice_train[0])

{'path': 'F:\\hf_home\\datasets\\downloads\\extracted\\3ee5ffca136c1c2287060526e62cd5c3b2bdcbca5812d1065a9fab9ec1ecb669\\zh-HK_train_0/common_voice_zh-HK_22942304.mp3', 'audio': {'path': 'F:\\hf_home\\datasets\\downloads\\extracted\\3ee5ffca136c1c2287060526e62cd5c3b2bdcbca5812d1065a9fab9ec1ecb669\\zh-HK_train_0/common_voice_zh-HK_22942304.mp3', 'array': array([ 5.45696821e-12,  2.72848411e-12,  3.63797881e-12, ...,
        1.48210138e-05,  9.73203896e-07, -4.09249424e-06]), 'sampling_rate': 16000}, 'sentence': '才能勇往直前 ', 'variant': ''}


In [9]:
# --- 3. Define the prepare_dataset function (like your Whisper one) ---
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True,)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained("./wav2vec2-large-xlsr-cantonese")

processor


Wav2Vec2Processor:
- feature_extractor: Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "processor_class": "Wav2Vec2Processor",
  "return_attention_mask": true,
  "sampling_rate": 16000
}

- tokenizer: Wav2Vec2CTCTokenizer(name_or_path='', vocab_size=3653, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '[UNK]', 'pad_token': '[PAD]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	3651: AddedToken("[UNK]", rstrip=True, lstrip=True, single_word=False, normalized=False, special=False),
	3652: AddedToken("[PAD]", rstrip=True, lstrip=True, single_word=False, normalized=False, special=False),
	3653: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3654: AddedToken("<

In [10]:
def prepare_dataset_for_batching(batch, processor_obj=None):
    # Extract audio data
    audio_arrays = [item["array"] for item in batch["audio"]]
    sampling_rates = [item["sampling_rate"] for item in batch["audio"]]
    sentences = batch["sentence"]  # List of strings

    # Process audio inputs (without padding)
    model_inputs = processor_obj(
        audio_arrays,
        sampling_rate=sampling_rates[0],
        padding=False,  # Crucial: no padding at this stage
        return_tensors=None,  # Get raw lists instead of tensors
    )

    batch["input_values"] = model_inputs.input_values

    # Process text labels (without padding)
    # Use tokenizer directly with add_special_tokens=False for CTC
    batch["labels"] = processor_obj.tokenizer(
        sentences, 
        add_special_tokens=False,  # No special tokens for CTC
        padding=False,  # No padding - handled by collator
    ).input_ids

    # Calculate audio lengths
    batch['input_length'] = [
        len(arr) / sr 
        for arr, sr in zip(audio_arrays, sampling_rates)
    ]

    return batch

# Then call map like this:
common_voice_train = common_voice_train.map(
    prepare_dataset_for_batching,
    #remove_columns=columns_to_remove_train,
    num_proc=3, # Can now safely increase this for parallel batch processing
    batched=True, # <--- IMPORTANT: Set to True
    fn_kwargs={"processor_obj": processor}, # Still good practice for num_proc > 1
    load_from_cache_file=False
)

common_voice_train[0]  # Check the first entry to see if it worked


Map (num_proc=3): 100%|██████████| 8425/8425 [04:03<00:00, 34.64 examples/s]


{'path': 'F:\\hf_home\\datasets\\downloads\\extracted\\3ee5ffca136c1c2287060526e62cd5c3b2bdcbca5812d1065a9fab9ec1ecb669\\zh-HK_train_0/common_voice_zh-HK_22942304.mp3',
 'audio': {'path': None,
  'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00, -3.05175781e-05]),
  'sampling_rate': 16000},
 'sentence': '才能勇往直前 ',
 'variant': '',
 'input_values': [0.00012427246838342398,
  0.00012427243927959353,
  0.00012427245383150876,
  0.0001242725847987458,
  0.00012427243927959353,
  0.00012427259935066104,
  0.0001242724247276783,
  0.00012427239562384784,
  0.00012427243927959353,
  0.00012427264300640672,
  0.00012427227920852602,
  0.00012427257024683058,
  0.0001242724829353392,
  0.00012427200272213668,
  0.0001242720609297976,
  0.00012427179899532348,
  0.00012427204637788236,
  0.00012427147885318846,
  0.0001242719154106453,
  0.00012427163892425597,
  0.0001242717116838321,
  0.0001242700091097504,
  0.000124267986393533

In [12]:
# Define a data collator for CTC with padding and masking
@dataclass

class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels

        return batch

In [20]:
# Metrics and model initialization, feature extractor, and model loading
import evaluate
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
# Load the built-in CER metric
# cer_metric = load_metric("cer")
cer_metric = evaluate.load("cer")

# def compute_metrics(pred):
#     pred_logits = pred.predictions
#     pred_ids = np.argmax(pred_logits, axis=-1)
#     pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
#     pred_str = processor.batch_decode(pred_ids)
#     label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
#     cer = cer_metric.compute(predictions=pred_str, references=label_str)
#     return {"cer": cer}

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Avoid in-place modification
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, group_tokens=False, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

model = Wav2Vec2ForCTC.from_pretrained(
    args.model,
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

if not args.unfreeze:
    model.freeze_feature_extractor()


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Set training arguments and initialize the Trainer
device = "cuda" if torch.cuda.is_available() else "cpu"

training_args = TrainingArguments(
    output_dir="./wav2vec2-large-xlsr-cantonese",
    group_by_length=True,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    eval_steps=400,
    num_train_epochs=40,
    fp16=True,
    fp16_backend="amp",
    logging_strategy="steps",
    logging_steps=400,
    learning_rate=args.lr,
    warmup_steps=100,
    save_steps=2376,
    save_total_limit=3,
    dataloader_num_workers=20,
  #  optim="adamw_8bit"
)

trainer = Trainer(
    model=model.to(device),
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

trainer.train()

  trainer = Trainer(


RuntimeError: DataLoader worker (pid(s) 35524, 16296, 36476, 25264, 32348, 29900, 26460, 35828, 29052, 17460, 36148, 10988, 35252) exited unexpectedly