In [None]:
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio
!pip install pyannote.audio
!pip install jiwer


In [None]:
# Install cpu/cuda pytorch (>=1.9) dependency from pytorch.org, e.g.:
!pip install torch torchaudio -f https://download.pytorch.org/whl/cpu/torch_stable.html
!pip install deepfilternet
!pip install deepfilternet[train]

!pip install -U bitsandbytes

In [None]:
import os
import pandas as pd
import torchaudio


data_dir = '/kaggle/input/data-100h/vlsp2020_train_set_02' 

MAX_SAMPLES = 10000
TRAIN_SIZE = 0.7
VAL_SIZE = 0.2
TEST_SIZE = 0.1
rows = []

from df.enhance import enhance, init_df, load_audio
import torch

TARGET_SR = 16000

df_model, df_state, _ = init_df()
resampler = torchaudio.transforms.Resample(orig_freq=df_state.sr(), new_freq=TARGET_SR)

def enhance_waveform(noisy_path):
    audio, sr = load_audio(noisy_path, sr=df_state.sr())
    enhanced = enhance(df_model, df_state, audio)

    resampled = resampler(enhanced)
    return resampled, TARGET_SR

df_state.sr()


In [None]:
from huggingface_hub import login

login(token="")

In [None]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="vietnamese", task="transcribe")

processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="vietnamese", task="transcribe")

In [None]:
import evaluate

metric = evaluate.load("wer")

In [None]:
from transformers import WhisperForConditionalGeneration

# Khởi tạo base model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto")

# Cấu hình generation
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

model.generation_config.language = "vietnamese"
model.generation_config.task = "transcribe"


In [None]:
from peft import LoraConfig , LoraConfig, get_peft_model
# Cấu hình LoRA
lora_config = LoraConfig(
    r=8,  # rank của LoRA, có thể tăng lên nếu GPU đủ mạnh
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],  # các module attention
    lora_dropout=0.05,
    bias="none",
    # task_type=TaskType.SEQ_2_SEQ_LM
)

# Áp dụng LoRA vào model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
def get_audio_files():
    """Get list of audio and transcript files"""
    audio_files = []
    transcript_files = []

    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.wav'):
                audio_files.append(os.path.join(root, file))
                transcript_files.append(os.path.join(root, file.replace('.wav', '.txt')))

    # Sort files to ensure consistent ordering
    audio_files.sort()
    transcript_files.sort()

    # Limit to MAX_SAMPLES
    audio_files = audio_files[:MAX_SAMPLES]
    transcript_files = transcript_files[:MAX_SAMPLES]

    return audio_files, transcript_files

def load_sample(audio_path, transcript_path):
    """Load audio and transcript for a single sample"""
    try:
        # Load audio
        waveform, sr = enhance_waveform(audio_path)

        # Load transcript
        with open(transcript_path, 'r', encoding='utf-8') as f:
            transcript = f.read().strip()

        return {
            'audio': waveform.numpy(),
            'transcript': transcript,
            'sample_rate': sr,
        }
    except Exception as e:
        print(f"Error processing {audio_path}: {str(e)}")
        return None

def create_dataset():
    """Create dataset splits from audio files"""
    # Initialize dataset dictionary
    vlsp_dict = DatasetDict()

    # Get audio files
    audio_files, transcript_files = get_audio_files()

    # Process files and create splits
    train_data = []
    val_data = []
    test_data = []

    # Process files with progress bar
    for i in tqdm.tqdm(range(len(audio_files))):
        sample = load_sample(audio_files[i], transcript_files[i])
        if sample is not None:
            if i < TRAIN_SIZE * MAX_SAMPLES:
                train_data.append(sample)
            elif i < (TRAIN_SIZE + VAL_SIZE) * MAX_SAMPLES:
                val_data.append(sample)
            else:
                test_data.append(sample)

    # Create dataset splits
    chunk_size = 1000
    chunks = [Dataset.from_list(train_data[i:i+chunk_size]) for i in range(0, len(train_data), chunk_size)]

    # Concatenate all chunks into one Dataset
    vlsp_dict["train"] = concatenate_datasets(chunks)

    vlsp_dict["val"] = Dataset.from_list(val_data)
    vlsp_dict["test"] = Dataset.from_list(test_data)

    # Print dataset statistics
    print(f"Train set size: {len(vlsp_dict['train'])}")
    print(f"Validation set size: {len(vlsp_dict['val'])}")
    print(f"Test set size: {len(vlsp_dict['test'])}")

    return vlsp_dict

In [None]:
import warnings
from datasets import Dataset, DatasetDict, concatenate_datasets
import tqdm as tqdm
from IPython.display import Audio
warnings.filterwarnings("ignore")

vlsp_dict = create_dataset()

In [None]:
vlsp_dict

In [None]:
from datasets import load_dataset, Audio



def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio[0], sampling_rate=batch["sample_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcript"], truncation=True, max_length=model.max_target_positions).input_ids
    return batch



vlsp_dict = vlsp_dict.map(prepare_dataset, remove_columns=vlsp_dict.column_names["train"], num_proc=1)

In [None]:
vlsp_dict["train"]

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Pad input features
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # Remove bos token if present
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch = {
            "input_features": batch["input_features"],
            "labels": labels,
        }

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    # replace possible -100 values in labels by pad_token_id
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="/kaggle/working/whisper-small-vi",  # change to a repo name of your choice
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    num_train_epochs = 2,
    warmup_steps=500,
    gradient_checkpointing=True,
    fp16=True, 
    per_device_eval_batch_size=4,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
    eval_strategy="epoch",
    save_strategy="epoch",
)

        
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=vlsp_dict["train"],
    eval_dataset=vlsp_dict["val"],
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# Test forward pass
train_dataloader = trainer.get_train_dataloader()
batch = next(iter(train_dataloader))

print("\nBatch keys:", batch.keys())
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k} shape:", v.shape)
        print(f"{k} dtype:", v.dtype)

# Prepare decoder input ids
decoder_input_ids = shift_tokens_right(
    batch["labels"],
    model.config.pad_token_id,
    model.config.decoder_start_token_id
)

# Try forward pass
test_outputs = model.forward(
    input_features=batch["input_features"],
    labels=batch["labels"],
    decoder_input_ids=decoder_input_ids,
    return_dict=True
)

print("\nTest forward pass loss:", test_outputs.loss)

In [None]:
print(batch.keys())


In [None]:
processor.save_pretrained(training_args.output_dir)

In [None]:
trainer.train()

In [None]:
kwargs = {
    "dataset": "VLSP 10000",  # a 'pretty' name for the training dataset
    "language": "vi",
    "model_name": "Whisper Medium Vi - ASR",  # a 'pretty' name for your model
    "finetuned_from": "openai/whisper-medium",
    "tasks": "automatic-speech-recognition",
}

trainer.push_to_hub(**kwargs)
trainer.push_to_hub(**kwargs)