In [1]:
!pip install transformers datasets torchaudio jiwer

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting jiwer
  Downloading jiwer-3.0.5-py3-none-any.whl.metadata (2.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[

In [2]:
#!/bin/bash
!kaggle datasets download kynthesis/vivos-vietnamese-speech-corpus-for-asr

Dataset URL: https://www.kaggle.com/datasets/kynthesis/vivos-vietnamese-speech-corpus-for-asr
License(s): Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
Downloading vivos-vietnamese-speech-corpus-for-asr.zip to /content
100% 1.37G/1.37G [00:17<00:00, 89.8MB/s]
100% 1.37G/1.37G [00:17<00:00, 83.5MB/s]


In [3]:
!unzip vivos-vietnamese-speech-corpus-for-asr.zip -d vivos-vietnamese-speech-corpus-for-asr

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_272.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_273.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_274.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_275.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_276.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_277.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_278.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_279.wav  
  inflating: vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK26/VIVOSSPK26_280.wav  
  inflating: vivo

In [8]:
import os
from datasets import DatasetDict, Dataset

# Function to parse PROMPTS.txt and load data
def load_vivos_subset(subset_path):
    data = {"path": [], "transcription": []}
    prompts_path = os.path.join(subset_path, "prompts.txt")

    with open(prompts_path, "r", encoding="utf-8") as f:
        for line in f:
            # Split the line: first part is file path, the rest is transcription
            file_path, transcription = line.strip().split(" ", 1)

            # Extract the subdirectory name from the file_path
            speaker_folder = file_path.split("_")[0]  # e.g., "VIVOSSPK01"

            # Construct the correct file path
            audio_path = os.path.join(subset_path, "waves", speaker_folder, file_path + ".wav")

            data["path"].append(audio_path)
            data["transcription"].append(transcription)

    return Dataset.from_dict(data)

# Reload the dataset
def load_vivos_dataset(dataset_path):
    return DatasetDict({
        "train": load_vivos_subset(os.path.join(dataset_path, "train")),
        "test": load_vivos_subset(os.path.join(dataset_path, "test")),
    })

# Set dataset path and load
DATASET_PATH = "/content/vivos-vietnamese-speech-corpus-for-asr/vivos"
vivos = load_vivos_dataset(DATASET_PATH)

# Verify loaded data
print(vivos)
print(vivos["test"][0])

DatasetDict({
    train: Dataset({
        features: ['path', 'transcription'],
        num_rows: 11660
    })
    test: Dataset({
        features: ['path', 'transcription'],
        num_rows: 760
    })
})
{'path': '/content/vivos-vietnamese-speech-corpus-for-asr/vivos/test/waves/VIVOSDEV02/VIVOSDEV02_R106.wav', 'transcription': 'TRỞ NÊN THỤ ĐỘNG'}


In [5]:
# Import required libraries
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset, DatasetDict
from torchaudio.transforms import Resample
from jiwer import wer
import numpy as np
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [6]:
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer  # For calculating Word Error Rate

# Set device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Pre-trained Model and Processor
model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)
processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")

# Function to preprocess audio
def preprocess_audio(path):
    # Load audio using torchaudio
    speech, sample_rate = torchaudio.load(path)

    # Resample to 16kHz if not already in the expected sample rate
    if sample_rate != 16000:
        transform = torchaudio.transforms.Resample(orig_freq=sample_rate,
                                                   new_freq=16000)
        speech = transform(speech)

    # Convert to a 1D array
    speech = speech.squeeze().numpy()
    return speech

# Test on the first input from the test set
test_sample = vivos["test"][0]  # Replace `vivos` with your DatasetDict variable
audio_path = test_sample["path"]
print(audio_path)
ground_truth = test_sample["transcription"]

# Preprocess the audio
speech = preprocess_audio(audio_path)

# Tokenize and make predictions
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values.to(device)
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)

# Decode predictions to text
predicted_transcription = processor.batch_decode(predicted_ids)[0]

# Calculate Word Error Rate (WER)
error_rate = wer(ground_truth.lower(), predicted_transcription.lower())

# Output results
print(f"Ground Truth: {ground_truth}")
print(f"Predicted Transcription: {predicted_transcription}")
print(f"WER: {error_rate}")

config.json:   0%|          | 0.00/1.65k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/378M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

/content/vivos-vietnamese-speech-corpus-for-asr/vivos/test/waves/VIVOSDEV02/VIVOSDEV02_R106.wav
Ground Truth: TRỞ NÊN THỤ ĐỘNG
Predicted Transcription: trở  nên thụ động
WER: 0.0


In [9]:
vivos = load_vivos_dataset(DATASET_PATH)
for test_sample in vivos["test"]:
    audio_path = test_sample["path"]
    ground_truth = test_sample["transcription"]

    # Preprocess, predict, and evaluate like the second code
    speech = preprocess_audio(audio_path)
    input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values.to(device)
    logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_transcription = processor.batch_decode(predicted_ids)[0]

    print(f"Ground Truth: {ground_truth}")
    print(f"Predicted: {predicted_transcription}")
    print(f"WER: {wer(ground_truth.lower(), predicted_transcription.lower())}")
    break

Ground Truth: TRỞ NÊN THỤ ĐỘNG
Predicted: trở  nên thụ động
WER: 0.0


In [10]:
# # Clear RAM cache
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [11]:
vivos

DatasetDict({
    train: Dataset({
        features: ['path', 'transcription'],
        num_rows: 11660
    })
    test: Dataset({
        features: ['path', 'transcription'],
        num_rows: 760
    })
})

In [12]:
# Import random for reproducibility
import random

# Define a seed for reproducibility
random.seed(42)

# Sample 1% of train and test splits
vivos = DatasetDict({
    "train": vivos["train"].select(random.sample(range(len(vivos["train"])), int(0.01 * len(vivos["train"])))),
    "test": vivos["test"].select(random.sample(range(len(vivos["test"])), int(0.01 * len(vivos["test"]))))
})

# Print the new dataset sizes
print(vivos)

DatasetDict({
    train: Dataset({
        features: ['path', 'transcription'],
        num_rows: 116
    })
    test: Dataset({
        features: ['path', 'transcription'],
        num_rows: 7
    })
})


In [13]:
vivos["train"][0]

{'path': '/content/vivos-vietnamese-speech-corpus-for-asr/vivos/train/waves/VIVOSSPK43/VIVOSSPK43_017.wav',
 'transcription': 'CHÚNG GIÚP NGĂN NGỪA BỆNH TIM CHỐNG VIÊM KHỚP VÀ KÍCH THÍCH NÃO'}

In [46]:
import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from torchaudio.transforms import Resample
import torchaudio
from jiwer import wer

# Initialize device and models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)

def preprocess_audio(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    resampler = Resample(sampling_rate, 16000)
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    batch["sampling_rate"] = 16000
    return batch

def prepare_dataset(batch):
    # Process audio input
    inputs = processor(batch["speech"],
                      sampling_rate=batch["sampling_rate"],
                      return_tensors="pt",
                      padding=True)

    # Store the original transcription
    batch["input_values"] = inputs.input_values[0]
    batch["labels"] = batch["transcription"]  # Keep the original text
    return batch

def collate_fn(batch):
    # Process input values
    input_values = [torch.tensor(item["input_values"]) for item in batch]
    input_values = processor.pad(
        {"input_values": input_values},
        padding=True,
        return_tensors="pt"
    )["input_values"]

    # Keep original text labels
    labels = [item["labels"] for item in batch]

    return {
        "input_values": input_values,
        "labels": labels  # Return original text labels
    }

def evaluate(model, processor, data_loader):
    model.eval()
    predictions, references = [], []

    with torch.no_grad():
        for batch in data_loader:
            input_values = batch["input_values"].to(device)
            references.extend(batch["labels"])  # Use original text labels

            # Forward pass
            logits = model(input_values).logits
            pred_ids = torch.argmax(logits, dim=-1)
            predictions.extend(processor.batch_decode(pred_ids))

    # Calculate WER
    # Lower the array before calculate error rate
    references = [ref.lower() for ref in references]
    predictions = [pred.lower() for pred in predictions]
    wer_score = wer(references, predictions)
    return wer_score, predictions, references

# Process the dataset
vivos_processed = vivos.map(preprocess_audio)
vivos_prepared = vivos_processed.map(
    prepare_dataset,
    remove_columns=["speech", "sampling_rate", "path"]  # Keep transcription
)

# Create data loaders
train_loader = DataLoader(
    vivos_prepared["train"],
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    vivos_prepared["test"],
    batch_size=4,
    collate_fn=collate_fn
)

# Evaluate
print("Evaluating model...")
wer_score, predictions, references = evaluate(model, processor, test_loader)
print(f"WER: {wer_score}")

# Print sample predictions
print("\nSample Predictions:")
for pred, ref in zip(predictions[:3], references[:3]):
    print(f"Reference: {ref}")
    print(f"Predicted: {pred}")
    print("-" * 50)



Map:   0%|          | 0/116 [00:00<?, ? examples/s]

Map:   0%|          | 0/7 [00:00<?, ? examples/s]

Map:   0%|          | 0/116 [00:00<?, ? examples/s]

Map:   0%|          | 0/7 [00:00<?, ? examples/s]

Evaluating model...
WER: 0.13559322033898305

Sample Predictions:
Reference: ngủ trên nệm không phù hợp
Predicted: ngủ trang niệm không phù hợp
--------------------------------------------------
Reference: hai mươi bốn hai mươi lăm
Predicted: hai mươi bốn hai mươi lăm
--------------------------------------------------
Reference: gí hòn than vào rơm
Predicted: guý hoàng thang vàu rươm
--------------------------------------------------


In [24]:
for batch in train_loader:
    print(batch)
    print(len(batch['labels']))
    break

{'input_values': tensor([[ 2.9091e-06,  2.9091e-06, -4.1950e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.0097e-01,  5.4827e-02,  1.6560e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.0109e-07,  1.0109e-07,  1.0109e-07,  ..., -9.0240e-03,
         -7.6356e-03, -9.3710e-03],
        [-2.0010e-05, -2.0010e-05, -8.0151e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]]), 'labels': ['ĐÂU PHẢI LÀ BA CÁI CHUYỆN BE BÉ XINH XINH THẾ NÀY', 'TÔI TÌM MỌI CÁCH ĐỂ NÍU KÉO NHƯNG KHÔNG CÓ KẾT QUẢ', 'ĐỘI CỦA BỆNH VIỆN PHÁP VIỆT HÓA TRANG VỚI MÀU XANH ĐẶC TRƯNG CỦA MÌNH ẢNH MINH ĐỨC', 'NHIỀU HÃNG PHIM ĐÃ VÀO CUỘC ĐỂ ĐÒI QUYỀN LỢI CỦA MÌNH']}
4


In [35]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
import os
from datetime import datetime

class CTCTrainer:
    def __init__(self, model, train_loader, test_loader, processor, device,
                 learning_rate=1e-4, num_epochs=10):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.processor = processor
        self.device = device
        self.num_epochs = num_epochs

        # Initialize optimizer and loss function
        self.optimizer = AdamW(model.parameters(), lr=learning_rate)
        self.criterion = nn.CTCLoss(blank=processor.tokenizer.pad_token_id)

        # Initialize best metrics for model saving
        self.best_wer = float('inf')

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc='Training')

        for batch in progress_bar:
            # Move input to device
            input_values = batch["input_values"].to(self.device)

            # Clear gradients
            self.optimizer.zero_grad()

            # Forward pass
            outputs = self.model(input_values)
            logits = outputs.logits

            # Process labels batch by batch
            with self.processor.as_target_processor():
                labels_batch = []
                for text in batch["labels"]:
                    labels = self.processor(text).input_ids
                    labels_batch.append(labels)

            # Pad labels to the same length
            max_label_length = max(len(labels) for labels in labels_batch)
            padded_labels = torch.full((len(labels_batch), max_label_length),
                                    fill_value=self.processor.tokenizer.pad_token_id,
                                    device=self.device)

            label_lengths = []
            for i, labels in enumerate(labels_batch):
                label_length = len(labels)
                label_lengths.append(label_length)
                padded_labels[i, :label_length] = torch.tensor(labels, device=self.device)

            # Calculate input lengths
            input_lengths = torch.full(size=(logits.shape[0],),
                                    fill_value=logits.shape[1],
                                    device=self.device)

            # Convert label lengths to tensor
            label_lengths = torch.tensor(label_lengths, device=self.device)

            # Calculate loss
            loss = self.criterion(logits.transpose(0, 1), padded_labels,
                                input_lengths, label_lengths)

            # Backward pass
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        return total_loss / len(self.train_loader)

    def evaluate(self):
        self.model.eval()
        predictions, references = [], []

        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc='Evaluating'):
                input_values = batch["input_values"].to(self.device)
                references.extend(batch["labels"])

                # Forward pass
                logits = self.model(input_values).logits
                pred_ids = torch.argmax(logits, dim=-1)
                predictions.extend(self.processor.batch_decode(pred_ids))

        # Calculate WER
        references = [ref.lower() for ref in references]
        predictions = [pred.lower() for pred in predictions]
        wer_score = wer(references, predictions)

        return wer_score, predictions, references

    def save_checkpoint(self, epoch, wer_score):
        # Create checkpoint directory if it doesn't exist
        checkpoint_dir = "checkpoints"
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        # Save model if it's the best so far
        if wer_score < self.best_wer:
            self.best_wer = wer_score
            checkpoint_path = os.path.join(
                checkpoint_dir,
                f"wav2vec2_vietnamese_epoch{epoch}_wer{wer_score:.4f}.pt"
            )

            # Save model state
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'wer': wer_score,
            }, checkpoint_path)

            print(f"\nSaved best model checkpoint to {checkpoint_path}")

    def train(self):
        print("Starting training...")

        # Get initial WER
        initial_wer, _, _ = self.evaluate()
        print(f"Initial WER: {initial_wer}")

        # Training loop
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.num_epochs}")

            # Train for one epoch
            avg_loss = self.train_epoch()
            print(f"Average loss: {avg_loss:.4f}")

            # Evaluate
            wer_score, predictions, references = self.evaluate()
            print(f"WER: {wer_score}")

            # Save checkpoint
            self.save_checkpoint(epoch + 1, wer_score)

            # Print some sample predictions
            print("\nSample Predictions:")
            for pred, ref in zip(predictions[:3], references[:3]):
                print(f"Reference: {ref}")
                print(f"Predicted: {pred}")
                print("-" * 50)

# Initialize trainer
trainer = CTCTrainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    processor=processor,
    device=device,
    learning_rate=1e-4,
    num_epochs=1
)

# Start training
trainer.train()

# Load best checkpoint and test
def load_best_model(model, checkpoint_dir="checkpoints"):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]
    best_checkpoint = min(checkpoints, key=lambda x: float(x.split('wer')[-1].split('.pt')[0]))
    checkpoint_path = os.path.join(checkpoint_dir, best_checkpoint)

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\nLoaded best model from {checkpoint_path}")
    print(f"Best WER: {checkpoint['wer']}")

    return model

# Test best model
best_model = load_best_model(model)
trainer.model = best_model
final_wer, final_predictions, final_references = trainer.evaluate()
print(f"\nFinal WER: {final_wer}")

Starting training...


Evaluating: 100%|██████████| 2/2 [00:00<00:00,  5.64it/s]


Initial WER: 0.13559322033898305

Epoch 1/1


Training: 100%|██████████| 29/29 [00:16<00:00,  1.71it/s, loss=-24.5]


Average loss: -29.6267


Evaluating: 100%|██████████| 2/2 [00:00<00:00,  4.64it/s]


WER: 1.0

Saved best model checkpoint to checkpoints/wav2vec2_vietnamese_epoch1_wer1.0000.pt

Sample Predictions:
Reference: ngủ trên nệm không phù hợp
Predicted: 
--------------------------------------------------
Reference: hai mươi bốn hai mươi lăm
Predicted: 
--------------------------------------------------
Reference: gí hòn than vào rơm
Predicted: 
--------------------------------------------------


  checkpoint = torch.load(checkpoint_path)



Loaded best model from checkpoints/wav2vec2_vietnamese_epoch1_wer1.0000.pt
Best WER: 1.0


Evaluating: 100%|██████████| 2/2 [00:00<00:00,  5.76it/s]


Final WER: 1.0





In [33]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
import os
from datetime import datetime
from jiwer import wer

class CTCTrainer:
    def __init__(self, model, train_loader, test_loader, processor, device,
                 learning_rate=1e-5,  # Reduced learning rate
                 num_epochs=10):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.processor = processor
        self.device = device
        self.num_epochs = num_epochs

        # Initialize optimizer with weight decay and lower learning rate
        self.optimizer = AdamW(model.parameters(),
                             lr=learning_rate,
                             weight_decay=0.01)  # Added weight decay

        # Initialize loss function with ignore_index
        self.criterion = nn.CTCLoss(blank=processor.tokenizer.pad_token_id,
                                  zero_infinity=True,
                                  reduction='mean')

        self.best_wer = float('inf')

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc='Training')

        for batch_idx, batch in enumerate(progress_bar):
            # Move input to device
            input_values = batch["input_values"].to(self.device)

            # Clear gradients
            self.optimizer.zero_grad()

            # Forward pass
            outputs = self.model(input_values)
            logits = outputs.logits

            # Debug print logits statistics
            if batch_idx == 0:
                print("\nLogits statistics:")
                print("Mean:", logits.mean().item())
                print("Std:", logits.std().item())
                print("Max:", logits.max().item())
                print("Min:", logits.min().item())

            # Process labels
            with torch.no_grad():
                labels_batch = self.processor(text=batch["labels"],
                                           padding=True,
                                           return_tensors="pt")
                labels = labels_batch.input_ids.to(self.device)

                # Calculate input lengths from logits
                input_lengths = torch.full(
                    size=(logits.shape[0],),
                    fill_value=logits.shape[1],
                    dtype=torch.long,
                    device=self.device
                )

                # Calculate label lengths
                label_lengths = torch.sum(labels != self.processor.tokenizer.pad_token_id, dim=1)

            # Calculate loss with proper log_softmax
            log_probs = nn.functional.log_softmax(logits, dim=-1)
            log_probs = log_probs.transpose(0, 1)

            loss = self.criterion(log_probs, labels,
                                input_lengths, label_lengths)

            # Check for invalid loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Warning: Invalid loss {loss.item()} detected. Skipping batch.")
                continue

            # Backward pass with gradient clipping
            loss.backward()

            # Print gradient norms for debugging
            if batch_idx == 0:
                total_norm = 0.0
                for p in self.model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                total_norm = total_norm ** 0.5
                print(f"Gradient norm: {total_norm}")

            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

            # Add predictions debugging for first batch
            if batch_idx == 0:
                with torch.no_grad():
                    pred_ids = torch.argmax(logits, dim=-1)
                    print("\nSample predictions from batch:")
                    print("Pred_ids unique values:", torch.unique(pred_ids))
                    try:
                        predictions = self.processor.batch_decode(pred_ids)
                        print("First prediction:", predictions[0])
                    except Exception as e:
                        print("Decoding error:", str(e))

        return total_loss / len(self.train_loader)

    def evaluate(self):
        self.model.eval()
        predictions, references = [], []

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(self.test_loader, desc='Evaluating')):
                input_values = batch["input_values"].to(self.device)
                references.extend(batch["labels"])

                # Debug print for input shape
                print(f"\nBatch {batch_idx} debugging:")
                print("Input values shape:", input_values.shape)

                # Forward pass
                logits = self.model(input_values).logits
                print("Logits shape:", logits.shape)

                # Get predictions
                pred_ids = torch.argmax(logits, dim=-1)
                print("Pred_ids shape:", pred_ids.shape)
                print("First few pred_ids:", pred_ids[0, :10])  # Show first 10 predictions

                # Debug the processor decode step
                try:
                    pred_transcripts = self.processor.batch_decode(pred_ids)
                    print("Successfully decoded predictions")
                    print("First decoded transcript:", pred_transcripts[0] if pred_transcripts else "Empty")
                except Exception as e:
                    print("Error in batch_decode:", str(e))
                    # Try single sample decoding as fallback
                    pred_transcripts = []
                    for pred_id in pred_ids:
                        try:
                            transcript = self.processor.decode(pred_id)
                            print("Single decode result:", transcript)
                            pred_transcripts.append(transcript)
                        except Exception as e:
                            print("Error in single decode:", str(e))
                            pred_transcripts.append("")

                predictions.extend(pred_transcripts)

                # Debug print references and predictions
                print("\nFirst sample in batch:")
                print("Reference:", references[-len(pred_transcripts)])
                print("Predicted:", pred_transcripts[0] if pred_transcripts else "Empty")

        # Print overall statistics
        print("\nOverall statistics:")
        print(f"Total number of references: {len(references)}")
        print(f"Total number of predictions: {len(predictions)}")
        print("\nFirst few predictions:")
        for i in range(min(3, len(predictions))):
            print(f"Reference: {references[i]}")
            print(f"Predicted: {predictions[i]}")
            print("-" * 50)

        # Calculate WER
        references = [ref.lower() for ref in references]
        predictions = [pred.lower() for pred in predictions]
        wer_score = wer(references, predictions)

        return wer_score, predictions, references

    def save_checkpoint(self, epoch, wer_score):
        checkpoint_dir = "checkpoints"
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        if wer_score < self.best_wer:
            self.best_wer = wer_score
            checkpoint_path = os.path.join(
                checkpoint_dir,
                f"wav2vec2_vietnamese_epoch{epoch}_wer{wer_score:.4f}.pt"
            )

            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'wer': wer_score,
            }, checkpoint_path)

            print(f"\nSaved best model checkpoint to {checkpoint_path}")

    def train(self):
        print("Starting training...")

        # Get initial WER
        initial_wer, _, _ = self.evaluate()
        print(f"Initial WER: {initial_wer}")

        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.num_epochs}")

            # Train for one epoch
            avg_loss = self.train_epoch()
            print(f"Average loss: {avg_loss:.4f}")

            # Evaluate
            wer_score, predictions, references = self.evaluate()
            print(f"WER: {wer_score}")

            # Save checkpoint
            self.save_checkpoint(epoch + 1, wer_score)

            # Print some sample predictions
            print("\nSample Predictions:")
            for pred, ref in zip(predictions[:3], references[:3]):
                print(f"Reference: {ref}")
                print(f"Predicted: {pred}")
                print("-" * 50)

# Initialize trainer
trainer = CTCTrainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    processor=processor,
    device=device,
    learning_rate=1e-4,
    num_epochs=1
)

# Start training
trainer.train()

# Load best checkpoint and test
def load_best_model(model, checkpoint_dir="checkpoints"):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]
    best_checkpoint = min(checkpoints, key=lambda x: float(x.split('wer')[-1].split('.pt')[0]))
    checkpoint_path = os.path.join(checkpoint_dir, best_checkpoint)

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\nLoaded best model from {checkpoint_path}")
    print(f"Best WER: {checkpoint['wer']}")

    return model

# Test best model
best_model = load_best_model(model)
trainer.model = best_model
final_wer, final_predictions, final_references = trainer.evaluate()
print(f"\nFinal WER: {final_wer}")

Starting training...


Evaluating:  50%|█████     | 1/2 [00:00<00:00,  2.84it/s]


Batch 0 debugging:
Input values shape: torch.Size([4, 54000])
Logits shape: torch.Size([4, 168, 110])
Pred_ids shape: torch.Size([4, 168])
First few pred_ids: tensor([109, 109, 109, 109, 109, 109, 109, 109, 109, 109], device='cuda:0')
Successfully decoded predictions
First decoded transcript: ngủ trang niệm không phù hợp

First sample in batch:
Reference: NGỦ TRÊN NỆM KHÔNG PHÙ HỢP
Predicted: ngủ trang niệm không phù hợp


Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.13it/s]



Batch 1 debugging:
Input values shape: torch.Size([3, 62000])
Logits shape: torch.Size([3, 193, 110])
Pred_ids shape: torch.Size([3, 193])
First few pred_ids: tensor([109, 109, 109, 109, 109, 109, 109, 109, 109, 109], device='cuda:0')
Successfully decoded predictions
First decoded transcript: cuộc biểu tình của áo đỏ có thể sẽ biến thành bạo lực

First sample in batch:
Reference: CUỘC BIỂU TÌNH CỦA ÁO ĐỎ CÓ THỂ SẼ BIẾN THÀNH BẠO LỰC
Predicted: cuộc biểu tình của áo đỏ có thể sẽ biến thành bạo lực

Overall statistics:
Total number of references: 7
Total number of predictions: 7

First few predictions:
Reference: NGỦ TRÊN NỆM KHÔNG PHÙ HỢP
Predicted: ngủ trang niệm không phù hợp
--------------------------------------------------
Reference: HAI MƯƠI BỐN HAI MƯƠI LĂM
Predicted: hai mươi bốn hai mươi lăm
--------------------------------------------------
Reference: GÍ HÒN THAN VÀO RƠM
Predicted: guý hoàng thang vàu rươm
--------------------------------------------------
Initial WER: 0.1355

Training:   0%|          | 0/29 [00:00<?, ?it/s]


Logits statistics:
Mean: -4.597815036773682
Std: 4.600584506988525
Max: 15.41457748413086
Min: -21.01230239868164


Training:   3%|▎         | 1/29 [00:00<00:21,  1.30it/s, loss=13.1]

Gradient norm: 34.603086331562714

Sample predictions from batch:
Pred_ids unique values: tensor([  3,   7,   8,   9,  12,  13,  15,  17,  18,  20,  24,  26,  30,  31,
         32,  36,  40,  45,  46,  47,  49,  50,  52,  54,  56,  57,  58,  62,
         67,  71,  72,  73,  79,  82,  83,  85,  88,  89,  92,  95,  96,  98,
         99, 102, 105, 106, 109], device='cuda:0')
First prediction: khi chạy lớn lên thình thoảng có thể có nước mắt chảy quay lỡ này


Training: 100%|██████████| 29/29 [00:18<00:00,  1.55it/s, loss=0.654]


Average loss: 1.9336


Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]


Batch 0 debugging:
Input values shape: torch.Size([4, 54000])


Evaluating:  50%|█████     | 1/2 [00:00<00:00,  4.16it/s]

Logits shape: torch.Size([4, 168, 110])
Pred_ids shape: torch.Size([4, 168])
First few pred_ids: tensor([109, 109, 109, 109, 109, 109, 109, 109, 109, 109], device='cuda:0')
Successfully decoded predictions
First decoded transcript: 

First sample in batch:
Reference: NGỦ TRÊN NỆM KHÔNG PHÙ HỢP
Predicted: 


Evaluating: 100%|██████████| 2/2 [00:00<00:00,  4.41it/s]


Batch 1 debugging:
Input values shape: torch.Size([3, 62000])
Logits shape: torch.Size([3, 193, 110])
Pred_ids shape: torch.Size([3, 193])
First few pred_ids: tensor([109, 109, 109, 109, 109, 109, 109, 109, 109, 109], device='cuda:0')
Successfully decoded predictions
First decoded transcript: 

First sample in batch:
Reference: CUỘC BIỂU TÌNH CỦA ÁO ĐỎ CÓ THỂ SẼ BIẾN THÀNH BẠO LỰC
Predicted: 

Overall statistics:
Total number of references: 7
Total number of predictions: 7

First few predictions:
Reference: NGỦ TRÊN NỆM KHÔNG PHÙ HỢP
Predicted: 
--------------------------------------------------
Reference: HAI MƯƠI BỐN HAI MƯƠI LĂM
Predicted: 
--------------------------------------------------
Reference: GÍ HÒN THAN VÀO RƠM
Predicted: 
--------------------------------------------------
WER: 1.0






Saved best model checkpoint to checkpoints/wav2vec2_vietnamese_epoch1_wer1.0000.pt

Sample Predictions:
Reference: ngủ trên nệm không phù hợp
Predicted: 
--------------------------------------------------
Reference: hai mươi bốn hai mươi lăm
Predicted: 
--------------------------------------------------
Reference: gí hòn than vào rơm
Predicted: 
--------------------------------------------------


  checkpoint = torch.load(checkpoint_path)



Loaded best model from checkpoints/wav2vec2_vietnamese_epoch1_wer1.0000.pt
Best WER: 1.0


Evaluating:  50%|█████     | 1/2 [00:00<00:00,  5.34it/s]


Batch 0 debugging:
Input values shape: torch.Size([4, 54000])
Logits shape: torch.Size([4, 168, 110])
Pred_ids shape: torch.Size([4, 168])
First few pred_ids: tensor([109, 109, 109, 109, 109, 109, 109, 109, 109, 109], device='cuda:0')
Successfully decoded predictions
First decoded transcript: 

First sample in batch:
Reference: NGỦ TRÊN NỆM KHÔNG PHÙ HỢP
Predicted: 

Batch 1 debugging:
Input values shape: torch.Size([3, 62000])


Evaluating: 100%|██████████| 2/2 [00:00<00:00,  5.51it/s]

Logits shape: torch.Size([3, 193, 110])
Pred_ids shape: torch.Size([3, 193])
First few pred_ids: tensor([109, 109, 109, 109, 109, 109, 109, 109, 109, 109], device='cuda:0')
Successfully decoded predictions
First decoded transcript: 

First sample in batch:
Reference: CUỘC BIỂU TÌNH CỦA ÁO ĐỎ CÓ THỂ SẼ BIẾN THÀNH BẠO LỰC
Predicted: 

Overall statistics:
Total number of references: 7
Total number of predictions: 7

First few predictions:
Reference: NGỦ TRÊN NỆM KHÔNG PHÙ HỢP
Predicted: 
--------------------------------------------------
Reference: HAI MƯƠI BỐN HAI MƯƠI LĂM
Predicted: 
--------------------------------------------------
Reference: GÍ HÒN THAN VÀO RƠM
Predicted: 
--------------------------------------------------

Final WER: 1.0





In [47]:
class CTCTrainer:
    def __init__(self, model, train_loader, test_loader, processor, device,
                 learning_rate=1e-5,
                 num_epochs=1):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.processor = processor
        self.device = device
        self.num_epochs = num_epochs

        # Initialize optimizer with weight decay
        self.optimizer = AdamW(model.parameters(),
                             lr=learning_rate,
                             weight_decay=0.01)

        # Initialize CTC loss
        self.criterion = nn.CTCLoss(blank=processor.tokenizer.pad_token_id,
                                  zero_infinity=True)

        self.best_wer = float('inf')

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc='Training')

        for batch_idx, batch in enumerate(progress_bar):
            # Move input to device
            input_values = batch["input_values"].to(self.device)

            # Clear gradients
            self.optimizer.zero_grad()

            # Forward pass
            outputs = self.model(input_values)
            print(f"Outputs shape: {outputs.logits.shape}")
            logits = outputs.logits
            print(f"Logits shape: {logits.shape}")

            # Process labels using processor
            with torch.no_grad():
                # Convert text to input IDs
                labels_batch = self.processor(
                    text=batch["labels"],
                    padding=True,
                    return_tensors="pt"
                )

                # Get input IDs and move to device
                labels = labels_batch.input_ids.to(self.device)

                # Calculate input lengths (from logits)
                input_lengths = torch.full(
                    size=(logits.shape[0],),
                    fill_value=logits.shape[1],
                    dtype=torch.long,
                    device=self.device
                )

                # Calculate label lengths (excluding padding)
                label_lengths = (labels != self.processor.tokenizer.pad_token_id).sum(dim=-1)

            # Apply log softmax
            log_probs = nn.functional.log_softmax(logits, dim=-1)
            print(f"Log probs shape: {log_probs.shape}")

            # Transpose log_probs for CTC loss (time, batch, class)
            log_probs = log_probs.transpose(0, 1)
            print(f"Transposed log probs shape: {log_probs.shape}")

            # Calculate CTC loss
            loss = self.criterion(log_probs, labels,
                                input_lengths, label_lengths)
            print(f"CTC Loss: {loss.item()}")

            # Skip invalid loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Warning: Invalid loss {loss.item()} detected. Skipping batch.")
                continue

            # Backward pass
            loss.backward()

            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Optimizer step
            self.optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        return total_loss / len(self.train_loader)

    def evaluate(self):
        self.model.eval()
        predictions, references = [], []

        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc='Evaluating'):
                input_values = batch["input_values"].to(self.device)
                references.extend(batch["labels"])

                # Forward pass
                outputs = self.model(input_values)
                logits = outputs.logits

                # Get predictions
                pred_ids = torch.argmax(logits, dim=-1)

                # Decode predictions
                decoded_preds = self.processor.batch_decode(pred_ids)
                predictions.extend(decoded_preds)

        # Calculate WER
        references = [ref.lower() for ref in references]
        predictions = [pred.lower() for pred in predictions]
        wer_score = wer(references, predictions)

        return wer_score, predictions, references

    def train(self):
        print("Starting training...")

        # Get initial WER
        initial_wer, _, _ = self.evaluate()
        print(f"Initial WER: {initial_wer}")

        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.num_epochs}")

            # Train epoch
            avg_loss = self.train_epoch()
            print(f"Average loss: {avg_loss:.4f}")

            # Evaluate
            wer_score, predictions, references = self.evaluate()
            print(f"WER: {wer_score}")

            # Print samples
            print("\nSample Predictions:")
            for pred, ref in zip(predictions[:3], references[:3]):
                print(f"Reference: {ref}")
                print(f"Predicted: {pred}")
                print("-" * 50)

# Initialize trainer with fixed implementation
trainer = CTCTrainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    processor=processor,
    device=device,
    learning_rate=1e-5,  # Lower learning rate
    num_epochs=1  # Increase epochs
)

# Start training
trainer.train()

Starting training...


Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

Decoded predictions: ['ngủ trang niệm không phù hợp', 'hai mươi bốn hai mươi lăm', 'guý hoàng thang vàu rươm', 'trận chiến xích bích']
Decoded predictions: ['cuộc biểu tình của áo đỏ có thể sẽ biến thành bạo lực', 'vì vậy công ty bê phải có trách nhiệm bồi thường', 'họ cũng luôn tránh xa những tình huống có thể gây hại cho họ']
Initial WER: 0.13559322033898305

Epoch 1/1


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Outputs shape: torch.Size([4, 324, 110])
Logits shape: torch.Size([4, 324, 110])
Log probs shape: torch.Size([4, 324, 110])
Transposed log probs shape: torch.Size([324, 4, 110])
CTC Loss: 12.591304779052734
Outputs shape: torch.Size([4, 276, 110])
Logits shape: torch.Size([4, 276, 110])
Log probs shape: torch.Size([4, 276, 110])
Transposed log probs shape: torch.Size([276, 4, 110])
CTC Loss: 11.592634201049805
Outputs shape: torch.Size([4, 190, 110])
Logits shape: torch.Size([4, 190, 110])
Log probs shape: torch.Size([4, 190, 110])
Transposed log probs shape: torch.Size([190, 4, 110])
CTC Loss: 10.232593536376953
Outputs shape: torch.Size([4, 162, 110])
Logits shape: torch.Size([4, 162, 110])
Log probs shape: torch.Size([4, 162, 110])
Transposed log probs shape: torch.Size([162, 4, 110])
CTC Loss: 11.128286361694336
Outputs shape: torch.Size([4, 306, 110])
Logits shape: torch.Size([4, 306, 110])
Log probs shape: torch.Size([4, 306, 110])
Transposed log probs shape: torch.Size([306, 4, 

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

Decoded predictions: ['<unk><unk><unk>', '<unk><unk>', '<unk><unk><unk><unk><unk><unk>', '<unk><unk><unk><unk><unk><unk><unk><unk>']
Decoded predictions: ['<unk><unk>cbểuaàbạ<unk>', '<unk>mờ<unk><unk><unk><unk><unk><unk><unk><unk>', '<unk>ngnngạ<unk>']
WER: 1.0

Sample Predictions:
Reference: ngủ trên nệm không phù hợp
Predicted: <unk><unk><unk>
--------------------------------------------------
Reference: hai mươi bốn hai mươi lăm
Predicted: <unk><unk>
--------------------------------------------------
Reference: gí hòn than vào rơm
Predicted: <unk><unk><unk><unk><unk><unk>
--------------------------------------------------
