In [1]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoConfig
from torch.optim import AdamW
from tqdm.auto import tqdm
import numpy as np

# --- Configuration ---
MODEL_NAME = "facebook/wav2vec2-base-960h"
DATA_ROOT = "./data_librispeech_asr"
OUTPUT_DIR = "./wav2vec2_librispeech_finetuned"
NUM_EPOCHS = 5
BATCH_SIZE = 2
LEARNING_RATE = 3e-5
GRAD_CLIP_NORM = 1.0
MAX_SAMPLES = 300
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DATA_ROOT, exist_ok=True)

# --- 1. Load Pre-trained Model, Processor and Config ---
print(f"Loading model, processor, and config for: {MODEL_NAME}...")
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.freeze_feature_encoder()

# --- 1.5 Verify Critical Tokenizer and Model Config ---
print("\n--- Verifying Tokenizer and Model Config ---")
if model.config.pad_token_id != processor.tokenizer.pad_token_id:
    print(f"CRITICAL WARNING: Model's pad_token_id ({model.config.pad_token_id}) and "
          f"Processor's pad_token_id ({processor.tokenizer.pad_token_id}) DO NOT MATCH!")
    print("This will likely lead to incorrect CTC behavior. Ensure they are aligned.")
else:
    print(f"Model CTC Blank ID (model.config.pad_token_id): {model.config.pad_token_id}")
    print(f"Processor Padding ID (processor.tokenizer.pad_token_id): {processor.tokenizer.pad_token_id}")
    print("Pad token IDs match. OK.")
print(f"Model inputs_to_logits_ratio: {model_config.inputs_to_logits_ratio}")
print("--- End Verification ---\n")

# --- 2. Load Dataset ---
print("Loading LibriSpeech dataset...")
try:
    librispeech_dataset_full = torchaudio.datasets.LIBRISPEECH(
        root=DATA_ROOT,
        url="train-clean-100",
        download=True
    )
except Exception as e:
    print(f"Error downloading/loading LibriSpeech. If it's a checksum error, try deleting the downloaded archive in {DATA_ROOT}/LibriSpeech and re-running.")
    print(f"Details: {e}")
    exit()

class SubsetDataset(Dataset):
    def __init__(self, full_dataset, num_samples):
        self.full_dataset = full_dataset
        self.indices = list(range(min(num_samples, len(full_dataset))))

    def __getitem__(self, idx):
        return self.full_dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

dataset = SubsetDataset(librispeech_dataset_full, MAX_SAMPLES)
print(f"Using {len(dataset)} samples for training.")

# --- 3. Custom PyTorch Dataset for Preprocessing ---
class AudioDataset(Dataset):
    def __init__(self, torchaudio_dataset, processor, target_sample_rate=16000):
        self.torchaudio_dataset = torchaudio_dataset
        self.processor = processor
        self.target_sample_rate = target_sample_rate
        self.inputs_to_logits_ratio = model_config.inputs_to_logits_ratio

    def __len__(self):
        return len(self.torchaudio_dataset)

    def __getitem__(self, idx):
        try:
            waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = self.torchaudio_dataset[idx]
        except Exception:
            return None

        if not utterance.strip():
            return None

        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
            waveform = resampler(waveform)

        waveform_squeezed = waveform.squeeze(0)
        if waveform_squeezed.ndim == 0:
            return None

        if waveform_squeezed.shape[0] < self.inputs_to_logits_ratio:
            return None

        input_values = self.processor(
            waveform_squeezed,
            sampling_rate=self.target_sample_rate,
            return_tensors="pt"
        ).input_values.squeeze(0)

        # Process transcript for labels (removed `as_target_processor` context)
        labels_dict = self.processor.tokenizer(
            utterance.upper(),
            return_tensors="pt",
            padding=False,
            truncation=False,
            add_special_tokens=False
        )
        labels = labels_dict.input_ids.squeeze(0)

        if labels.nelement() == 0:
            return None

        expected_logits_len = waveform_squeezed.shape[0] // self.inputs_to_logits_ratio
        if expected_logits_len < labels.shape[0]:
            return None

        return {
            "input_values": input_values,
            "labels": labels,
            "id": f"{speaker_id}-{chapter_id}-{utterance_id}",
            "waveform_len_debug": waveform_squeezed.shape[0],
            "label_len_debug": labels.shape[0]
        }

processed_dataset = AudioDataset(dataset, processor)
test_dataset = AudioDataset(librispeech_dataset_full, processor)

# --- 4. Custom Data Collator ---
class CustomDataCollatorCTCWithPadding:
    def __init__(self, processor):
        self.processor = processor
        self.audio_padding_value = 0.0
        self.label_padding_token_id = self.processor.tokenizer.pad_token_id
        if self.label_padding_token_id is None:
            print("Warning: processor.tokenizer.pad_token_id is None. Defaulting label padding to 0.")
            self.label_padding_token_id = 0

    def __call__(self, features):
        valid_features = [f for f in features if f is not None]
        if not valid_features:
            return {}

        input_values_list = [feature["input_values"] for feature in valid_features]
        labels_list = [feature["labels"] for feature in valid_features]

        # Pad audio by passing a dictionary to feature_extractor.pad
        audio_to_pad = {"input_values": input_values_list}
        batch_padded_audio = self.processor.feature_extractor.pad(
            audio_to_pad, # Pass the dictionary
            padding=True,
            return_tensors="pt",
            pad_to_multiple_of=None,
            return_attention_mask=True  # Ensure mask is returned
        )

        # Pad labels
        max_label_len = max(len(lab) for lab in labels_list)
        padded_labels_list = []
        for lab in labels_list:
            padding_needed = max_label_len - len(lab)
            padded_lab = torch.cat([lab, torch.full((padding_needed,), self.label_padding_token_id, dtype=lab.dtype)], dim=0)
            padded_labels_list.append(padded_lab)
        batch_labels = torch.stack(padded_labels_list)

        batch_ids = [f.get("id", "unknown") for f in valid_features]
        waveform_lengths = [f.get("waveform_len_debug", 0) for f in valid_features]
        label_lengths = [f.get("label_len_debug", 0) for f in valid_features]

        return {
            "input_values": batch_padded_audio.input_values,
            "attention_mask": batch_padded_audio.attention_mask,
            "labels": batch_labels,
            "ids_debug": batch_ids,
            "waveform_lengths_debug": waveform_lengths,
            "label_lengths_debug": label_lengths
        }

data_collator = CustomDataCollatorCTCWithPadding(processor=processor)

train_dataloader = DataLoader(
    processed_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
    shuffle=True,
    num_workers=2 # Set to 0 if you still face issues, to debug in main thread
)

# --- 5. Sanity Check Dataloader Output ---
print("\n--- Sanity Checking Dataloader Output ---")
num_valid_samples_in_processed_dataset = 0
for i in range(len(processed_dataset)):
    if processed_dataset[i] is not None:
        num_valid_samples_in_processed_dataset +=1
print(f"Total samples in original subset: {len(dataset)}")
print(f"Effective samples after AudioDataset filtering: {num_valid_samples_in_processed_dataset}")

if num_valid_samples_in_processed_dataset > 0:
    try:
        for i, batch_sanity in enumerate(train_dataloader):
            if i >= 1: break
            if not batch_sanity: print(f"Sanity check: Batch {i} is empty."); continue

            print(f"\nSanity Check Batch {i}:")
            print(f"  Input values shape: {batch_sanity['input_values'].shape}")
            print(f"  Attention mask shape: {batch_sanity['attention_mask'].shape}")
            print(f"  Labels shape: {batch_sanity['labels'].shape}")
            # Ensure labels tensor is not empty before decoding and accessing index 0
            if batch_sanity['labels'].numel() > 0 and len(batch_sanity['labels']) > 0:
                 print(f"  Sample 0 Decoded Labels (first 50 chars): '{processor.decode(batch_sanity['labels'][0].tolist())[:50]}'")
            else:
                print("  Sample 0 Labels are empty or not present for decoding.")

            if batch_sanity.get('waveform_lengths_debug') and len(batch_sanity['waveform_lengths_debug']) > 0:
                wf_len_s0 = batch_sanity['waveform_lengths_debug'][0]
                lbl_len_s0 = batch_sanity['label_lengths_debug'][0]
                exp_logits_s0 = wf_len_s0 // model_config.inputs_to_logits_ratio
                print(f"  Sample 0 CTC check: WF_len={wf_len_s0}, ExpLogits={exp_logits_s0}, Lbl_len={lbl_len_s0}. Valid: {exp_logits_s0 >= lbl_len_s0}")
                if not (exp_logits_s0 >= lbl_len_s0):
                    print(f"    WARNING: Sample 0 may fail CTC length constraint in model!")
            else:
                print("  Debug lengths not available for Sample 0.")
    except Exception as e:
        print(f"Error during dataloader sanity check: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for dataloader errors
        print("This might indicate an issue with data processing or collation.")
else:
    print("Effective dataset size is 0. Cannot perform dataloader sanity check. Check filtering in AudioDataset.")
print("--- End Sanity Checking ---\n")

# --- 6. Optimizer ---
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# --- 7. Training Loop ---
print("Starting training...")
model.train()
avg_epoch_loss = float('nan') # Initialize in case no valid batches
num_valid_batches_epoch = 0   # Initialize

for epoch in range(NUM_EPOCHS):
    epoch_loss, num_valid_batches_epoch, num_nan_batches_epoch = 0, 0, 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    for batch_idx, batch in enumerate(progress_bar):
        if not batch or "input_values" not in batch or batch["input_values"] is None:
            print(f"Skipping empty or invalid batch at epoch {epoch+1}, index {batch_idx}")
            continue

        optimizer.zero_grad()

        input_values = batch["input_values"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        try:
            outputs = model(input_values, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            if torch.isnan(loss) or torch.isinf(loss):
                num_nan_batches_epoch += 1
                print(f"\nNaN/Inf loss detected! Epoch {epoch+1}, Batch {batch_idx}. IDs: {batch.get('ids_debug', 'N/A')}")
                print(f"  Input values: min={input_values.min():.2e}, max={input_values.max():.2e}, mean={input_values.mean():.2e}")
                if hasattr(outputs, 'logits') and outputs.logits is not None:
                    print(f"  Logits: min={outputs.logits.min():.2e}, max={outputs.logits.max():.2e}, mean={outputs.logits.mean():.2e}, has_nan={torch.isnan(outputs.logits).any()}")
                for i in range(labels.size(0)):
                    current_labels = labels[i]
                    effective_label_length = (current_labels != processor.tokenizer.pad_token_id).sum().item()
                    if effective_label_length == 0 and model.config.ctc_zero_infinity is False:
                        print(f"  Sample {i} (ID: {batch.get('ids_debug', ['N/A']*labels.size(0))[i]}) has effective zero-length label. This can cause inf loss if ctc_zero_infinity=False.")
                optimizer.zero_grad(set_to_none=True)
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            optimizer.step()

            epoch_loss += loss.item()
            num_valid_batches_epoch += 1
            progress_bar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]['lr'], "nans": num_nan_batches_epoch})

        except Exception as e:
            print(f"\nError during model forward/backward pass at Epoch {epoch+1}, Batch {batch_idx}: {e}")
            print(f"  Input values shape: {input_values.shape}, attention_mask shape: {attention_mask.shape if attention_mask is not None else 'None'}, labels shape: {labels.shape}")
            print(f"  Batch IDs: {batch.get('ids_debug', 'N/A')}")
            import traceback
            traceback.print_exc()
            optimizer.zero_grad(set_to_none=True) # Ensure grads are cleared
            num_nan_batches_epoch +=1 # Count as a problematic batch
            continue


    avg_epoch_loss = epoch_loss / num_valid_batches_epoch if num_valid_batches_epoch > 0 else float('nan')
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Avg Loss: {avg_epoch_loss:.4f}, Valid Batches: {num_valid_batches_epoch}, NaN Batches: {num_nan_batches_epoch}")

# --- 8. Saving the Model and Processor ---
if num_valid_batches_epoch > 0 and not np.isnan(avg_epoch_loss):
    print(f"\nSaving fine-tuned model and processor to {OUTPUT_DIR}...")
    model.save_pretrained(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)
    print("Training complete and model saved.")
else:
    print("\nSkipping model saving due to NaN losses or no valid batches during training.")



Using device: cuda
Loading model, processor, and config for: facebook/wav2vec2-base-960h...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



--- Verifying Tokenizer and Model Config ---
Model CTC Blank ID (model.config.pad_token_id): 0
Processor Padding ID (processor.tokenizer.pad_token_id): 0
Pad token IDs match. OK.
Model inputs_to_logits_ratio: 320
--- End Verification ---

Loading LibriSpeech dataset...


100%|██████████| 5.95G/5.95G [02:51<00:00, 37.2MB/s]


Using 300 samples for training.

--- Sanity Checking Dataloader Output ---
Total samples in original subset: 300
Effective samples after AudioDataset filtering: 300

Sanity Check Batch 0:
  Input values shape: torch.Size([2, 249040])
  Attention mask shape: torch.Size([2, 249040])
  Labels shape: torch.Size([2, 338])
  Sample 0 Decoded Labels (first 50 chars): 'PALE AND COLD ADIEU SIR SHE SAID ADIEU MADAME REPL'
  Sample 0 CTC check: WF_len=166400, ExpLogits=520, Lbl_len=132. Valid: True
--- End Sanity Checking ---

Starting training...


Epoch 1/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 1, Batch 1. IDs: ['103-1240-40', '103-1241-43']
  Input values: min=-1.45e+01, max=1.09e+01, mean=-6.70e-10
  Logits: min=-3.43e+01, max=1.46e+01, mean=-3.85e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 8. IDs: ['1034-121119-78', '1034-121119-80']
  Input values: min=-9.56e+00, max=8.91e+00, mean=1.01e-09
  Logits: min=-3.15e+01, max=1.71e+01, mean=-3.99e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 11. IDs: ['1034-121119-17', '1034-121119-45']
  Input values: min=-1.09e+01, max=1.18e+01, mean=-1.65e-09
  Logits: min=-3.69e+01, max=1.52e+01, mean=-4.13e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 13. IDs: ['103-1240-7', '1034-121119-9']
  Input values: min=-1.57e+01, max=1.11e+01, mean=0.00e+00
  Logits: min=-5.72e+01, max=1.66e+01, mean=-3.35e+00, has_nan=False

NaN/Inf loss detected! Epoch 1, Batch 18. IDs: ['103-1241-23', '1034-121119-12']
  Input values: min=-3.26e+01, max=1.92e+01, mean=3.79e-09
  Logits: min=-2.9

Epoch 2/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 2, Batch 4. IDs: ['103-1240-46', '1034-121119-65']
  Input values: min=-1.51e+01, max=9.70e+00, mean=3.60e-09
  Logits: min=-4.39e+01, max=1.51e+01, mean=-2.90e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 8. IDs: ['103-1241-25', '103-1240-25']
  Input values: min=-2.77e+01, max=2.10e+01, mean=6.78e-10
  Logits: min=-3.15e+01, max=1.19e+01, mean=-3.08e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 14. IDs: ['1069-133699-18', '1034-121119-54']
  Input values: min=-1.06e+01, max=1.06e+01, mean=6.84e-10
  Logits: min=-2.63e+01, max=8.57e+00, mean=-2.85e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 24. IDs: ['1040-133433-32', '1034-121119-38']
  Input values: min=-9.59e+00, max=1.81e+01, mean=3.22e-09
  Logits: min=-2.84e+01, max=1.12e+01, mean=-3.01e+00, has_nan=False

NaN/Inf loss detected! Epoch 2, Batch 31. IDs: ['1034-121119-22', '103-1240-27']
  Input values: min=-1.44e+01, max=9.08e+00, mean=2.77e-09
  Logits: min=-3.3

Epoch 3/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 3, Batch 0. IDs: ['1034-121119-19', '1034-121119-86']
  Input values: min=-9.63e+00, max=9.22e+00, mean=6.93e-09
  Logits: min=-3.24e+01, max=1.23e+01, mean=-3.34e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 5. IDs: ['103-1240-17', '1040-133433-38']
  Input values: min=-1.53e+01, max=1.13e+01, mean=4.12e-10
  Logits: min=-3.23e+01, max=1.19e+01, mean=-3.11e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 13. IDs: ['1034-121119-90', '103-1240-0']
  Input values: min=-1.13e+01, max=1.21e+01, mean=2.78e-09
  Logits: min=-3.07e+01, max=1.22e+01, mean=-3.18e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 15. IDs: ['103-1240-38', '103-1240-25']
  Input values: min=-1.19e+01, max=8.14e+00, mean=1.92e-09
  Logits: min=-3.06e+01, max=1.19e+01, mean=-3.02e+00, has_nan=False

NaN/Inf loss detected! Epoch 3, Batch 21. IDs: ['103-1240-32', '1034-121119-36']
  Input values: min=-2.09e+01, max=1.17e+01, mean=2.10e-10
  Logits: min=-3.68e+0

Epoch 4/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 4, Batch 2. IDs: ['1040-133433-61', '1034-121119-41']
  Input values: min=-1.52e+01, max=1.90e+01, mean=-2.37e-09
  Logits: min=-2.80e+01, max=9.42e+00, mean=-3.12e+00, has_nan=False


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79690c34b380>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79690c34b380>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16


NaN/Inf loss detected! Epoch 4, Batch 10. IDs: ['1034-121119-80', '1040-133433-34']
  Input values: min=-1.01e+01, max=1.26e+01, mean=1.53e-09
  Logits: min=-3.34e+01, max=9.89e+00, mean=-3.21e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 23. IDs: ['1034-121119-6', '1034-121119-60']
  Input values: min=-1.61e+01, max=1.05e+01, mean=-8.48e-09
  Logits: min=-2.98e+01, max=1.01e+01, mean=-3.21e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 26. IDs: ['103-1240-16', '1040-133433-48']
  Input values: min=-2.53e+01, max=1.68e+01, mean=3.68e-09
  Logits: min=-3.06e+01, max=1.09e+01, mean=-3.16e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 38. IDs: ['1034-121119-18', '103-1241-36']
  Input values: min=-1.26e+01, max=8.26e+00, mean=6.89e-10
  Logits: min=-4.92e+01, max=1.95e+01, mean=-3.16e+00, has_nan=False

NaN/Inf loss detected! Epoch 4, Batch 41. IDs: ['1034-121119-29', '103-1241-32']
  Input values: min=-1.30e+01, max=9.05e+00, mean=2.17e-10
  Logits: min

Epoch 5/5:   0%|          | 0/150 [00:00<?, ?it/s]


NaN/Inf loss detected! Epoch 5, Batch 0. IDs: ['1069-133699-1', '1034-121119-35']
  Input values: min=-1.04e+01, max=1.04e+01, mean=3.01e-11
  Logits: min=-3.07e+01, max=1.14e+01, mean=-3.16e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 2. IDs: ['1040-133433-43', '1040-133433-26']
  Input values: min=-1.33e+01, max=1.61e+01, mean=-2.17e-09
  Logits: min=-2.99e+01, max=1.01e+01, mean=-3.34e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 4. IDs: ['1040-133433-15', '1040-133433-70']
  Input values: min=-1.60e+01, max=1.66e+01, mean=-7.58e-10
  Logits: min=-2.77e+01, max=1.03e+01, mean=-3.12e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 11. IDs: ['1040-133433-65', '103-1241-25']
  Input values: min=-2.77e+01, max=2.10e+01, mean=-7.89e-10
  Logits: min=-2.73e+01, max=9.38e+00, mean=-3.05e+00, has_nan=False

NaN/Inf loss detected! Epoch 5, Batch 12. IDs: ['1034-121119-55', '1040-133433-38']
  Input values: min=-8.41e+00, max=1.13e+01, mean=1.33e-08
  Logits

In [4]:
pip install jiwer


Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.1.0 rapidfuzz-3.13.0


In [8]:
# --- 9. Simple Inference Example ---
# Check for either pytorch_model.bin or model.safetensors
model_weights_name_bin = "pytorch_model.bin"
model_weights_name_sf = "model.safetensors" # Standard name for safetensors weights

model_path_bin = os.path.join(OUTPUT_DIR, model_weights_name_bin)
model_path_sf = os.path.join(OUTPUT_DIR, model_weights_name_sf)

if os.path.exists(model_path_bin) or os.path.exists(model_path_sf):
    print("\n--- Inference Example ---")
    try:
        # Wav2Vec2ForCTC.from_pretrained will automatically find the correct weights file
        model_inf = Wav2Vec2ForCTC.from_pretrained(OUTPUT_DIR).to(DEVICE)
        processor_inf = Wav2Vec2Processor.from_pretrained(OUTPUT_DIR)
        model_inf.eval()

        # Your logic to select a sample
        print(f"Length of processed_dataset: {len(processed_dataset)}") # Original print
        # The original code directly picks from test_dataset.torchaudio_dataset
        # Let's stick to that for consistency with the provided snippet.
        sample_idx = 567 # As per your code
        if sample_idx >= len(test_dataset.torchaudio_dataset):
            print(f"Warning: Sample index {sample_idx} is out of bounds for test_dataset. Using index 0.")
            sample_idx = 5
            if len(test_dataset.torchaudio_dataset) == 0:
                raise ValueError("test_dataset.torchaudio_dataset is empty!")


        sample_candidate = test_dataset.torchaudio_dataset[sample_idx]
        waveform_orig, sr_orig, original_transcript, speaker_id, chapter_id, utterance_id = sample_candidate
        inference_sample_data = sample_candidate # Assuming this sample is always valid for the example

        if inference_sample_data:
            waveform_orig, sr_orig, original_transcript, speaker_id, chapter_id, utterance_id = inference_sample_data
            print(f"\nOriginal transcript (ID: {speaker_id}-{chapter_id}-{utterance_id}):")
            print(original_transcript)

            # Resample if necessary
            target_sr_inf = processor_inf.feature_extractor.sampling_rate
            if sr_orig != target_sr_inf:
                print(f"Resampling waveform from {sr_orig} Hz to {target_sr_inf} Hz.")
                resampler = torchaudio.transforms.Resample(orig_freq=sr_orig, new_freq=target_sr_inf)
                waveform_orig_resampled = resampler(waveform_orig)
            else:
                waveform_orig_resampled = waveform_orig

            # Squeeze and check dimensions
            waveform_squeezed_orig = waveform_orig_resampled.squeeze(0)
            if waveform_squeezed_orig.ndim == 0:
                 raise ValueError("Waveform for inference became a scalar.")
            if waveform_squeezed_orig.ndim > 1 : # Ensure it's 1D
                if waveform_squeezed_orig.shape[0] == 1: # like (1, N) after squeeze from (1,1,N)
                    waveform_squeezed_orig = waveform_squeezed_orig.squeeze(0)
                else: # Potentially (C, N) for stereo - ASR models usually expect mono
                    print(f"Warning: Waveform has {waveform_squeezed_orig.shape[0]} channels. Taking the first channel.")
                    waveform_squeezed_orig = waveform_squeezed_orig[0]


            # Process for inference
            # The processor expects a 1D array or list of floats
            input_dict = processor_inf(
                waveform_squeezed_orig.numpy(), # Processor usually expects numpy array or list of floats
                return_tensors="pt",
                sampling_rate=target_sr_inf,
                padding=True
            )
            input_values_inf = input_dict.input_values.to(DEVICE)
            attention_mask_inf = input_dict.attention_mask.to(DEVICE) if hasattr(input_dict, 'attention_mask') and input_dict.attention_mask is not None else None

            with torch.no_grad():
                logits_inf = model_inf(input_values_inf, attention_mask=attention_mask_inf).logits

            predicted_ids = torch.argmax(logits_inf, dim=-1)
            # `batch_decode` handles joining tokens and some cleanup.
            # The output is usually uppercase and with certain punctuation removed/normalized.
            predicted_sentence = processor_inf.batch_decode(predicted_ids)[0]
            print(f"\nPredicted transcript:")
            print(predicted_sentence)

            # --- METRIC CALCULATION ---
            try:
                import jiwer
                import re

                # Normalization function:
                # Wav2Vec2 models are often trained on uppercase text with minimal punctuation.
                # The processor.batch_decode() output is typically already normalized (e.g. uppercase).
                # We need to apply similar normalization to the original_transcript for a fair comparison.
                def normalize_text_for_asr_metrics(text):
                    text = str(text).upper() # Convert to uppercase
                    # Remove common punctuation. This list can be expanded.
                    # Consider what punctuation your model is expected to handle or ignore.
                    # For LibriSpeech, transcripts are typically clean.
                    punctuation_to_remove = r"[.,?!-;:]" # Add others like quotes if needed: r"[.,?!-;:“”\"]"
                    text = re.sub(punctuation_to_remove, "", text)
                    text = text.replace("-", " ") # Treat hyphens as spaces or remove if tokenizer does
                    text = " ".join(text.split()) # Normalize whitespace (remove multiple spaces, trim)
                    return text

                # Normalize both original and predicted transcripts
                # `predicted_sentence` from `batch_decode` is often already quite clean/normalized.
                # Applying the same normalization ensures consistency.
                normalized_original = normalize_text_for_asr_metrics(original_transcript)
                normalized_predicted = normalize_text_for_asr_metrics(predicted_sentence)

                print(f"\n--- Metrics Calculation ---")
                print(f"Normalized Original: '{normalized_original}'")
                print(f"Normalized Predicted: '{normalized_predicted}'")

                if not normalized_original: # Handle cases where original transcript becomes empty
                    print("Warning: Normalized original transcript is empty. Metrics might be misleading.")
                    wer = float('inf') if normalized_predicted else 0.0 # If pred is also empty, 0 error
                    cer = float('inf') if normalized_predicted else 0.0
                else:
                    wer = jiwer.wer(normalized_original, normalized_predicted)
                    cer = jiwer.cer(normalized_original, normalized_predicted)

                print(f"Word Error Rate (WER): {wer:.4f}")
                print(f"Character Error Rate (CER): {cer:.4f}")

            except ImportError:
                print("\n[METRICS] Please install jiwer for metrics calculation: pip install jiwer")
            except Exception as e_metric:
                print(f"[METRICS] Error during metric calculation: {e_metric}")
                import traceback
                traceback.print_exc()
            # --- END OF METRIC CALCULATION ---

        else:
            print("Not enough valid samples in the dataset to run inference example.")
            # The original print was:
            # print(f"(Checked {len(processed_dataset)} entries in processed_dataset for a suitable sample).")
            # This seems mismatched with how the sample was actually selected above (directly from test_dataset).
            # For clarity, if inference_sample_data can be None based on other logic:
            print(f"No suitable inference_sample_data found.")


    except Exception as e:
        print(f"Error during inference example: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"\nInference example skipped as model weights ('{model_weights_name_bin}' or '{model_weights_name_sf}') not found in {OUTPUT_DIR}.")

print("\nScript finished.")


--- Inference Example ---
Length of processed_dataset: 300

Original transcript (ID: 1088-134315-19):
HE RANG THE BELL THIS TIME FOR HIS VALET FISHER HE SAID I AM EXPECTING A VISIT FROM A GENTLEMAN NAMED GATHERCOLE A ONE ARMED GENTLEMAN WHOM YOU MUST LOOK AFTER IF HE COMES

Predicted transcript:
HE RANG THE BELL THIS TIME FOR HIS VALET FISHER HE SAID I AM EXPECTING A VISIT FROM A GENTLEMAN NAMED GATHERCOL A ONE ARMED GENTLEMAN WHOM YOU MUST LOOK AFTER IF HE COMES

--- Metrics Calculation ---
Normalized Original: 'HE RANG THE BELL THIS TIME FOR HIS VALET FISHER HE SAID I AM EXPECTING A VISIT FROM A GENTLEMAN NAMED GATHERCOLE A ONE ARMED GENTLEMAN WHOM YOU MUST LOOK AFTER IF HE COMES'
Normalized Predicted: 'HE RANG THE BELL THIS TIME FOR HIS VALET FISHER HE SAID I AM EXPECTING A VISIT FROM A GENTLEMAN NAMED GATHERCOL A ONE ARMED GENTLEMAN WHOM YOU MUST LOOK AFTER IF HE COMES'
Word Error Rate (WER): 0.0294
Character Error Rate (CER): 0.0058

Script finished.
