In [1]:
import os
import torch
import librosa
import torchaudio
import pandas as pd
from torch.utils.data import DataLoader
from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments
from datasets import Dataset, load_metric
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer


In [2]:
def read_prompts(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split(maxsplit=1)
            if len(parts) == 2:
                data.append((parts[0], parts[1]))
    return pd.DataFrame(data, columns=['filename', 'transcript'])

def load_data(base_path):
    train_path = os.path.join(base_path, 'train')
    test_path = os.path.join(base_path, 'test')
    
    train_prompts = read_prompts(os.path.join(train_path, 'prompts.txt'))
    test_prompts = read_prompts(os.path.join(test_path, 'prompts.txt'))
    
    def create_filepath(filename, base_path):
        speaker_id = filename.split('_')[0]
        return os.path.join(base_path, 'waves', speaker_id, filename + '.wav')
    
    train_prompts['filepath'] = train_prompts['filename'].apply(lambda x: create_filepath(x, train_path))
    test_prompts['filepath'] = test_prompts['filename'].apply(lambda x: create_filepath(x, test_path))
    
    return train_prompts, test_prompts

In [3]:
base_path = 'vivos'
train_data, test_data = load_data(base_path)

In [4]:
train_data

Unnamed: 0,filename,transcript,filepath
0,VIVOSSPK01_R001,KHÁCH SẠN,vivos\train\waves\VIVOSSPK01\VIVOSSPK01_R001.wav
1,VIVOSSPK01_R002,CHỈ BẰNG CÁCH LUÔN NỖ LỰC THÌ CUỐI CÙNG BẠN MỚ...,vivos\train\waves\VIVOSSPK01\VIVOSSPK01_R002.wav
2,VIVOSSPK01_R003,TRONG SỐ CÁC QUỐC GIA CÔNG NGHIỆP PHÁT TRIỂN,vivos\train\waves\VIVOSSPK01\VIVOSSPK01_R003.wav
3,VIVOSSPK01_R004,ANH ĐÃ NHÌN THẤY TRONG NHỮNG LẢI NHẢI DÔNG DÀI...,vivos\train\waves\VIVOSSPK01\VIVOSSPK01_R004.wav
4,VIVOSSPK01_R005,KHỦNG HOẢNG MÔI TRƯỜNG CẦN ĐƯỢC NGĂN CHẶN,vivos\train\waves\VIVOSSPK01\VIVOSSPK01_R005.wav
...,...,...,...
11655,VIVOSSPK46_296,GIỐNG NHƯ NGHĨA CHÚ BÉ BÌNH ĂN NGỦ CÙNG THẾ GI...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_296.wav
11656,VIVOSSPK46_297,TIẾP TỤC CHẤM MỘT ÍT NƯỚC BÓNG LÊN PHẦN GIỮA C...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_297.wav
11657,VIVOSSPK46_298,NHỮNG LẦN ĐÓ CON TỦI THÂN TẠI SAO BA MẸ LẠI LẠ...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_298.wav
11658,VIVOSSPK46_299,TRONG NHỮNG NĂM ĐẦU TIÊN CHẮC CHẮN NĂNG SUẤT L...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_299.wav


In [5]:
train_data.tail()

Unnamed: 0,filename,transcript,filepath
11655,VIVOSSPK46_296,GIỐNG NHƯ NGHĨA CHÚ BÉ BÌNH ĂN NGỦ CÙNG THẾ GI...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_296.wav
11656,VIVOSSPK46_297,TIẾP TỤC CHẤM MỘT ÍT NƯỚC BÓNG LÊN PHẦN GIỮA C...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_297.wav
11657,VIVOSSPK46_298,NHỮNG LẦN ĐÓ CON TỦI THÂN TẠI SAO BA MẸ LẠI LẠ...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_298.wav
11658,VIVOSSPK46_299,TRONG NHỮNG NĂM ĐẦU TIÊN CHẮC CHẮN NĂNG SUẤT L...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_299.wav
11659,VIVOSSPK46_300,MỘT SỐ KHÁC TỚI HỎI THỦ TỤC SAU ĐÓ RA VỀ VÀ CH...,vivos\train\waves\VIVOSSPK46\VIVOSSPK46_300.wav


In [6]:
filepaths = train_data['filepath'].tolist()
transcripts = train_data['transcript'].tolist()

In [7]:
from transformers import Wav2Vec2Processor

model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)


Some weights of the model checkpoint at nguyenvulebinh/wav2vec2-base-vietnamese-250h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at nguyenvulebinh/wav2vec2-base-vietnamese-250h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You

In [8]:
def speech_file_to_array_fn(path):
    speech_array, sampling_rate = torchaudio.load(path)
    return speech_array.squeeze().numpy(), sampling_rate

data = {'path': filepaths, 'transcript': transcripts}
dataset = Dataset.from_dict(data)

In [9]:
def prepare_dataset(batch):
    speech_array, sampling_rate = speech_file_to_array_fn(batch["path"])
    batch["input_values"] = processor(speech_array, sampling_rate=sampling_rate).input_values[0]
    batch["labels"] = processor.tokenizer(batch["transcript"]).input_ids
    return batch

dataset = dataset.map(prepare_dataset, remove_columns=["path"])

Map:   0%|          | 0/11660 [00:00<?, ? examples/s]

In [10]:
def collate_fn(batch):
    input_values = [item['input_values'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Padding input_values
    input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(i, dtype=torch.float32) for i in input_values], batch_first=True)
    
    # Padding labels
    label_lengths = [len(l) for l in labels]
    max_length = max(label_lengths)
    padded_labels = torch.zeros(len(labels), max_length, dtype=torch.int64)
    
    for i, label in enumerate(labels):
        padded_labels[i, :label_lengths[i]] = torch.tensor(label, dtype=torch.int64)
    
    return {'input_values': input_values, 'labels': padded_labels}

train_dataloader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)


In [11]:
small_dataset = dataset.shuffle(seed=42).select(range(500)) 

In [12]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16,
    num_train_epochs=1, 
    save_steps=400,
    save_total_limit=2,
    logging_dir="./logs",
)

# Tạo Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=small_dataset,
    eval_dataset=small_dataset,
)

trainer.train()

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.1458837985992432, 'eval_runtime': 1310.3303, 'eval_samples_per_second': 0.382, 'eval_steps_per_second': 0.024, 'epoch': 1.0}
{'train_runtime': 7357.6522, 'train_samples_per_second': 0.068, 'train_steps_per_second': 0.004, 'train_loss': 3.3344829082489014, 'epoch': 1.0}


TrainOutput(global_step=32, training_loss=3.3344829082489014, metrics={'train_runtime': 7357.6522, 'train_samples_per_second': 0.068, 'train_steps_per_second': 0.004, 'total_flos': 4.3199022981888e+16, 'train_loss': 3.3344829082489014, 'epoch': 1.0})

In [13]:
# Đánh giá mô hình trên tập đánh giá
metrics = trainer.evaluate()
print(metrics)


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.1458837985992432, 'eval_runtime': 1330.522, 'eval_samples_per_second': 0.376, 'eval_steps_per_second': 0.024, 'epoch': 1.0}


In [21]:
from transformers import Wav2Vec2Processor

model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)


Some weights of the model checkpoint at nguyenvulebinh/wav2vec2-base-vietnamese-250h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at nguyenvulebinh/wav2vec2-base-vietnamese-250h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You

In [30]:
import soundfile as sf

def predict(path):
    speech_array, sampling_rate = speech_file_to_array_fn(path)
    
    input_values = processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values
    
    with torch.no_grad():
        logits = model(input_values).logits
    
    predicted_ids = torch.argmax(logits, dim=-1)
    
    transcription = processor.decode(predicted_ids[0])
    
    return transcription

In [32]:
audio_file = 'vivos/test/waves/VIVOSDEV01/VIVOSDEV01_R002.wav'

transcription = predict(audio_file)
print(f"Transcription: {transcription}")

Transcription: tiếng cộc cạch khuẩn lợi của những cốp sắc
