In [None]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoConfig
from torch.optim import AdamW
from tqdm.auto import tqdm
import numpy as np

# --- Configuration ---
MODEL_NAME = "facebook/wav2vec2-base-960h"
DATA_ROOT = "./data_librispeech_asr"
OUTPUT_DIR = "./wav2vec2_librispeech_finetuned"
NUM_EPOCHS = 5
BATCH_SIZE = 2
LEARNING_RATE = 3e-5
GRAD_CLIP_NORM = 1.0
MAX_SAMPLES = 300
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DATA_ROOT, exist_ok=True)

# --- 1. Load Pre-trained Model, Processor and Config ---
print(f"Loading model, processor, and config for: {MODEL_NAME}...")
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.freeze_feature_encoder()

# --- 1.5 Verify Critical Tokenizer and Model Config ---
print("\n--- Verifying Tokenizer and Model Config ---")
if model.config.pad_token_id != processor.tokenizer.pad_token_id:
    print(f"CRITICAL WARNING: Model's pad_token_id ({model.config.pad_token_id}) and "
          f"Processor's pad_token_id ({processor.tokenizer.pad_token_id}) DO NOT MATCH!")
    print("This will likely lead to incorrect CTC behavior. Ensure they are aligned.")
else:
    print(f"Model CTC Blank ID (model.config.pad_token_id): {model.config.pad_token_id}")
    print(f"Processor Padding ID (processor.tokenizer.pad_token_id): {processor.tokenizer.pad_token_id}")
    print("Pad token IDs match. OK.")
print(f"Model inputs_to_logits_ratio: {model_config.inputs_to_logits_ratio}")
print("--- End Verification ---\n")

# --- 2. Load Dataset ---
print("Loading LibriSpeech dataset...")
try:
    librispeech_dataset_full = torchaudio.datasets.LIBRISPEECH(
        root=DATA_ROOT,
        url="train-clean-100",
        download=True
    )
except Exception as e:
    print(f"Error downloading/loading LibriSpeech. If it's a checksum error, try deleting the downloaded archive in {DATA_ROOT}/LibriSpeech and re-running.")
    print(f"Details: {e}")
    exit()

class SubsetDataset(Dataset):
    def __init__(self, full_dataset, num_samples):
        self.full_dataset = full_dataset
        self.indices = list(range(min(num_samples, len(full_dataset))))

    def __getitem__(self, idx):
        return self.full_dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

dataset = SubsetDataset(librispeech_dataset_full, MAX_SAMPLES)
print(f"Using {len(dataset)} samples for training.")

# --- 3. Custom PyTorch Dataset for Preprocessing ---
class AudioDataset(Dataset):
    def __init__(self, torchaudio_dataset, processor, target_sample_rate=16000):
        self.torchaudio_dataset = torchaudio_dataset
        self.processor = processor
        self.target_sample_rate = target_sample_rate
        self.inputs_to_logits_ratio = model_config.inputs_to_logits_ratio

    def __len__(self):
        return len(self.torchaudio_dataset)

    def __getitem__(self, idx):
        try:
            waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = self.torchaudio_dataset[idx]
        except Exception:
            return None

        if not utterance.strip():
            return None

        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
            waveform = resampler(waveform)

        waveform_squeezed = waveform.squeeze(0)
        if waveform_squeezed.ndim == 0:
            return None

        if waveform_squeezed.shape[0] < self.inputs_to_logits_ratio:
            return None

        input_values = self.processor(
            waveform_squeezed,
            sampling_rate=self.target_sample_rate,
            return_tensors="pt"
        ).input_values.squeeze(0)

        # Process transcript for labels (removed `as_target_processor` context)
        labels_dict = self.processor.tokenizer(
            utterance.upper(),
            return_tensors="pt",
            padding=False,
            truncation=False,
            add_special_tokens=False
        )
        labels = labels_dict.input_ids.squeeze(0)

        if labels.nelement() == 0:
            return None

        expected_logits_len = waveform_squeezed.shape[0] // self.inputs_to_logits_ratio
        if expected_logits_len < labels.shape[0]:
            return None

        return {
            "input_values": input_values,
            "labels": labels,
            "id": f"{speaker_id}-{chapter_id}-{utterance_id}",
            "waveform_len_debug": waveform_squeezed.shape[0],
            "label_len_debug": labels.shape[0]
        }

processed_dataset = AudioDataset(dataset, processor)
test_dataset = AudioDataset(librispeech_dataset_full, processor)

# --- 4. Custom Data Collator ---
class CustomDataCollatorCTCWithPadding:
    def __init__(self, processor):
        self.processor = processor
        self.audio_padding_value = 0.0
        self.label_padding_token_id = self.processor.tokenizer.pad_token_id
        if self.label_padding_token_id is None:
            print("Warning: processor.tokenizer.pad_token_id is None. Defaulting label padding to 0.")
            self.label_padding_token_id = 0

    def __call__(self, features):
        valid_features = [f for f in features if f is not None]
        if not valid_features:
            return {}

        input_values_list = [feature["input_values"] for feature in valid_features]
        labels_list = [feature["labels"] for feature in valid_features]

        # Pad audio by passing a dictionary to feature_extractor.pad
        audio_to_pad = {"input_values": input_values_list}
        batch_padded_audio = self.processor.feature_extractor.pad(
            audio_to_pad, # Pass the dictionary
            padding=True,
            return_tensors="pt",
            pad_to_multiple_of=None,
            return_attention_mask=True  # Ensure mask is returned
        )

        # Pad labels
        max_label_len = max(len(lab) for lab in labels_list)
        padded_labels_list = []
        for lab in labels_list:
            padding_needed = max_label_len - len(lab)
            padded_lab = torch.cat([lab, torch.full((padding_needed,), self.label_padding_token_id, dtype=lab.dtype)], dim=0)
            padded_labels_list.append(padded_lab)
        batch_labels = torch.stack(padded_labels_list)

        batch_ids = [f.get("id", "unknown") for f in valid_features]
        waveform_lengths = [f.get("waveform_len_debug", 0) for f in valid_features]
        label_lengths = [f.get("label_len_debug", 0) for f in valid_features]

        return {
            "input_values": batch_padded_audio.input_values,
            "attention_mask": batch_padded_audio.attention_mask,
            "labels": batch_labels,
            "ids_debug": batch_ids,
            "waveform_lengths_debug": waveform_lengths,
            "label_lengths_debug": label_lengths
        }

data_collator = CustomDataCollatorCTCWithPadding(processor=processor)

train_dataloader = DataLoader(
    processed_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
    shuffle=True,
    num_workers=2 # Set to 0 if you still face issues, to debug in main thread
)

# --- 5. Sanity Check Dataloader Output ---
print("\n--- Sanity Checking Dataloader Output ---")
num_valid_samples_in_processed_dataset = 0
for i in range(len(processed_dataset)):
    if processed_dataset[i] is not None:
        num_valid_samples_in_processed_dataset +=1
print(f"Total samples in original subset: {len(dataset)}")
print(f"Effective samples after AudioDataset filtering: {num_valid_samples_in_processed_dataset}")

if num_valid_samples_in_processed_dataset > 0:
    try:
        for i, batch_sanity in enumerate(train_dataloader):
            if i >= 1: break
            if not batch_sanity: print(f"Sanity check: Batch {i} is empty."); continue

            print(f"\nSanity Check Batch {i}:")
            print(f"  Input values shape: {batch_sanity['input_values'].shape}")
            print(f"  Attention mask shape: {batch_sanity['attention_mask'].shape}")
            print(f"  Labels shape: {batch_sanity['labels'].shape}")
            # Ensure labels tensor is not empty before decoding and accessing index 0
            if batch_sanity['labels'].numel() > 0 and len(batch_sanity['labels']) > 0:
                 print(f"  Sample 0 Decoded Labels (first 50 chars): '{processor.decode(batch_sanity['labels'][0].tolist())[:50]}'")
            else:
                print("  Sample 0 Labels are empty or not present for decoding.")

            if batch_sanity.get('waveform_lengths_debug') and len(batch_sanity['waveform_lengths_debug']) > 0:
                wf_len_s0 = batch_sanity['waveform_lengths_debug'][0]
                lbl_len_s0 = batch_sanity['label_lengths_debug'][0]
                exp_logits_s0 = wf_len_s0 // model_config.inputs_to_logits_ratio
                print(f"  Sample 0 CTC check: WF_len={wf_len_s0}, ExpLogits={exp_logits_s0}, Lbl_len={lbl_len_s0}. Valid: {exp_logits_s0 >= lbl_len_s0}")
                if not (exp_logits_s0 >= lbl_len_s0):
                    print(f"    WARNING: Sample 0 may fail CTC length constraint in model!")
            else:
                print("  Debug lengths not available for Sample 0.")
    except Exception as e:
        print(f"Error during dataloader sanity check: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for dataloader errors
        print("This might indicate an issue with data processing or collation.")
else:
    print("Effective dataset size is 0. Cannot perform dataloader sanity check. Check filtering in AudioDataset.")
print("--- End Sanity Checking ---\n")

# --- 6. Optimizer ---
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# --- 7. Training Loop ---
print("Starting training...")
model.train()
avg_epoch_loss = float('nan') # Initialize in case no valid batches
num_valid_batches_epoch = 0   # Initialize

for epoch in range(NUM_EPOCHS):
    epoch_loss, num_valid_batches_epoch, num_nan_batches_epoch = 0, 0, 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    for batch_idx, batch in enumerate(progress_bar):
        if not batch or "input_values" not in batch or batch["input_values"] is None:
            print(f"Skipping empty or invalid batch at epoch {epoch+1}, index {batch_idx}")
            continue

        optimizer.zero_grad()

        input_values = batch["input_values"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        try:
            outputs = model(input_values, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            if torch.isnan(loss) or torch.isinf(loss):
                num_nan_batches_epoch += 1
                print(f"\nNaN/Inf loss detected! Epoch {epoch+1}, Batch {batch_idx}. IDs: {batch.get('ids_debug', 'N/A')}")
                print(f"  Input values: min={input_values.min():.2e}, max={input_values.max():.2e}, mean={input_values.mean():.2e}")
                if hasattr(outputs, 'logits') and outputs.logits is not None:
                    print(f"  Logits: min={outputs.logits.min():.2e}, max={outputs.logits.max():.2e}, mean={outputs.logits.mean():.2e}, has_nan={torch.isnan(outputs.logits).any()}")
                for i in range(labels.size(0)):
                    current_labels = labels[i]
                    effective_label_length = (current_labels != processor.tokenizer.pad_token_id).sum().item()
                    if effective_label_length == 0 and model.config.ctc_zero_infinity is False:
                        print(f"  Sample {i} (ID: {batch.get('ids_debug', ['N/A']*labels.size(0))[i]}) has effective zero-length label. This can cause inf loss if ctc_zero_infinity=False.")
                optimizer.zero_grad(set_to_none=True)
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            optimizer.step()

            epoch_loss += loss.item()
            num_valid_batches_epoch += 1
            progress_bar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]['lr'], "nans": num_nan_batches_epoch})

        except Exception as e:
            print(f"\nError during model forward/backward pass at Epoch {epoch+1}, Batch {batch_idx}: {e}")
            print(f"  Input values shape: {input_values.shape}, attention_mask shape: {attention_mask.shape if attention_mask is not None else 'None'}, labels shape: {labels.shape}")
            print(f"  Batch IDs: {batch.get('ids_debug', 'N/A')}")
            import traceback
            traceback.print_exc()
            optimizer.zero_grad(set_to_none=True) # Ensure grads are cleared
            num_nan_batches_epoch +=1 # Count as a problematic batch
            continue


    avg_epoch_loss = epoch_loss / num_valid_batches_epoch if num_valid_batches_epoch > 0 else float('nan')
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Avg Loss: {avg_epoch_loss:.4f}, Valid Batches: {num_valid_batches_epoch}, NaN Batches: {num_nan_batches_epoch}")

# --- 8. Saving the Model and Processor ---
if num_valid_batches_epoch > 0 and not np.isnan(avg_epoch_loss):
    print(f"\nSaving fine-tuned model and processor to {OUTPUT_DIR}...")
    model.save_pretrained(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)
    print("Training complete and model saved.")
else:
    print("\nSkipping model saving due to NaN losses or no valid batches during training.")



Using device: cuda
Loading model, processor, and config for: facebook/wav2vec2-base-960h...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



--- Verifying Tokenizer and Model Config ---
Model CTC Blank ID (model.config.pad_token_id): 0
Processor Padding ID (processor.tokenizer.pad_token_id): 0
Pad token IDs match. OK.
Model inputs_to_logits_ratio: 320
--- End Verification ---

Loading LibriSpeech dataset...
Using 300 samples for training.

--- Sanity Checking Dataloader Output ---
Total samples in original subset: 300
Effective samples after AudioDataset filtering: 300

Sanity Check Batch 0:
  Input values shape: torch.Size([2, 252880])
  Attention mask shape: torch.Size([2, 252880])
  Labels shape: torch.Size([2, 271])
  Sample 0 Decoded Labels (first 50 chars): 'SO MATHEW AND I HAVE TALKED IT OVER OF AND ON EVER'
  Sample 0 CTC check: WF_len=230800, ExpLogits=721, Lbl_len=271. Valid: True
--- End Sanity Checking ---

Starting training...


Epoch 1/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 1, Batch 1. IDs: ['103-1240-40', '1040-133433-70']
  Input values: min=-1.45e+01, max=1.27e+01, mean=-1.18e-09
  Logits: min=-3.24e+01, max=1.20e+01, mean=-3.35e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 2. IDs: ['1034-121119-25', '1040-133433-15']
  Input values: min=-1.60e+01, max=1.66e+01, mean=7.39e-09
  Logits: min=-2.79e+01, max=7.67e+00, mean=-3.14e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 3. IDs: ['1034-121119-64', '1040-133433-6']
  Input values: min=-1.02e+01, max=1.16e+01, mean=4.27e-09
  Logits: min=-3.33e+01, max=1.39e+01, mean=-3.60e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 17. IDs: ['1069-133699-8', '1034-121119-93']
  Input values: min=-1.03e+01, max=9.04e+00, mean=-4.06e-09
  Logits: min=-2.59e+01, max=9.51e+00, mean=-2.93e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 19. IDs: ['103-1241-28', '1034-121119-39']
  Input values: min=-1.47e+01, max=1.12e+01, mean=4.60e-11
  Logits: min

Epoch 2/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 2, Batch 0. IDs: ['103-1240-42', '1034-121119-24']
  Input values: min=-1.41e+01, max=8.33e+00, mean=4.47e-09
  Logits: min=-2.81e+01, max=1.10e+01, mean=-3.21e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 1. IDs: ['1040-133433-48', '1040-133433-45']
  Input values: min=-1.03e+01, max=1.68e+01, mean=-9.22e-12
  Logits: min=-5.69e+01, max=1.72e+01, mean=-3.38e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 2. IDs: ['1034-121119-26', '1040-133433-32']
  Input values: min=-9.51e+00, max=1.81e+01, mean=2.41e-09
  Logits: min=-2.84e+01, max=9.31e+00, mean=-3.00e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 10. IDs: ['1034-121119-9', '103-1240-37']
  Input values: min=-2.37e+01, max=2.33e+01, mean=-7.72e-10
  Logits: min=-2.88e+01, max=1.10e+01, mean=-3.08e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 16. IDs: ['1040-133433-65', '1034-121119-88']
  Input values: min=-1.10e+01, max=1.74e+01, mean=4.37e-10
  Logits: mi

Epoch 3/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 3, Batch 2. IDs: ['103-1240-19', '1040-133433-6']
  Input values: min=-1.86e+01, max=1.29e+01, mean=5.49e-10
  Logits: min=-5.27e+01, max=1.70e+01, mean=-3.33e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 5. IDs: ['1040-133433-78', '103-1240-5']
  Input values: min=-1.27e+01, max=1.14e+01, mean=-7.42e-10
  Logits: min=-2.92e+01, max=9.71e+00, mean=-2.90e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 16. IDs: ['1040-133433-13', '1040-133433-31']
  Input values: min=-1.58e+01, max=1.88e+01, mean=1.80e-09
  Logits: min=-3.03e+01, max=1.09e+01, mean=-3.18e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 17. IDs: ['1034-121119-24', '1034-121119-64']
  Input values: min=-1.02e+01, max=1.16e+01, mean=4.05e-09
  Logits: min=-3.08e+01, max=1.14e+01, mean=-3.07e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 19. IDs: ['103-1240-32', '1040-133433-19']
  Input values: min=-2.09e+01, max=1.70e+01, mean=-1.32e-09
  Logits: min=-

Epoch 4/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 4, Batch 0. IDs: ['1034-121119-5', '1040-133433-24']
  Input values: min=-8.91e+00, max=1.14e+01, mean=1.60e-09
  Logits: min=-4.38e+01, max=1.49e+01, mean=-3.02e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 4. IDs: ['1040-133433-27', '1034-121119-93']
  Input values: min=-1.03e+01, max=1.28e+01, mean=1.23e-09
  Logits: min=-5.87e+01, max=1.62e+01, mean=-3.26e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 13. IDs: ['1040-133433-38', '1069-133699-8']
  Input values: min=-8.73e+00, max=1.13e+01, mean=-2.80e-09
  Logits: min=-2.57e+01, max=9.59e+00, mean=-2.88e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 14. IDs: ['103-1241-29', '1034-121119-0']
  Input values: min=-1.21e+01, max=1.27e+01, mean=-5.22e-10
  Logits: min=-2.90e+01, max=1.06e+01, mean=-2.88e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 17. IDs: ['103-1241-34', '1034-121119-4']
  Input values: min=-1.44e+01, max=1.09e+01, mean=1.02e-09
  Logits: min=

Epoch 5/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 5, Batch 7. IDs: ['1034-121119-91', '103-1241-19']
  Input values: min=-1.45e+01, max=1.19e+01, mean=2.52e-09
  Logits: min=-2.88e+01, max=1.01e+01, mean=-3.05e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 13. IDs: ['1034-121119-6', '1034-121119-88']
  Input values: min=-9.39e+00, max=8.50e+00, mean=-4.21e-09
  Logits: min=-3.19e+01, max=1.06e+01, mean=-3.09e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 15. IDs: ['103-1240-47', '1034-121119-59']
  Input values: min=-1.45e+01, max=1.18e+01, mean=3.70e-09
  Logits: min=-2.82e+01, max=1.19e+01, mean=-3.06e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 19. IDs: ['103-1241-16', '1040-133433-70']
  Input values: min=-1.51e+01, max=1.27e+01, mean=2.45e-09
  Logits: min=-2.85e+01, max=1.20e+01, mean=-2.94e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 21. IDs: ['1040-133433-26', '1040-133433-79']
  Input values: min=-1.33e+01, max=1.47e+01, mean=-4.08e-09
  Logits: min

In [None]:
# --- 9. Simple Inference Example ---
# Check for either pytorch_model.bin or model.safetensors
model_weights_name_bin = "pytorch_model.bin"
model_weights_name_sf = "model.safetensors" # Standard name for safetensors weights

model_path_bin = os.path.join(OUTPUT_DIR, model_weights_name_bin)
model_path_sf = os.path.join(OUTPUT_DIR, model_weights_name_sf)

if os.path.exists(model_path_bin) or os.path.exists(model_path_sf):
    print("\n--- Inference Example ---")
    try:
        # Wav2Vec2ForCTC.from_pretrained will automatically find the correct weights file
        model_inf = Wav2Vec2ForCTC.from_pretrained(OUTPUT_DIR).to(DEVICE)
        processor_inf = Wav2Vec2Processor.from_pretrained(OUTPUT_DIR)
        model_inf.eval()

        # Make sure there's at least one valid sample in the processed_dataset for inference
        # (or use the original `dataset` if that's preferred and handle None items)
        # Let's try to find the first valid item from processed_dataset for inference
        inference_sample_data = None
        """if len(processed_dataset) > 0: # Check if processed_dataset has items
            for i in range(len(processed_dataset)):
                a = 0
                sample_candidate = processed_dataset.torchaudio_dataset[i] # Get from original subset via processed_dataset
                # We need the original waveform and transcript for this example
                # The processed_dataset.__getitem__ returns processed values or None
                # So we access the underlying torchaudio_dataset item
                if sample_candidate and sample_candidate[2].strip(): # waveform, sr, utterance, ...
                    a+=1
                    waveform_orig, sr_orig, original_transcript, speaker_id, chapter_id, utterance_id = sample_candidate
                    inference_sample_data = sample_candidate
                    if a==5:
                      break"""
        print(len(processed_dataset))
        sample_candidate = test_dataset.torchaudio_dataset[567]
        waveform_orig, sr_orig, original_transcript, speaker_id, chapter_id, utterance_id = sample_candidate
        inference_sample_data = sample_candidate
        if inference_sample_data:
            waveform_orig, sr_orig, original_transcript, speaker_id, chapter_id, utterance_id = inference_sample_data
            print(f"\nOriginal transcript (ID: {speaker_id}-{chapter_id}-{utterance_id}):")
            print(original_transcript)

            # Resample if necessary
            target_sr_inf = processor_inf.feature_extractor.sampling_rate
            if sr_orig != target_sr_inf:
                resampler = torchaudio.transforms.Resample(orig_freq=sr_orig, new_freq=target_sr_inf)
                waveform_orig = resampler(waveform_orig)

            waveform_squeezed_orig = waveform_orig.squeeze(0)
            if waveform_squeezed_orig.ndim == 0: # Should not happen with LibriSpeech
                 raise ValueError("Waveform for inference became a scalar.")

            # Process for inference
            input_dict = processor_inf(
                waveform_squeezed_orig,
                return_tensors="pt",
                sampling_rate=target_sr_inf, # Use the processor's expected sampling rate
                padding=True # Padding might be good for single sample if model expects it
            )
            input_values_inf = input_dict.input_values.to(DEVICE)
            # Ensure attention_mask is passed if generated and model uses it
            attention_mask_inf = input_dict.attention_mask.to(DEVICE) if hasattr(input_dict, 'attention_mask') and input_dict.attention_mask is not None else None


            with torch.no_grad():
                logits_inf = model_inf(input_values_inf, attention_mask=attention_mask_inf).logits

            predicted_ids = torch.argmax(logits_inf, dim=-1)
            predicted_sentence = processor_inf.batch_decode(predicted_ids)[0]
            print(f"\nPredicted transcript:")
            print(predicted_sentence)
        else:
            print("Not enough valid samples in the dataset to run inference example.")
            print(f"(Checked {len(processed_dataset)} entries in processed_dataset for a suitable sample).")

    except Exception as e:
        print(f"Error during inference example: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"\nInference example skipped as model weights ('{model_weights_name_bin}' or '{model_weights_name_sf}') not found in {OUTPUT_DIR}.")

print("\nScript finished.")


--- Inference Example ---
300

Original transcript (ID: 1088-134315-19):
HE RANG THE BELL THIS TIME FOR HIS VALET FISHER HE SAID I AM EXPECTING A VISIT FROM A GENTLEMAN NAMED GATHERCOLE A ONE ARMED GENTLEMAN WHOM YOU MUST LOOK AFTER IF HE COMES

Predicted transcript:
HE RANG THE BELL THIS TIME FOR HIS VALET FISHER HE SAID I AM EXPECTING A VISIT FROM A GENTLEMAN NAMED GATHERCO A ONE ARMED GENTLEMAN WHOM YOU MUST LOOK AFTER IF HE COMES

Script finished.
