In [None]:
from transformers import WhisperProcessor
from utils.data_dataset import KruWhisperDataset, DataCollatorSpeechSeq2SeqWithPadding
from torch.utils.data import DataLoader
# 1) Load processor
processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small",
    language="ko",
    task="transcribe"
)

feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer


# 2) Load Dataset
train_dataset = KruWhisperDataset(
    csv_path="/workspace/kru_data/train.csv",
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    sampling_rate=16000,
)

test_dataset = KruWhisperDataset(
    csv_path="/workspace/kru_data/test.csv",
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    sampling_rate=16000,
)

# 3) Data Collator (공식)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=tokenizer.bos_token_id,
)

## Metric 정의

In [None]:
from utils.metrics import compute_metrics
import numpy as np

# Example 
# # predictions와 label_ids가 numpy array여야 metrics 코드가 동작합니다.
# pred = type("Pred", (), {})()
# pred.predictions = np.array([[1, 2, 3], [4, 5, 6]])
# pred.label_ids = np.array([[1, 2, 3], [1, 2, 3]])

# print(compute_metrics(pred, tokenizer, "wer"))
# print(compute_metrics(pred, tokenizer, "cer"))

## 훈련 매개변수 및 모델 정의

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

from transformers import Seq2SeqTrainingArguments
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

training_args = Seq2SeqTrainingArguments(
    output_dir="/workspace/kru_data/results/whisper-small-ko",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    
    fp16=False,       # ❗ 반드시 끄기 (bf16 가능하면 켜기)
    bf16=True,        # RTX 4000 계열 GPU에 최적
    gradient_checkpointing=False,  # ❗ True시 에러 충돌
    
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="cer",
    greater_is_better=False,
    push_to_hub=False, # 모델 체크포인트를 허브에 업로드하지 않기 
)


from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.tokenizer,
)

# Training 
trainer.train()