This

-------------------

1. Prepare the dataset with the Whisper processor


In [65]:
from transformers import WhisperProcessor

model_name = "openai/whisper-base"
language = "english" # Change to your dataset's language
task = "transcribe" # Use "translate" if you're translating to English

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)

In [76]:
from datasets import load_dataset, Audio, DatasetDict

# Chunked wav dataset
# ataset_path = "C:\\Users\\dacla\\Documents\\DALI-chunks-wav" #FOR PARS
dataset_path = "T:\\dl-project\\DALI-chunks-lines"

#raw_dataset = load_dataset("csv", data_files="metadata-wav.csv", split='train') #FOR PARS
raw_dataset = load_dataset("csv", data_files="metadata-word-level.csv", split='train')
print("Full dataset", raw_dataset)

# Make a train/test split at this point !
raw_dataset = raw_dataset.train_test_split(test_size=0.1, shuffle=True, seed=555)
downsampled_dataset = raw_dataset['test'].train_test_split(test_size=0.1, shuffle=True, seed=555)
print("\nSplit dataset", downsampled_dataset)

Generating train split: 0 examples [00:00, ? examples/s]

Full dataset Dataset({
    features: ['filename', 'words', 'transcript'],
    num_rows: 176774
})

Split dataset DatasetDict({
    train: Dataset({
        features: ['filename', 'words', 'transcript'],
        num_rows: 15910
    })
    test: Dataset({
        features: ['filename', 'words', 'transcript'],
        num_rows: 1768
    })
})


I was getting errors with a simpler prepare_dataset function. This one manually sets the tokens correctly

In [77]:
import librosa
from transformers import WhisperProcessor
import torch 

def prepare_dataset(batch, processor: WhisperProcessor, dataset_path: str):
    # Load and resample audio data
    audio_paths = [f"{dataset_path}\\{fname}" for fname in batch['filename']]
    audio_arrays = [librosa.load(path, sr=16000)[0] for path in audio_paths]
    
    # Compute log-Mel input features from the audio
    batch["input_features"] = processor.feature_extractor(audio_arrays, sampling_rate=16000).input_features

    # Encode the transcriptions to label ids
    # THIS IS THE CORRECTED PART:
    # Use the processor's encode function for labels, specifying language and task
    # This automatically adds <|startoftranscript|>, <|lang|>, <|task|>, and <|endoftext|>
    # And potentially <|no_timestamps|> if you enable it.

    # First, get the forced_decoder_ids from the model's generation config
    # This is important to ensure the start tokens are consistent with what Whisper expects for training
    # For fine-tuning, you often want to ensure these are fixed.
    # If your model config isn't available here, you'd define them manually based on the model.
    # For tiny/base/small English, it's typically:
    # <|startoftranscript|> (50258), <|en|> (50259), <|transcribe|> (50359), <|no_timestamps|> (50362)
    
    # Let's define the correct prefix tokens based on common Whisper usage for English ASR
    # Assuming you want English transcription without timestamps
    bos_token_id = processor.tokenizer.bos_token_id
    lang_en_id = processor.tokenizer.convert_tokens_to_ids("<|en|>") # Assuming English
    transcribe_token_id = processor.tokenizer.encode("transcribe", add_special_tokens=False)[0]
    eos_token_id = processor.tokenizer.eos_token_id
    
    prefix_tokens = [bos_token_id, lang_en_id, transcribe_token_id, eos_token_id]

    # Process each transcript in the batch
    tokenized_labels = []
    for transcript in batch["transcript"]:
        # Tokenize the actual text content
        text_ids = processor.tokenizer(
            transcript,
            return_tensors="pt",
            truncation=True, # Ensure truncation, if max_length is defined later
            max_length=processor.tokenizer.model_max_length # Or a custom MAX_LABEL_LENGTH constant
        ).input_ids.squeeze(0) # Remove batch dim

        # Remove the auto-added <|startoftranscript|> from the text_ids if it's there
        # As we're adding our own prefix.
        if text_ids[0] == bos_token_id:
            text_ids = text_ids[1:] # Remove the initial BOS added by default

        # Remove the auto-added <|endoftext|> if it's there, as we'll explicitly add it at the end
        if text_ids[-1] == eos_token_id:
            text_ids = text_ids[:-1]

        # Combine prefix tokens, text tokens, and EOS token
        final_labels = torch.cat([
            torch.tensor(prefix_tokens, dtype=torch.long),
            text_ids,
            torch.tensor([eos_token_id], dtype=torch.long) # Explicitly add EOS at end
        ])
        
        tokenized_labels.append(final_labels.tolist()) # Convert back to list of lists

    batch["labels"] = tokenized_labels
    return batch

from functools import partial
# Assuming 'processor' and 'dataset_path' are defined above this line.
prepare_dataset_with_args = partial(prepare_dataset, processor=processor, dataset_path=dataset_path)

processed_dataset = downsampled_dataset.map(
    prepare_dataset_with_args,
    batched=True,
    batch_size=128,
    remove_columns=downsampled_dataset.column_names["train"]
)

Map:   0%|          | 0/15910 [00:00<?, ? examples/s]

Map:   0%|          | 0/1768 [00:00<?, ? examples/s]

Old prepare_dataset function

In [None]:
# import librosa

# def prepare_dataset(batch):
#     # Load and resample audio data
#     #audio_paths = [f"{dataset_path}\\{fname}" for fname in batch['file-wav']] #PARS
#     audio_paths = [f"{dataset_path}\\{fname}" for fname in batch['filename']]
#     audio_arrays = [librosa.load(path, sr=16000)[0] for path in audio_paths]
    
#     # Compute log-Mel input features from the audio
#     batch['input_features'] = processor.feature_extractor(audio_arrays, sampling_rate=16000).input_features

#     # Encode the transcriptions to label ids
#     batch['labels'] = processor.tokenizer(batch['transcript']).input_ids

#     return batch

# # Apply the function to the entire dataset
# processed_dataset = raw_dataset.map(prepare_dataset, batched=True, batch_size=8, remove_columns=raw_dataset.column_names["train"])

Map:   0%|          | 0/26690 [00:00<?, ? examples/s]

  "cipher": algorithms.TripleDES,
  "class": algorithms.Blowfish,
  "class": algorithms.TripleDES,


Map:   0%|          | 0/2966 [00:00<?, ? examples/s]

In [78]:
# Save dataset to disc
processed_dataset.save_to_disk('dataset-whisper-lines-.1')

Saving the dataset (0/31 shards):   0%|          | 0/15910 [00:00<?, ? examples/s]

Saving the dataset (0/4 shards):   0%|          | 0/1768 [00:00<?, ? examples/s]

The following does a check to make sure the inputs are formatted correctly

In [63]:
# Assuming processed_dataset is ready
print("\n--- Verifying processed_dataset labels after map ---")
# Get a sample from the processed_dataset (e.g., the first 5 samples)
sample_data = processed_dataset["train"].select(range(min(5, len(processed_dataset["train"]))))

processor_instance = processor # Use the processor you defined earlier

for i, sample in enumerate(sample_data):
    labels = sample["labels"] # These are the token IDs from prepare_dataset

    # Ensure labels is a list (if it came from prepare_dataset's list of lists)
    if isinstance(labels, torch.Tensor):
        labels_list = labels.tolist()
    else: # It's likely a list of lists if batched=True in map
        # If it's a single sample, it might just be a list
        labels_list = labels 
        if isinstance(labels_list[0], list): # If it's a list of lists (from batched=True)
            labels_list = labels_list[0] # Take the first one if you expect single samples here

    decoded_full = processor_instance.tokenizer.decode(labels_list, skip_special_tokens=False)
    decoded_clean = processor_instance.tokenizer.decode(labels_list, skip_special_tokens=True)
    
    eos_id = processor_instance.tokenizer.eos_token_id

    print(f"\nSample {i+1}:")
    print(f"  Raw Labels IDs: {labels_list}")
    print(f"  Decoded (with special tokens): '{decoded_full}'")
    print(f"  Decoded (clean text): '{decoded_clean}'")
    
    if labels_list and labels_list[-1] == eos_id:
        print(f"  Ends with EOS token ({eos_id}): YES")
    else:
        print(f"  Ends with EOS token ({eos_id}): NO - CRITICAL ISSUE AT prepare_dataset!")
        if labels_list:
            print(f"    Last token: {labels_list[-1]}")


--- Verifying processed_dataset labels after map ---

Sample 1:
  Raw Labels IDs: [50258, 50259, 50359, 50362, 50258, 50259, 50359, 50363, 332, 264, 472, 567, 534, 6752, 2478, 3186, 220, 488, 668, 6728, 259, 412, 428, 2853, 50257]
  Decoded (with special tokens): '<|startoftranscript|><|en|><|transcribe|><|nocaptions|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>im the one who really loves ya baby ive been knockin at your door<|endoftext|>'
  Decoded (clean text): 'im the one who really loves ya baby ive been knockin at your door'
  Ends with EOS token (50257): YES

Sample 2:
  Raw Labels IDs: [50258, 50259, 50359, 50362, 50258, 50259, 50359, 50363, 474, 321, 3373, 309, 3186, 321, 3727, 259, 293, 3373, 259, 321, 27524, 483, 264, 24244, 570, 321, 50257]
  Decoded (with special tokens): '<|startoftranscript|><|en|><|transcribe|><|nocaptions|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>and we roll it baby we rockin and rollin we wont get the blues because we<|end

---------------------------

2. Start from here if the dataset is already made
- Create processor
- Define data collator


In [79]:
from transformers import WhisperProcessor

model_name = "openai/whisper-base"
language = "english" 
task = "transcribe" 

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)

This is old. Not needed anymore

In [None]:
# class CustomWhisperDataCollator:
#     """
#     Data collator for Whisper fine-tuning, implemented as a class.
#     It pads input_features (log-mel spectrograms) and labels (tokenized transcripts).
#     """
#     def __init__(self, processor: WhisperProcessor):
#         self.processor = processor

#     def __call__(self, features):
#         """
#         Args:
#             features (list): A list of dictionaries, where each dictionary
#                              represents a single data sample (e.g., from a Dataset's __getitem__).
#                              Expected to have:
#                              - 'input_features': A PyTorch tensor of log-mel spectrograms.
#                              - 'labels': A PyTorch tensor of tokenized transcript IDs.

#         Returns:
#             dict: A dictionary containing padded 'input_features' and 'labels' tensors.
#         """
#         input_features = [feature["input_features"] for feature in features]
#         labels = [feature["labels"] for feature in features]

#         # --- Padding Input Features (Audio Spectrograms) ---
#         max_input_frames = max(f.shape[-1] for f in input_features)
        
#         padded_input_features = []
#         for feat in input_features:
#             padding_needed = max_input_frames - feat.shape[-1]
#             padded_feat = torch.nn.functional.pad(feat, (0, padding_needed), "constant", 0.0)
#             padded_input_features.append(padded_feat)
        
#         input_features_batch = torch.stack(padded_input_features)

#         # --- Padding Labels (Tokenized Transcripts) ---
#         max_label_len = max(len(l) for l in labels)

#         padded_labels = []
#         for label in labels:
#             padding_needed = max_label_len - len(label)
#             padded_label = torch.nn.functional.pad(
#                 label, (0, padding_needed), "constant", self.processor.tokenizer.pad_token_id
#             )
#             padded_labels.append(padded_label)
        
#         labels_batch = torch.stack(padded_labels)

#         # --- Replace pad_token_id with -100 for loss computation ---
#         labels_batch = labels_batch.masked_fill(
#             labels_batch == self.processor.tokenizer.pad_token_id, -100
#         )

#         return {
#             "input_features": input_features_batch,
#             "labels": labels_batch,
#         }
    
# data_collator = CustomWhisperDataCollator(processor=processor)

For checking formatting


In [None]:

print("\n--- Verifying new class-based data collator ---")
sample_batch = next(iter(train_dataloader))

first_label_ids = sample_batch['labels'][0].tolist()
# Filter out the -100s to see the original tokens for decoding
original_tokens = [id for id in first_label_ids if id != -100] # <-- This is the sequence to check!

decoded_text_full = processor.tokenizer.decode(original_tokens, skip_special_tokens=False)
print(f"\nSample decoded label (with special tokens): {decoded_text_full}")

whisper_eos_id = processor.tokenizer.eos_token_id
if original_tokens and original_tokens[-1] == whisper_eos_id:
    print(f"Last non-padded token is EOS ({whisper_eos_id}): Yes")
else:
    print(f"Last non-padded token is EOS ({whisper_eos_id}): No - this indicates an issue!")
    if original_tokens:
        print(f"  Last token in non-padded sequence: {original_tokens[-1]}")
        print(f"  Expected EOS ID: {whisper_eos_id}")


--- Verifying new class-based data collator ---


NameError: name 'train_dataloader' is not defined

Use this for training

In [88]:
from transformers import DataCollatorForSeq2Seq
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

# --- Data Collator ---
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and need different padding methods.
        # "input_features" for Whisper-based models (vs. "input_values" for wav2vec...)
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, 
                                                     return_tensors="pt",
                                                     return_attention_mask=True)
        
        labels_batch = self.processor.tokenizer.pad(label_features, 
                                                    return_tensors="pt",)

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, padding=True)

Create the Whisper model


In [81]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name)
# This is necessary for the model to work correctly with the Trainer
#model.config.forced_decoder_ids = None
#model.config.suppress_tokens = []

# send to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print(f'Model {model_name} loaded on {device}')

Model openai/whisper-base loaded on cuda


If continuing to train then load the previous parameters

In [6]:
finetuned_model_path = "C:\\Users\\dacla\\Documents\\auto-censoring-local\\whisper-ft"
model.load_state_dict(torch.load(finetuned_model_path, map_location=device))

PermissionError: [Errno 13] Permission denied: 'C:\\Users\\dacla\\Documents\\auto-censoring-local\\whisper-ft'

Regardless, freeze the parameters not from the final layer

In [82]:
# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Except those in the last layer
for param in model.proj_out.parameters():
        param.requires_grad = True

# Verify which layers are trainable
print("\nTrainable parameters after freezing:")
trainable_params = 0
frozen_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"  - {name} (Trainable, shape: {param.shape})")
    else:
        frozen_params += param.numel()
        # print(f"  - {name} (Frozen)") # Uncomment to see all frozen params

total_params = trainable_params + frozen_params
print(f"\nTotal trainable parameters: {trainable_params}")
print(f"Total frozen parameters: {frozen_params}")
print(f"Total parameters: {total_params}")
print(f"Ratio of trained params to total params: {trainable_params / total_params:.4f}")



Trainable parameters after freezing:
  - model.decoder.embed_tokens.weight (Trainable, shape: torch.Size([51865, 512]))

Total trainable parameters: 26554880
Total frozen parameters: 46039040
Total parameters: 72593920
Ratio of trained params to total params: 0.3658


Downsample the dataset if needed

In [90]:
from datasets import load_from_disk

sample_percentage = .5

# Load full prepared dataset
prepared_dataset_path = 'dataset-whisper-lines-.1'
prepared_datasets = load_from_disk(prepared_dataset_path)
print("--- Full Prepared Dataset ---")
print(prepared_datasets)

# Sample 1% from the training set
train_split = prepared_datasets["train"]
sampled_train_split = train_split.train_test_split(train_size=sample_percentage, shuffle=True, seed=555)['train'] # We only want the 'train' part of this new split

test_split = prepared_datasets["test"]
sampled_test_split = test_split.train_test_split(train_size=sample_percentage, shuffle=True, seed=555)['train'] 

# Overwrite the original splits with the sampled splits
prepared_datasets['train'] = sampled_train_split
prepared_datasets['test'] = sampled_test_split

print(f"\n--- Sampled ({sample_percentage*100}%) Dataset ---")
print(prepared_datasets)

# Now, use this smaller `prepared_datasets` object for the rest of your script
# (creating DataLoaders, etc.)

Loading dataset from disk:   0%|          | 0/31 [00:00<?, ?it/s]

--- Full Prepared Dataset ---
DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 15910
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 1768
    })
})

--- Sampled (50.0%) Dataset ---
DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 7955
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 884
    })
})


Define dataloaders, optimizer, etc. Whisper has a built in loss function

In [93]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
import re
from tqdm import tqdm

def remove_punctuation(s):
    s = re.sub(r'[^a-zA-Z0-9\s]', '', s)
    return s.lower()

# Training parameters
learning_rate = .0002
train_batch_size = 64 # 16 might work?
eval_batch_size = 64

# Defined train and test DLs
train_dataloader = DataLoader(prepared_datasets["train"], shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
eval_dataloader = DataLoader(prepared_datasets["test"], collate_fn=data_collator, batch_size=eval_batch_size)

optimizer = AdamW(model.parameters(), lr=learning_rate)
scaler = torch.amp.GradScaler('cuda')

Set learning rate scheduler, WER metric, and some more training parameters. 

NOTE: jiwer, evaluate.load('wer'), etc. all have different normalization functions. jiwer seems to be the one use for our data

In [None]:
from transformers import get_scheduler
import jiwer

num_train_epochs = 20
num_warmup_steps = 0
total_steps = len(train_dataloader) * num_train_epochs

lr_scheduler = get_scheduler(name="linear",
                             optimizer=optimizer,
                             num_warmup_steps=num_warmup_steps,
                             num_training_steps=total_steps)


# Set initial WER max to inf
best_wer = float('inf')
output_dir = ".\\whisper-ft"

Main training cycle

In [None]:
for epoch in range(num_train_epochs):
    # --- TRAINING ---
    model.train()
    train_loss = 0
    
    # Use tqdm for a progress bar
    for batch in tqdm(train_dataloader, desc=f"Training epoch {epoch + 1}/{num_train_epochs}"):
        
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss

        # Backwards pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        lr_scheduler.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1} | Average training Loss: {avg_train_loss:.4f}")

    # --- EVALUATION ---
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Generate predictions. Note this is different than model.transcribe (which is used for untrained?)
            generated_ids = model.generate(input_features=batch["input_features"], 
                                           attention_mask=batch["attention_mask"], 
                                           #max_length=225
                                           )              
            
            # Decode predictions
            predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            # Decode labels, replacing -100 with pad token
            labels = batch["labels"].clone()
            labels[labels == -100] = processor.tokenizer.pad_token_id
            labels_str = processor.batch_decode(labels, skip_special_tokens=True)

            # Remove punctuation and capital letters from transcription
            predictions = [remove_punctuation(p) for p in predictions]

            all_predictions.extend(predictions)
            all_labels.extend(labels_str)

    # Compute WER
    wer = jiwer.wer(all_predictions, all_labels)
    
    # To see the output
    # for i in range(len(all_predictions)):
    #     print("Prediction:", all_predictions[i])
    #     print("Actual:", all_labels[i])
    #     print()
         
    print(f"WER: {wer:.5f}")
    print()

    # Save the model if it has the best WER so far
    if wer < best_wer:
        best_wer = wer
        print(f"New best WER: {best_wer}. Saving model...")
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
        print()
        
print("\n--- Training Complete ---")
print(f"Best WER achieved: {best_wer}")

Training epoch 1/20:  66%|██████▌   | 82/125 [06:05<03:11,  4.46s/it]

--------------------------

3. Test run of WER for untrained Whisper

In [40]:
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Generate predictions. Note this is different than model.transcribe (which is used for untrained?)
        generated_ids = model.generate(input_features=batch["input_features"], 
                                        attention_mask=batch["attention_mask"], 
                                        #max_length=225
                                        )              
        
        # Decode predictions
        predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)

        # Decode labels, replacing -100 with pad token
        labels = batch["labels"].clone()
        labels[labels == -100] = processor.tokenizer.pad_token_id
        labels_str = processor.batch_decode(labels, skip_special_tokens=True)

        # Remove punctuation and capital letters from transcription
        predictions = [remove_punctuation(p) for p in predictions]

        all_predictions.extend(predictions)
        all_labels.extend(labels_str)

    # Compute WER
    wer = jiwer.wer(all_predictions, all_labels)
    print(wer)

Evaluating: 100%|██████████| 37/37 [01:13<00:00,  1.97s/it]

0.8119815668202764





In [57]:
import jiwer
import pandas as pd


# evaluate
wer_metric = evaluate.load("wer")
ev_wer = wer_metric.compute(predictions=all_predictions, references=all_labels)

# jiwer
jwer = jiwer.wer(all_predictions, all_labels)
l = []

print(f'Evaluate wer: {ev_wer:.5f}')
print(f'Jiwer wer: {jwer:.5f}')
print('-----\n')

for i in range(len(all_predictions)):
    pred = all_predictions[i]
    actual = all_labels[i]
    wer = jiwer.wer(pred, actual)


    print(f'Predicted - {pred}')
    print(f'Actual    - {actual}')
    print(f'Jiwer WER: {wer}')
    print()

    l.append([pred, actual, wer])

df_wer = pd.DataFrame(l, columns=['Prediction', 'Actual', 'WER'])
df_wer.to_csv('wer-.1-sample-seed-555.csv', index=False)

Evaluate wer: 1.23883
Jiwer wer: 0.81198
-----

Predicted -  and you i see juddy and you i come stomp and you i feel so pretty and you i take sky and you i feel so hungry and you i crash past we must never be apart
Actual    - in you i see dirty in you i count stars in you i feel so pretty in you i taste god in you i feel so hungry in you i crash cars we must never be apart
Jiwer WER: 0.32432432432432434

Predicted -  cause if you go i go cause if you go i go
Actual    - cause if you go i go cause if you go i go
Jiwer WER: 0.0

Predicted -  but now its getting late and the moon is primal i want a celebrate see it shining in your eyes
Actual    - but now its gettin late and the moon is climbin high i want to celebrate see it shinin in your eye
Jiwer WER: 0.3

Predicted -  oh why cant you hear me in the street why cant i kiss you on the dance floor i wish that it could be mine i wish that it could be mine i wish that you could be mine i wish that you could be mine i wish that you could b

In [59]:
df_wer[df_wer['WER'] > 1]


Unnamed: 0,Prediction,Actual,WER
12,im happy for you,if happy is her if happy is her im happy for you,2.0
34,oh,woh uh oh ho now theres no welcome look in you...,39.0
37,i,taste me drink my soul show me all the things ...,49.0
45,i cant take it,theres nothing i can do im such a fool for you...,23.5
50,let them be,let there be love here in the here in the dark...,8.333333
52,we can,we are the ones we get knocked down we get bac...,19.5
60,i never take to be told wont you listen tonig...,i have a tale to be told wont you listen tonig...,1.315789
61,take everything i want,go on take everything take everything i want y...,4.0
70,im the geek in the pink,like the geek in the pink do do do do do do do...,8.5
74,,yeah yeah yeah yeah yeah yeah,6.0


------------------------------

4. Testing a trained model on a single audio track

In [94]:
import torchaudio

def test_transcribe(audio_path):
    # Put in evaluation mode
    model.eval()

    # Load audio file
    print(f"Loading audio from: {audio_path}...")
    waveform, sample_rate = torchaudio.load(audio_path)

    # Resample if necessary (Whisper expects 16kHz)
    if sample_rate != 16000:
        print(f"Resampling audio from {sample_rate}Hz to 16kHz...")
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
        sample_rate = 16000 # Update sample rate after resampling

    # Ensure mono audio (Whisper expects single channel)
    if waveform.shape[0] > 1:
        print("Converting stereo audio to mono...")
        waveform = waveform.mean(dim=0, keepdim=True) # Average channels to mono

    # Convert to numpy array (required by feature_extractor for raw audio)
    audio_array = waveform.squeeze().numpy()

    # Extract features (Mel spectrogram)
    processed_audio = processor.feature_extractor(audio_array, 
                                                  sampling_rate=sample_rate, 
                                                  return_tensors="pt",
                                                  return_attention_mask=True,
                                                  )
 
    input_features = processed_audio.input_features.to(device)
    attention_mask = processed_audio.attention_mask.to(device)

    print("Generating transcription...")
    with torch.no_grad():
        generated_ids = model.generate(input_features=input_features, 
                                       attention_mask=attention_mask,
                                       max_new_tokens=400,
                                       temperature=0.0,
                                       #no_speech_threshold=.3 # Error when using this ?
                                       )
        
    # Create the transcription
    transcription = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription


# Load and preprocess the audio file
audio_path = 'vocals.wav'
print("\nTranscription:\n", test_transcribe(audio_path))

Loading audio from: vocals.wav...
Resampling audio from 44100Hz to 16kHz...
Converting stereo audio to mono...
Generating transcription...

Transcription:
  I can slowly say I don't give a fuck about your money Begun it means so much to you and all of your money Why doesn't it mean so much to you? You cite the land of greed and I'm talking about a world of need Money has nothing to do with the value of life


---------------


This checks for correct formatting

In [84]:
from collections import Counter

# ... (your dataloader setup) ...

print("\n--- Inspecting a sample batch from train_dataloader ---")
num_batches_to_inspect = 2 # Inspect the first 2 batches

for i, batch in enumerate(train_dataloader):
    if i >= num_batches_to_inspect:
        break

    print(f"\nBatch {i+1}:")
    input_features = batch["input_features"] # No need to move to device for inspection
    labels = batch["labels"]

    print(f"  Input Features Shape: {input_features.shape}")
    print(f"  Labels Shape: {labels.shape}")

    # Decode a few labels to see the actual text and special tokens
    for j in range(min(2, labels.shape[0])): # Inspect first 2 samples in batch
        sample_labels = labels[j].tolist()

        # Find the actual end of the non-padded sequence
        # Remember -100 means ignore for loss, but original token was PAD_ID
        true_label_ids = [
            idx for idx in sample_labels if idx != -100
        ]

        # If PAD_ID is same as EOS_ID, then tokens before -100 are still relevant
        # You want to see if the EOS_ID is present at the logical end of the transcript.

        # For Whisper, the actual EOS token is usually 50257 (<|endoftext|>)
        # Let's check for its presence at the end of the non-padded sequence.

        # Get the tokenizer's actual EOS ID
        whisper_eos_id = processor.tokenizer.eos_token_id

        # Decode the sequence, showing special tokens
        decoded_text_with_special = processor.tokenizer.decode(
            true_label_ids, skip_special_tokens=False
        )

        # Decode the sequence, skipping special tokens (for readability)
        decoded_text_without_special = processor.tokenizer.decode(
            true_label_ids, skip_special_tokens=True
        )

        print(f"    Sample {j+1} - Decoded (with special): '{decoded_text_with_special}'")
        print(f"    Sample {j+1} - Decoded (without special): '{decoded_text_without_special}'")

        # Check if EOS token is at the end of the *non-padded* sequence
        if true_label_ids and true_label_ids[-1] == whisper_eos_id:
            print(f"    Sample {j+1} - Ends with EOS token ({whisper_eos_id}): Yes")
        else:
            print(f"    Sample {j+1} - Ends with EOS token ({whisper_eos_id}): No - CHECK THIS!")
            if true_label_ids:
                print(f"      Last token: {true_label_ids[-1]}")


--- Inspecting a sample batch from train_dataloader ---

Batch 1:
  Input Features Shape: torch.Size([8, 80, 3000])
  Labels Shape: torch.Size([8, 76])
    Sample 1 - Decoded (with special): '<|startoftranscript|><|en|><|transcribe|><|nocaptions|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>girl youre my angel youre my darling angel closer than my peeps you are to me baby shorty youre my angel youre my darling angel girl youre my friend when im in need lady<|endoftext|>'
    Sample 1 - Decoded (without special): 'girl youre my angel youre my darling angel closer than my peeps you are to me baby shorty youre my angel youre my darling angel girl youre my friend when im in need lady'
    Sample 1 - Ends with EOS token (50257): Yes
    Sample 2 - Decoded (with special): '<|startoftranscript|><|en|><|transcribe|><|nocaptions|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>ive been down ive been beat ive been so tired i could not speak ive been so lost that i could not