# Fine-tuning Whisper For Korean ASR with 🤗 Transformers

> 본 튜토리얼은 [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/fine_tune_whisper.ipynb)을 참고하여 제작되었습니다.

## Whisper

- https://openai.com/blog/whisper/  
- trained on 680,000 hours of multilingual and multitask supervised data collected from the web.
- multiple languages (https://github.com/openai/whisper/blob/main/whisper/tokenizer.py)

<img src="https://cdn.openai.com/whisper/asr-summary-of-model-architecture-desktop.svg">

## Load Dataset
- zeroth_korean 의 한국어 데이터셋 사용

Import wandb

In [None]:
!pip install wandb

import wandb
wandb.login()

Install library

In [None]:
!pip install accelerate
!pip install datasets
!pip install transformers == 4.26.1
!pip install librosa
!pip install evaluate
!pip install jiwer

### load dataset

In [None]:
from datasets import load_dataset

train_dataset = load_dataset("kresnik/zeroth_korean", split='train[:1800]')
test_dataset = load_dataset("kresnik/zeroth_korean", split='test[:200]')

In [None]:
print(train_dataset)
print(test_dataset)

### Remove unnecessary columns

In [None]:
train_ds = train_dataset.remove_columns(["speaker_id", "chapter_id", "id"])
test_ds = test_dataset.remove_columns(["speaker_id", "chapter_id", "id"])

In [None]:
print(train_ds)
print(test_ds)

## Prepare Feature Extractor, Tokenizer and Data
1. A feature extractor which pre-processes the raw audio-inputs 
2. The model which performs the sequence-to-sequence mapping
3. A tokenizer which post-processes the model outputs to text format

### Load WhisperFeatureExtractor

1. 오디오 입력을 30초로 pads/truncate
    - 30초보다 짧은 오디오 입력은 무음(0)으로 30초로 pad하고, 30초보다 긴 오디오 입력은 30초로 truncate
2. 오디오 입력 log-Mel spectorgram으로 변환

In [None]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

### Load WhisperTokenizer

위스퍼 모델은 token id sequence를 결과로 출력하고 Tokenizer는 이러한 각 token id를 를 해당 텍스트 문자열에 매핑합니다. 

tokenizer에 language와 task를 전달 할 수 있다. 

In [None]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean", task="transcribe")

### Combine To Create A WhisperProcessor

feauter extractor 와 tokenizer는 WhisperProcessor 클래스로 wrap되어 train에서는 `processor` 와 `model`만 사용

In [None]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="korean", task="transcribe")

### Prepare Data

In [None]:
train_ds[0]

Resample to 16kHz

`cast_column` : 오디오를 제자리에서 변경하는 것이 아니라, 오디오 샘플을 처음 load할 때 즉시 resample되도록 Dataset에 신호를 보낸다. 

In [None]:
from datasets import Audio

train_ds = train_ds.cast_column("audio", Audio(sampling_rate=16000))
test_ds = test_ds.cast_column("audio", Audio(sampling_rate=16000))

0번 index의 audio sample을 reload하면 16kHz로 resample

In [None]:
train_ds[0]

Audio Sample 확인

In [None]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(train_ds)-1)

ipd.display(ipd.Audio(data=train_ds[rand_int]["audio"]["array"], autoplay=True, rate=16000))

print("Target text:", train_ds[rand_int]["text"])
print("Input array shape:", train_ds[rand_int]["audio"]["array"].shape)
print("Sampling rate:", train_ds[rand_int]["audio"]["sampling_rate"])
print("Input Duration:", train_ds[rand_int]["audio"]["array"].shape[0]/train_ds[rand_int]["audio"]["sampling_rate"])

1. `batch["audio"]`를 호출하여 오디오 데이터를 로드하고 16kHz로 resampling  

2. 로드된 오디오 파일에서 `input_values`값을 추출, 이 단계에서 Log-Mel Spectrogram

3. transcription을 인코딩하여 label ID를 지정

In [None]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch

In [None]:
train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names, num_proc=4)
test_ds = test_ds.map(prepare_dataset, remove_columns=test_ds.column_names, num_proc=4)

In [None]:
train_ds[0]

In [None]:
test_ds[0].keys()

## Training and Evaluation

🤗 Trainer

1. define Data Collector : Data Collector는 pre-processed data를 가져와서 model에 사용할 수 있는 PyTorch tensor를 준비

2. Evaluation metrics: Evaluation 중에 CER 메트릭을 사용하여 model을 평가. 이 계산을 처리하는 compute_metrics 함수를 정의

3. Load a pre-trained checkpoint: pre-trained checkpoint를 load하고 학습을 위해 올바르게 구성

4. Define the training configuration: 🤗 Trainer의 Training schedule을 정의

### Define a Data Collector

1. input_features
    - 이미 `input_features`는 feature extractor로 30s로 pad되고 log-Mel spectrogram으로 변환되었음
    - `input_features`를 PyTorch tensor로 변환을 진행 (`return_tensors=pt`)
2. labels
    - `labels` 는 아직 un-padded. 
    - `labels`를 sequence의 최대 길이까지 패딩
    - `labels`를 PyTorch tensor로 변환을 진행 (`return_tensors=pt`)
    - 패딩된 token은 -100으로 대체되어 손실을 계산할 때 해당 token을 고려하지 않도록 함.
    - Train 과정에서 BOS token (begin of sequence)을 추가하기 때문에 sequence의 시작 부분에서 BOS token을 잘라낸다. 

In [None]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Matrics

- Character Error Rate(CER)

In [None]:
import evaluate

metric = evaluate.load("cer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    cer = metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

### Load a Pre-Trained Checkpoint

- pre-trained `whisper-tiny` checkpoint 로드 
- 자유롭게 output sequence를 생성하도록 model.config 수정

In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-tiny",
)

In [None]:
model.config.forced_decoder_ids = None      # no tokens are forced as decoder outputs 
model.config.suppress_tokens = []           # no tokens are suppressed during generation 

### Define the Training Configuration

Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-tiny-korean",  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    report_to="wandb",
    run_name="whisper-tiny-korean",
    load_best_model_at_end=True,
    metric_for_best_model="cer",
    greater_is_better=False,
)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
processor.save_pretrained(training_args.output_dir)

In [None]:
trainer.train()

model.config.use_cache = False, when using gradient checkpointing during training.
but during inference make sure to set it back to True.

In [None]:
wandb.finish()

## Inference
- spow12/whisper-medium-zeroth_korean [link](https://huggingface.co/spow12/whisper-medium-zeroth_korean)

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import soundfile as sf
import torch
from jiwer import cer

In [None]:
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="ko", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained('spow12/whisper-medium-zeroth_korean').to('cuda')

In [None]:
ds = load_dataset("kresnik/zeroth_korean", "clean")

test_ds = ds['test']

In [None]:
def map_to_array(batch):
    speech, _ = sf.read(batch["file"])
    batch["speech"] = speech
    return batch

test_ds = test_ds.map(map_to_array)

In [None]:
test_ds

In [None]:
def map_to_pred(batch):
    input_features = processor(batch["speech"], sampling_rate=16000, return_tensors="pt").input_features.cuda()
    
    predicted_ids = model.generate(input_features)

    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    batch["transcription"] = transcription
    
    return batch

result = test_ds.map(map_to_pred,batched=True, batch_size=16, remove_columns=["speech"])

In [None]:
ipd.display(ipd.Audio(result[0]['file']))
print("ref  : ", result[0]["text"])
print("trans: ", result[0]['transcription'])

In [None]:
print("CER:", cer(result["text"], result["transcription"]))