In [1]:
import os
import torch
import torchvision
import torchvision.transforms as t
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import jiwer
import cv2
from typing import List
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset, Dataset
from datasets import load_dataset, Audio, Dataset
import random 
from torchvision.io import read_image
import time
import torchaudio
from transformers import (Wav2Vec2ForCTC, Wav2Vec2Processor, Trainer, TrainingArguments,
                          Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, DataCollatorWithPadding, AutoConfig)

## DATA LOADING

In [4]:
seed = 132
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:

train_audio_path = 'G:/data/vivos/train/waves'
train_prompts_path = 'G:/data/vivos/train/prompts.txt'
train_genders_path = 'G:/data/vivos/train/genders.txt'
test_audio_path = 'G:/data/vivos/test/waves'
test_prompts_path = 'G:/data/vivos/test/prompts.txt'
test_genders_path = 'G:/data/vivos/test/genders.txt'


In [6]:
def load_prompts(prompts_path):
    transcripts = []
    with open(prompts_path, 'r', encoding='utf-8') as f:
        for line in f:
            id, text = line.strip().split(' ', 1)
            transcripts.append({'id': id, 'text': text.lower()})
    return pd.DataFrame(transcripts)


train_transcripts = load_prompts(train_prompts_path)
test_transcripts = load_prompts(test_prompts_path)

In [7]:
def get_audio_path(audio_base_path, audio_id):
    speaker = audio_id.split('_')[0]
    return os.path.join(audio_base_path, speaker, audio_id + '.wav')

In [8]:
train_transcripts['audio'] = train_transcripts['id'].apply(lambda x: get_audio_path(train_audio_path, x))
test_transcripts['audio'] = test_transcripts['id'].apply(lambda x: get_audio_path(test_audio_path, x))

In [9]:
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"“%‘”�]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
    return batch

In [10]:
train_transcripts = train_transcripts.apply(remove_special_characters, axis=1)
test_transcripts = test_transcripts.apply(remove_special_characters, axis=1)

## DATA PIPELINE AND TRAIN TEST SPLIT

In [32]:
train_dataset = Dataset.from_pandas(train_transcripts)
test_dataset = Dataset.from_pandas(test_transcripts)


train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16000))

train_valid_split = train_dataset.train_test_split(test_size=0.1, seed=42)  # 90% train, 10% validation
train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]

In [40]:
processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h") 

In [44]:
from torch.utils.data import DataLoader

def collate_fn(batch, preprocessor, sampling_rate=16000):
    audio_tensors = [torch.tensor(item["audio"]["array"], dtype=torch.float32) for item in batch]
    text_data = [item["text"] for item in batch]

    # Pad audio data
    audio_padded = torch.nn.utils.rnn.pad_sequence(audio_tensors, batch_first=True, padding_value=0)

    # Tokenize text data (convert to input IDs with padding)
    text_encodings = preprocessor(text=text_data, padding=True, return_tensors="pt")
    input_ids = text_encodings.input_ids

    # Ensure that the audio field is passed correctly
    audio_encodings = preprocessor(audio=audio_padded, sampling_rate=sampling_rate, padding=True, return_tensors="pt")

    return {
        "audio": audio_encodings.input_values,
        "input_ids": input_ids,
    }




train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: collate_fn(batch, processor, sampling_rate=16000),
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, processor, sampling_rate=16000),
)



# Check the size of the datasets
print(f"Training Set Size: {len(train_loader.dataset)}")
print(f"Validation Set Size: {len(valid_loader.dataset)}")


Training Set Size: 10494
Validation Set Size: 1166


## MODELS

In [13]:
import torch
import torch.nn as nn
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

class VietLip(nn.Module):

    def __init__(self, num_classes, feature_size, hidden_size,
                num_layers, dropout, bidirectional, device='cpu'):
        super(VietLip, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.directions = 2 if bidirectional else 1
        self.device = device
        self.layernorm = nn.LayerNorm(feature_size)
        self.lstm = nn.LSTM(input_size=feature_size, hidden_size=hidden_size,
                            num_layers=num_layers, dropout=dropout,
                            bidirectional=bidirectional)
        self.classifier = nn.Linear(hidden_size*self.directions, num_classes)

    def _init_hidden(self, batch_size):
        n, d, hs = self.num_layers, self.directions, self.hidden_size
        return (torch.zeros(n*d, batch_size, hs).to(self.device),
                torch.zeros(n*d, batch_size, hs).to(self.device))

    def forward(self, x):
        # x.shape => seq_len, batch, feature
        x = self.layernorm(x)
        hidden = self._init_hidden(x.size()[1])
        out, (hn, cn) = self.lstm(x, hidden)
        out = self.classifier(hn)
        return out


## TRAINING

In [14]:
from torch.optim.lr_scheduler import LambdaLR
from torchaudio.transforms import Resample

class CTCLoss(nn.Module):
    def __init__(self):
        super(CTCLoss, self).__init__()

    def forward(self, y_true, y_pred):
        batch_size = y_true.size(0)
        input_length = y_pred.size(1)
        label_length = y_true.size(1)

        # Create length tensors
        input_length = torch.full((batch_size,), input_length, dtype=torch.long)
        label_length = torch.full((batch_size,), label_length, dtype=torch.long)

        loss = nn.CTCLoss()(y_pred.log_softmax(2), y_true, input_length, label_length)
        return loss


In [15]:
def set_device():
    if torch.cuda.is_available():
        dev = "cuda"
    else:
        dev = "cpu"
    return torch.device(dev)

In [26]:
def save_checkpoint(model, optimizer, epoch, loss, wer, filename="best_model.pth"):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        "wer": wer,
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")



In [27]:
def evaluate_model(model, dataloader, preprocessor, criterion, device):
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = batch["audio"]["array"]
            targets = batch["text"]

            # Preprocess audio
            inputs = [preprocessor(inputs[i], sampling_rate=16000).input_values[0] for i in range(len(inputs))]
            inputs = torch.tensor(inputs).to(device).unsqueeze(1)  # Add batch and channel dims

            # Forward pass
            outputs = model(inputs)
            outputs = outputs.permute(1, 0, 2)  # [seq_len, batch_size, num_classes]

            # Compute loss (optional during evaluation)
            input_lengths = torch.full((outputs.size(1),), outputs.size(0), dtype=torch.long)
            target_lengths = torch.tensor([len(t) for t in targets], dtype=torch.long)
            targets_encoded = preprocessor(targets, padding=True, return_tensors="pt").input_ids
            targets_encoded = targets_encoded.to(device)

            loss = criterion(outputs, targets_encoded, input_lengths, target_lengths)
            total_loss += loss.item()

            # Decode predictions for WER
            predicted_indices = outputs.argmax(dim=2).cpu().numpy()
            all_predictions.extend(predicted_indices)
            all_targets.extend(targets)

    # Compute WER
    wer = calculate_wer(all_predictions, all_targets)
    return total_loss / len(dataloader), wer

In [47]:
def train_model(
    model,
    train_loader,
    valid_loader,
    optimizer,
    criterion,
    preprocessor,
    num_epochs,
    device,
    example_callback=None,
    checkpoint_path="best_model.pth",
):
    lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1 if epoch < 30 else 0.1)
    best_wer = float("inf")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch in train_loader:
            
            inputs = batch["audio"]
            targets = batch["text"]

            # Preprocess audio
            inputs = [preprocessor(inputs[i], sampling_rate=16000).input_values[0] for i in range(len(inputs))]
            inputs = torch.tensor(inputs).to(device).unsqueeze(1)  # Add batch and channel dims

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            outputs = outputs.permute(1, 0, 2)  # [seq_len, batch_size, num_classes]

            # Compute loss
            input_lengths = torch.full((outputs.size(1),), outputs.size(0), dtype=torch.long)
            target_lengths = torch.tensor([len(t) for t in targets], dtype=torch.long)
            targets_encoded = preprocessor(targets, padding=True, return_tensors="pt").input_ids
            targets_encoded = targets_encoded.to(device)

            loss = criterion(outputs, targets_encoded, input_lengths, target_lengths)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Learning rate scheduling
        lr_scheduler.step()

        # Evaluate on validation set
        valid_loss, wer = evaluate_model(model, valid_loader, preprocessor, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}] | Train Loss: {running_loss / len(train_loader):.4f} | "
              f"Validation Loss: {valid_loss:.4f} | WER: {wer:.2f}")

        # Save the best model based on WER
        if wer < best_wer:
            best_wer = wer
            save_checkpoint(model, optimizer, epoch + 1, valid_loss, wer, checkpoint_path)

        # Example callback
        if example_callback:
            example_callback.on_epoch_end(epoch)


In [48]:
def calculate_wer(predictions, references):
    """Computes WER using the jiwer library."""
    pred_texts = ["".join(map(str, decode_predictions(pred))) for pred in predictions]
    ref_texts = ["".join(map(str, decode_predictions(ref))) for ref in references]
    wer = jiwer.wer(ref_texts, pred_texts)
    return wer

In [49]:
import torch.optim as optim

device = set_device()
model = VietLip(1, 40, 256, 1, 0.1,  False).to(device)  
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = CTCLoss()




In [50]:
train_model(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    criterion=criterion,
    preprocessor=processor,
    num_epochs=50,
    device=device,
    checkpoint_path="vietlip_best_model.pth",
)


tensor([[[ 2.0110e-05,  2.0110e-05,  2.0110e-05,  ...,  5.4945e-03,
          -1.6222e-03, -1.0381e-02],
         [-5.4724e-02, -2.5235e-01, -2.0965e-01,  ...,  2.0110e-05,
           2.0110e-05,  2.0110e-05],
         [ 2.0110e-05, -5.2733e-04,  5.6755e-04,  ...,  2.0110e-05,
           2.0110e-05,  2.0110e-05],
         ...,
         [ 2.0110e-05, -5.2733e-04,  5.6755e-04,  ...,  2.0110e-05,
           2.0110e-05,  2.0110e-05],
         [ 2.0110e-05,  5.6755e-04,  2.0110e-05,  ...,  2.0110e-05,
           2.0110e-05,  2.0110e-05],
         [ 2.0110e-05,  2.0110e-05, -5.2733e-04,  ...,  2.0110e-05,
           2.0110e-05,  2.0110e-05]]])


IndexError: too many indices for tensor of dimension 3

NameError: name 'batch' is not defined