## Login & Initialization

In [1]:
!pip install --upgrade pip
!pip install --upgrade datasets transformers accelerate soundfile librosa evaluate jiwer tensorboard gradio

Collecting pip
  Downloading pip-24.0-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.2
    Uninstalling pip-23.3.2:
      Successfully uninstalled pip-23.3.2
Successfully installed pip-24.0
Collecting accelerate
  Downloading accelerate-0.29.2-py3-none-any.whl.metadata (18 kB)
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting jiwer
  Downloading jiwer-3.0.3-py3-none-any.whl.metadata (2.6 kB)
Collecting tensorboard
  Downloading tensorboard-2.16.2-py3-none-any.whl.metadata (1.6 kB)
Collecting gradio
  Downloading gradio-4.26.0-py3-none-any.whl.metadata (15 kB)
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl.metadata (29 kB

In [2]:
import os

from kaggle_secrets import UserSecretsClient
from huggingface_hub import HfFolder

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
HfFolder.save_token(hf_token)

model_name_or_path = "openai/whisper-small"
language = "Chinese"
language_abbr = "nan-tw"
task = "transcribe"
device_map = "auto"

In [3]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Prepare Dataset

In [4]:
from datasets import load_dataset, DatasetDict, concatenate_datasets

common_voice_16_1 = load_dataset("mozilla-foundation/common_voice_16_1", language_abbr, split=["train+validation+other", "test"], token=hf_token, trust_remote_code=True)
common_voice_15_0 = load_dataset("mozilla-foundation/common_voice_15_0", language_abbr, split=["train+validation+other", "test"], token=hf_token, trust_remote_code=True)

train_validation_other_combined = concatenate_datasets([
    common_voice_16_1[0],
    common_voice_15_0[0]
])

test_combined = concatenate_datasets([
    common_voice_16_1[1],
    common_voice_15_0[1]
])

common_voice = DatasetDict({
    "train": train_validation_other_combined,
    "test": test_combined
})

print(common_voice)

Downloading builder script:   0%|          | 0.00/8.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/12.3k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/77.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/56.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/40.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/319M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.23M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/864k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/395k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/522k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.50M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/125k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 3665it [00:00, 88988.41it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1679it [00:00, 99393.62it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 2281it [00:00, 101885.02it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 19451it [00:00, 100288.40it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 521it [00:00, 63895.68it/s]


Downloading builder script:   0%|          | 0.00/8.18k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/71.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/44.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/328M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.85M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/666k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/498k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/493k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.63M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/119k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 2821it [00:00, 91052.12it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 2122it [00:00, 96834.10it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 2155it [00:00, 98749.35it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 20038it [00:00, 101537.54it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 496it [00:00, 91104.65it/s]


DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 49776
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 4436
    })
})


In [5]:
from datasets import DatasetDict, load_dataset, Audio
import re

def remove_pinyin(example):
    example['sentence'] = re.sub(r'\（[^)]*\）', '', example['sentence'])
    return example

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes", "variant"])
common_voice = common_voice.map(remove_pinyin)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)

Map:   0%|          | 0/49776 [00:00<?, ? examples/s]

Map:   0%|          | 0/4436 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/49776 [00:00<?, ? examples/s]

2024-04-12 00:05:04.725022: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 00:05:04.725022: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 00:05:04.725085: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 00:05:04.725140: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 00:05:04.860431: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory

Map (num_proc=2):   0%|          | 0/4436 [00:00<?, ? examples/s]

2024-04-12 00:14:07.940053: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 00:14:07.940058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 00:14:07.940104: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 00:14:07.940110: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 00:14:07.941855: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory

## Training and Evaluation

### Data Collator

In [6]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Metrics

In [7]:
import evaluate

metric = evaluate.load("cer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    cer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

2024-04-12 00:15:02.553683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 00:15:02.553740: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 00:15:02.555201: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

### Load a Pre-Trained Checkpoint

In [8]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, device_map=device_map)

config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

In [9]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.language = "<|zh|>"
model.generation_config.task = "transcribe"

### Define the Training Configuration

In [10]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./linshoufanfork-whisper-small-nan-tw",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=100,
    num_train_epochs=1,
    gradient_checkpointing=True,
    save_strategy="steps",
    evaluation_strategy="steps",
    predict_with_generate=True,
    load_best_model_at_end=True,
    save_steps=500,
    eval_steps=500,
    logging_steps=25,
    report_to=["tensorboard"],
    metric_for_best_model="cer",
    greater_is_better=False,
    push_to_hub=True,
    hub_strategy="checkpoint",
    save_total_limit=5,
)

## Train

In [11]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

model.config.use_cache = False
#trainer.train(resume_from_checkpoint="/kaggle/input/whisper-small-epoch-2/whisper-small-taiwanese/checkpoint-6000")
trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Cer
500,0.7938,0.776787,55.834111
1000,0.5845,0.594726,41.152154
1500,0.459,0.513161,37.618349
2000,0.3512,0.470923,35.404721
2500,0.3758,0.436322,33.57781
3000,0.3191,0.42162,32.611015


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=3111, training_loss=0.6242340206525135, metrics={'train_runtime': 34920.6578, 'train_samples_per_second': 1.425, 'train_steps_per_second': 0.089, 'total_flos': 1.436462688632832e+19, 'train_loss': 0.6242340206525135, 'epoch': 1.0})

In [12]:
kwargs = {
    "dataset_tags": ["mozilla-foundation/common_voice_16_1", "mozilla-foundation/common_voice_15_0"],
    "dataset": ["Common Voice 16.1", "Common Voice 15.0"],
    "dataset_args": "config: nan-tw, split: test",
    "language": "nan",
    "model_name": "Whisper Small Taiwanese",
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

trainer.push_to_hub(**kwargs)

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}


events.out.tfevents.1712880914.8d00cfbe3768.26.0:   0%|          | 0.00/33.9k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/linshoufan/linshoufanfork-whisper-small-nan-tw/commit/301e0f24c0c2bfbda9659ebf76fcba4f42dc257b', commit_message='End of training', commit_description='', oid='301e0f24c0c2bfbda9659ebf76fcba4f42dc257b', pr_url=None, pr_revision=None, pr_num=None)