In [None]:
# Imports
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import pandas as pd
import numpy as np
from pathlib import Path
import librosa
from sklearn.model_selection import train_test_split
import string
from collections import Counter
from datasets import load_dataset

In [None]:
# Preprocessing

class TextProcessor:
    """
    Handles conversion between text and numerical sequences for training.
    This is crucial because neural networks work with numbers, not text.
    """

    def __init__(self, vocab_size_limit=None):
        # Special tokens that help the model understand structure
        self.pad_token = '<PAD>'    # Padding for variable-length sequences
        self.unk_token = '<UNK>'    # Unknown/out-of-vocabulary words
        self.sos_token = '<SOS>'    # Start of sequence
        self.eos_token = '<EOS>'    # End of sequence

        # Will store character-to-index and index-to-character mappings
        self.char_to_idx = {}
        self.idx_to_char = {}
        self.vocab_size = 0
        self.vocab_size_limit = vocab_size_limit

    def build_vocabulary(self, texts):
        """
        Creates a mapping between characters and numbers.
        The model needs to convert text to numbers for processing.
        """
        # Count frequency of each character across all texts
        char_counter = Counter()
        for text in texts:
            # Convert to lowercase and count each character
            char_counter.update(text.lower())

        # Start with special tokens
        vocab = [self.pad_token, self.unk_token, self.sos_token, self.eos_token]

        # Add most common characters (if limit specified)
        if self.vocab_size_limit:
            most_common_chars = char_counter.most_common(self.vocab_size_limit - len(vocab))
            vocab.extend([char for char, _ in most_common_chars])
        else:
            vocab.extend(list(char_counter.keys()))

        # Create bidirectional mapping
        self.char_to_idx = {char: idx for idx, char in enumerate(vocab)}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(vocab)

        print(f"Vocabulary size: {self.vocab_size}")
        print(f"Sample characters: {list(self.char_to_idx.keys())[:20]}")

    def text_to_sequence(self, text):
        """Convert text string to sequence of integers"""
        sequence = [self.char_to_idx[self.sos_token]]  # Start token

        for char in text.lower():
            # Use unknown token if character not in vocabulary
            idx = self.char_to_idx.get(char, self.char_to_idx[self.unk_token])
            sequence.append(idx)

        sequence.append(self.char_to_idx[self.eos_token])  # End token
        return sequence

    def sequence_to_text(self, sequence):
        """Convert sequence of integers back to text"""
        chars = []
        for idx in sequence:
            if idx in self.idx_to_char:
                char = self.idx_to_char[idx]
                # Skip special tokens in output
                if char not in [self.pad_token, self.sos_token, self.eos_token]:
                    chars.append(char)
        return ''.join(chars)

In [15]:
# Audio Processing

class AudioProcessor:
    """
    Handles audio loading and feature extraction.
    Converts raw audio waveforms into features the model can understand.
    """

    def __init__(self, sample_rate=16000, n_mels=80, n_fft=512, hop_length=256):
        self.sample_rate = sample_rate  # Standard rate for speech recognition
        self.n_mels = n_mels           # Number of mel-frequency bands
        self.n_fft = n_fft             # FFT window size
        self.hop_length = hop_length   # Step size between windows

        # Mel-spectrogram transform - converts audio to visual representation
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_mels=n_mels,
            n_fft=n_fft,
            hop_length=hop_length
        )

    def load_audio(self, file_path):
        """
        Load audio file and convert to standard format.
        Ensures all audio has consistent sample rate and format.
        """
        try:
            # Load audio using torchaudio
            waveform, orig_sample_rate = torchaudio.load(file_path)

            # Convert to mono if stereo
            if waveform.size(0) > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Resample if necessary
            if orig_sample_rate != self.sample_rate:
                resampler = torchaudio.transforms.Resample(
                    orig_freq=orig_sample_rate,
                    new_freq=self.sample_rate
                )
                waveform = resampler(waveform)

            return waveform.squeeze(0)  # Remove channel dimension

        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None

    def extract_features(self, waveform):
        """
        Convert audio waveform to mel-spectrogram features.
        Mel-spectrograms capture the frequency content over time,
        similar to how humans perceive sound.
        """
        # Apply mel-spectrogram transformation
        mel_spec = self.mel_transform(waveform.unsqueeze(0))  # Add batch dimension

        # Convert to log scale (more natural for human hearing)
        mel_spec = torch.log(mel_spec + 1e-8)  # Add small value to avoid log(0)

        # Remove batch dimension and transpose for model input
        # Shape: (time_steps, n_mels)
        return mel_spec.squeeze(0).transpose(0, 1)

In [16]:
# Dataset

class CommonVoiceDataset(Dataset):
    """
    PyTorch Dataset class for Mozilla Common Voice data.
    Handles loading audio files and corresponding transcriptions.
    """

    def __init__(self, data_dir, metadata_file, text_processor, audio_processor, max_audio_length=16000*10):
        self.data_dir = Path(data_dir)
        self.text_processor = text_processor
        self.audio_processor = audio_processor
        self.max_audio_length = max_audio_length  # 10 seconds at 16kHz

        # Load metadata CSV file
        self.metadata = pd.read_csv(metadata_file, sep='\t')

        # Filter out very long audio files (for memory efficiency)
        print(f"Total samples before filtering: {len(self.metadata)}")
        self.metadata = self.metadata[self.metadata['path'].notna()]
        print(f"Total samples after filtering: {len(self.metadata)}")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        """
        Returns a single training example: audio features and text sequence.
        This method is called by PyTorch's DataLoader during training.
        """
        row = self.metadata.iloc[idx]

        # Get file paths
        audio_path = self.data_dir / "clips" / row['path']
        text = row['sentence']

        # Load and process audio
        waveform = self.audio_processor.load_audio(audio_path)

        if waveform is None:
            # Return dummy data if audio loading fails
            return torch.zeros(1, self.audio_processor.n_mels), torch.tensor([0])

        # Truncate very long audio
        if len(waveform) > self.max_audio_length:
            waveform = waveform[:self.max_audio_length]

        # Extract features
        features = self.audio_processor.extract_features(waveform)

        # Convert text to sequence
        text_sequence = torch.tensor(self.text_processor.text_to_sequence(text))

        return features, text_sequence

def collate_fn(batch):
    """
    Custom function to handle variable-length sequences in batches.
    Pads shorter sequences to match the longest in each batch.
    """
    # Separate audio features and text sequences
    features, texts = zip(*batch)

    # Pad audio features to same length
    features_lengths = [f.size(0) for f in features]
    max_feature_length = max(features_lengths)

    padded_features = []
    for f in features:
        # Pad with zeros if shorter than max length
        if f.size(0) < max_feature_length:
            pad_size = max_feature_length - f.size(0)
            f = torch.cat([f, torch.zeros(pad_size, f.size(1))], dim=0)
        padded_features.append(f)

    # Stack into batch tensor
    features_batch = torch.stack(padded_features)

    # Pad text sequences
    text_lengths = [len(t) for t in texts]
    max_text_length = max(text_lengths)

    padded_texts = []
    for t in texts:
        if len(t) < max_text_length:
            # Pad with pad token index (0)
            padding = torch.zeros(max_text_length - len(t), dtype=torch.long)
            t = torch.cat([t, padding])
        padded_texts.append(t)

    texts_batch = torch.stack(padded_texts)

    return features_batch, texts_batch, torch.tensor(features_lengths), torch.tensor(text_lengths)

In [17]:
class HuggingFaceDataset(Dataset):
    """
    PyTorch Dataset wrapper for Hugging Face CoVoST 2 dataset.
    """

    def __init__(self, hf_dataset, text_processor, audio_processor, max_audio_length=16000*10):
        self.hf_dataset = hf_dataset
        self.text_processor = text_processor
        self.audio_processor = audio_processor
        self.max_audio_length = max_audio_length

        print(f"Total samples: {len(self.hf_dataset)}")

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        """
        Returns a single training example: audio features and text sequence.
        """
        try:
            sample = self.hf_dataset[idx]

            # Get audio data directly from HF dataset
            audio_array = sample['audio']['array']
            text = sample['sentence']

            # Convert numpy array to torch tensor
            waveform = torch.tensor(audio_array, dtype=torch.float32)

            # Truncate very long audio
            if len(waveform) > self.max_audio_length:
                waveform = waveform[:self.max_audio_length]

            # Extract features
            features = self.audio_processor.extract_features(waveform)

            # Convert text to sequence
            text_sequence = torch.tensor(self.text_processor.text_to_sequence(text))

            return features, text_sequence

        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            # Return dummy data if there's an error
            return torch.zeros(1, self.audio_processor.n_mels), torch.tensor([0])

In [18]:
# Network

class AttentionMechanism(nn.Module):
    """
    Attention mechanism allows the model to focus on relevant parts
    of the audio when predicting each character.
    """

    def __init__(self, encoder_hidden_size, decoder_hidden_size):
        super().__init__()
        self.encoder_hidden_size = encoder_hidden_size
        self.decoder_hidden_size = decoder_hidden_size

        # Linear layers for computing attention scores
        self.attention = nn.Linear(encoder_hidden_size + decoder_hidden_size, decoder_hidden_size)
        self.v = nn.Linear(decoder_hidden_size, 1, bias=False)

    def forward(self, encoder_outputs, decoder_hidden):
        """
        encoder_outputs: (batch_size, seq_len, encoder_hidden_size)
        decoder_hidden: (batch_size, decoder_hidden_size)
        """
        batch_size, seq_len, _ = encoder_outputs.size()

        # Repeat decoder hidden for each encoder output
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1)

        # Concatenate encoder outputs with decoder hidden
        combined = torch.cat([encoder_outputs, decoder_hidden], dim=2)

        # Calculate attention scores
        attention_scores = self.v(torch.tanh(self.attention(combined)))
        attention_weights = F.softmax(attention_scores.squeeze(2), dim=1)

        # Apply attention weights to encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        return context.squeeze(1), attention_weights

class Encoder(nn.Module):
    """
    Encoder processes audio features and creates a representation
    that captures the acoustic information.
    """

    def __init__(self, input_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Bidirectional LSTM to capture context from both directions
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True,
            batch_first=True
        )

        # Project bidirectional output back to hidden_size
        self.projection = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, features, feature_lengths):
        """
        features: (batch_size, seq_len, input_size)
        feature_lengths: actual lengths before padding
        """
        # Pack sequences for efficient processing
        packed_features = nn.utils.rnn.pack_padded_sequence(
            features, feature_lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        # Process through LSTM
        packed_outputs, (hidden, cell) = self.lstm(packed_features)

        # Unpack sequences
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)

        # Project to desired hidden size
        outputs = self.projection(outputs)

        return outputs, (hidden, cell)

class Decoder(nn.Module):
    """
    Decoder generates text character by character,
    using attention to focus on relevant audio parts.
    """

    def __init__(self, vocab_size, hidden_size, encoder_hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers

        # Character embedding layer
        self.embedding = nn.Embedding(vocab_size, hidden_size)

        # LSTM for sequential processing
        self.lstm = nn.LSTM(
            input_size=hidden_size + encoder_hidden_size,  # embedding + context
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Attention mechanism
        self.attention = AttentionMechanism(encoder_hidden_size, hidden_size)

        # Output projection to vocabulary
        self.out = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder_outputs, target_sequence=None, max_length=100):
        """
        encoder_outputs: (batch_size, seq_len, encoder_hidden_size)
        target_sequence: (batch_size, target_len) - for training
        """
        batch_size = encoder_outputs.size(0)
        device = encoder_outputs.device

        # Initialize decoder hidden state
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        decoder_state = (hidden, cell)

        outputs = []

        if target_sequence is not None:  # Training mode
            target_length = target_sequence.size(1)

            for t in range(target_length - 1):  # Exclude last token
                # Current input character
                if t == 0:
                    # Start with SOS token
                    current_input = torch.full((batch_size,), 2, device=device, dtype=torch.long)  # SOS = 2
                else:
                    current_input = target_sequence[:, t]

                # Get embedding
                embedded = self.embedding(current_input)  # (batch_size, hidden_size)

                # Calculate attention and context
                context, _ = self.attention(encoder_outputs, decoder_state[0][-1])

                # Combine embedding and context
                lstm_input = torch.cat([embedded, context], dim=1).unsqueeze(1)

                # Process through LSTM
                lstm_output, decoder_state = self.lstm(lstm_input, decoder_state)
                lstm_output = self.dropout(lstm_output.squeeze(1))

                # Generate output distribution
                output = self.out(lstm_output)
                outputs.append(output)

            return torch.stack(outputs, dim=1)

        else:  # Inference mode
            current_input = torch.full((batch_size,), 2, device=device, dtype=torch.long)  # SOS

            for t in range(max_length):
                embedded = self.embedding(current_input)
                context, _ = self.attention(encoder_outputs, decoder_state[0][-1])
                lstm_input = torch.cat([embedded, context], dim=1).unsqueeze(1)

                lstm_output, decoder_state = self.lstm(lstm_input, decoder_state)
                output = self.out(lstm_output.squeeze(1))

                outputs.append(output)

                # Get next input (greedy decoding)
                current_input = output.argmax(dim=1)

                # Stop if all sequences generated EOS token
                if (current_input == 3).all():  # EOS = 3
                    break

            return torch.stack(outputs, dim=1)

class Speech2TextModel(nn.Module):
    """
    Complete Speech-to-Text model combining encoder and decoder.
    """

    def __init__(self, vocab_size, input_size=80, encoder_hidden=256, decoder_hidden=256, num_layers=2):
        super().__init__()

        self.encoder = Encoder(
            input_size=input_size,
            hidden_size=encoder_hidden,
            num_layers=num_layers
        )

        self.decoder = Decoder(
            vocab_size=vocab_size,
            hidden_size=decoder_hidden,
            encoder_hidden_size=encoder_hidden,
            num_layers=num_layers
        )

    def forward(self, features, feature_lengths, target_sequence=None):
        # Encode audio features
        encoder_outputs, _ = self.encoder(features, feature_lengths)

        # Decode to text
        decoder_outputs = self.decoder(encoder_outputs, target_sequence)

        return decoder_outputs

In [19]:
# Training

def train_model(model, train_loader, val_loader, text_processor, num_epochs=10, device='cpu'):
    """
    Training loop for the speech-to-text model.
    """
    model = model.to(device)

    # Loss function - CrossEntropyLoss for character prediction
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens

    # Optimizer - Adam is good for sequence models
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0

        for batch_idx, (features, texts, feature_lengths, text_lengths) in enumerate(train_loader):
            features = features.to(device)
            texts = texts.to(device)
            feature_lengths = feature_lengths.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(features, feature_lengths, texts)

            # Calculate loss
            # Reshape for loss calculation
            outputs = outputs.reshape(-1, outputs.size(-1))
            targets = texts[:, 1:].reshape(-1)  # Exclude SOS token

            loss = criterion(outputs, targets)

            # Backward pass
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            train_loss += loss.item()
            train_batches += 1

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_batches = 0

        with torch.no_grad():
            for features, texts, feature_lengths, text_lengths in val_loader:
                features = features.to(device)
                texts = texts.to(device)
                feature_lengths = feature_lengths.to(device)

                outputs = model(features, feature_lengths, texts)

                outputs = outputs.reshape(-1, outputs.size(-1))
                targets = texts[:, 1:].reshape(-1)

                loss = criterion(outputs, targets)
                val_loss += loss.item()
                val_batches += 1

        avg_train_loss = train_loss / train_batches
        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'New best model saved with validation loss: {avg_val_loss:.4f}')

        scheduler.step(avg_val_loss)

def predict_text(model, audio_file, text_processor, audio_processor, device='cpu'):
    """
    Predict text from audio file using trained model.
    """
    model.eval()

    # Load and process audio
    waveform = audio_processor.load_audio(audio_file)
    if waveform is None:
        return "Error loading audio"

    features = audio_processor.extract_features(waveform)
    features = features.unsqueeze(0).to(device)  # Add batch dimension
    feature_lengths = torch.tensor([features.size(1)]).to(device)

    with torch.no_grad():
        # Generate text
        outputs = model(features, feature_lengths)

        # Convert to text
        predicted_sequence = outputs.argmax(dim=-1).squeeze(0)
        predicted_text = text_processor.sequence_to_text(predicted_sequence.cpu().numpy())

    return predicted_text

In [None]:
# main()

def main():
    """
    Main function to train the speech-to-text model.
    """
    # Set device
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')  # For Apple Silicon Macs
    else:
        print("No GPU available, using CPU.")

    print(f"Using device: {device}")

    # Paths - adjust these to your Common Voice dataset location
    # data_dir = os.getenv("COVOST2_DATASET")  # Change this path
    # train_metadata = f"{data_dir}/train.tsv"
    # val_metadata = f"{data_dir}/dev.tsv"

    # Load dataset using Hugging Face
    data_dir = os.getenv("COVOST2_DATASET")
    cv_4_0 = load_dataset(
        "facebook/covost2",
        "en_de",
        data_dir=data_dir,
        trust_remote_code=True
    )

    # Initialize processors
    text_processor = TextProcessor(vocab_size_limit=100)
    audio_processor = AudioProcessor()

    # Build vocabulary from training data
    print("Building vocabulary...")
    train_texts = cv_4_0['train']['sentence']
    text_processor.build_vocabulary(train_texts)

    # Create datasets using the new HuggingFaceDataset class
    print("Creating datasets...")
    train_dataset = HuggingFaceDataset(cv_4_0['train'], text_processor, audio_processor)
    val_dataset = HuggingFaceDataset(cv_4_0['validation'], text_processor, audio_processor)

    """# Initialize processors
    text_processor = TextProcessor(vocab_size_limit=100)  # Limit vocabulary for demo
    audio_processor = AudioProcessor()

    # Build vocabulary from training data
    print("Building vocabulary...")
    train_df = pd.read_csv(train_metadata, sep='\t')
    texts = train_df['sentence'].dropna().tolist()
    text_processor.build_vocabulary(texts)

    # Create datasets
    print("Creating datasets...")
    train_dataset = CommonVoiceDataset(data_dir, train_metadata, text_processor, audio_processor)
    val_dataset = CommonVoiceDataset(data_dir, val_metadata, text_processor, audio_processor)"""

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=8,  # Small batch size for demo
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=8,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0
    )

    # Create model
    print("Creating model...")
    model = Speech2TextModel(
        vocab_size=text_processor.vocab_size,
        input_size=audio_processor.n_mels,
        encoder_hidden=128,  # Smaller for demo
        decoder_hidden=128,
        num_layers=1
    )

    print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

    # Train model
    print("Starting training...")
    train_model(model, train_loader, val_loader, text_processor, num_epochs=1, device=device)


if __name__ == "__main__":
    main()

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')  # For Apple Silicon Macs
else:
    print("No GPU available, using CPU.")

print(f"Using device: {device}")

text_processor = TextProcessor(vocab_size_limit=100)
audio_processor = AudioProcessor()

# Example prediction
print("\nTesting prediction...")
# Use a file from your validation set
data_dir = os.getenv("COVOST2_DATASET")
sample_audio = f"{data_dir}/harvard.wav"  # Change this
model = "./best_model.pth"  # Load your trained model here
if os.path.exists(model):
    model = Speech2TextModel(
        vocab_size=text_processor.vocab_size,
        input_size=audio_processor.n_mels,
        encoder_hidden=128,  # Match your training config
        decoder_hidden=128,
        num_layers=1
    )
    model.load_state_dict(torch.load(model, map_location=device))
else:
    print("Model file not found. Please train the model first.")
if os.path.exists(sample_audio):
    prediction = predict_text(model, sample_audio, text_processor, audio_processor, device)
    print(f"Predicted text: {prediction}")

In [22]:
def load_and_test_model():
    """
    Load trained model and test on audio file.
    """
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        print("No GPU available, using CPU.")

    print(f"Using device: {device}")

    # Load dataset to rebuild vocabulary (needed for text processing)
    data_dir = os.getenv("COVOST2_DATASET")
    cv_4_0 = load_dataset(
        "facebook/covost2",
        "en_de",
        data_dir=data_dir,
        trust_remote_code=True
    )
    """data_cv4 = next(iter(cv_4_0['test']))

    sample_audio = data_cv4['audio']['path'] # f"{data_dir}/harvard.wav"  # Replace with actual audio file
    print(f"sample_audio: {sample_audio}")"""

    # Initialize processors with same settings as training
    text_processor = TextProcessor(vocab_size_limit=100)
    audio_processor = AudioProcessor()

    # Rebuild vocabulary from training data (must match training exactly)
    print("Building vocabulary...")
    train_texts = cv_4_0['train']['sentence']
    text_processor.build_vocabulary(train_texts)

    # Create model with same architecture as training
    model = Speech2TextModel(
        vocab_size=text_processor.vocab_size,
        input_size=audio_processor.n_mels,
        encoder_hidden=128,  # Must match training config
        decoder_hidden=128,
        num_layers=1
    )

    # Load trained weights
    model_path = "./best_model.pth"
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        print("Model loaded successfully!")

        # Test prediction
        sample_audio = f"{data_dir}/clips/common_voice_en_699711.mp3"  # Replace with actual audio file
        # print(f"sample_audio: {sample_audio}")
        if os.path.exists(sample_audio):
            prediction = predict_text(model, sample_audio, text_processor, audio_processor, device)
            print(f"Predicted text: {prediction}")
        else:
            print(f"Audio file not found: {sample_audio}")
            print("Please provide a valid audio file path.")
    else:
        print("Model file not found. Please train the model first.")

# Call the function
load_and_test_model()

Using device: mps
Building vocabulary...
Vocabulary size: 100
Sample characters: ['<PAD>', '<UNK>', '<SOS>', '<EOS>', ' ', 'e', 'a', 't', 'i', 'o', 's', 'n', 'r', 'h', 'l', 'd', 'c', 'u', 'm', 'p']
Model loaded successfully!
Predicted text: the man in the man in the man in the man in the market and the community.


  model.load_state_dict(torch.load(model_path, map_location=device))
