-------------------

1. Prepare the dataset with the Whisper processor


In [19]:
from transformers import WhisperProcessor

model_name = "openai/whisper-base"
language = "english" # Change to your dataset's language
task = "transcribe" # Use "translate" if you're translating to English

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)

In [16]:
from datasets import load_dataset, Audio, DatasetDict

# Chunked wav dataset
# dataset_path = "C:\\Users\\dacla\\Documents\\DALI-chunks-wav" #FOR PARS
dataset_path = "T:\\dl-project\\DALI-chunks-lines"
# dataset_path = "C:\\Users\\dacla\\Documents\\chunks" #ela's dataset


#raw_dataset = load_dataset("csv", data_files="metadata-wav.csv", split='train') #FOR PARS
raw_dataset = load_dataset("csv", data_files="metadata-word-level-no-tra.csv", split='train')
print("Full dataset\n", raw_dataset)

# Make a train/test split at this point !
raw_dataset = raw_dataset.train_test_split(test_size=0.1, shuffle=True, seed=555)

raw_dataset_sampled = raw_dataset['test'].train_test_split(train_size=.9, shuffle=True, seed=555)

print("----------\nSplit dataset\n", raw_dataset_sampled)

Full dataset
 Dataset({
    features: ['filename', 'transcript'],
    num_rows: 176774
})
----------
Split dataset
 DatasetDict({
    train: Dataset({
        features: ['filename', 'transcript'],
        num_rows: 15910
    })
    test: Dataset({
        features: ['filename', 'transcript'],
        num_rows: 1768
    })
})


Prepare the dataset


In [21]:
import librosa

def prepare_dataset(batch):
    # Load and resample audio data
    #audio_paths = [f"{dataset_path}\\{fname}" for fname in batch['file-wav']] #PARS
    audio_paths = [f"{dataset_path}\\{fname}" for fname in batch['filename']]
    audio_arrays = [librosa.load(path, sr=16000)[0] for path in audio_paths]
    
    # Compute log-Mel input features from the audio
    batch['input_features'] = processor.feature_extractor(audio_arrays, sampling_rate=16000).input_features

    # Encode the transcriptions to label ids
    batch['labels'] = processor.tokenizer(batch['transcript']).input_ids

    return batch

# Apply the function to the entire dataset
processed_dataset = raw_dataset_sampled.map(prepare_dataset, batched=True, batch_size=8, remove_columns=raw_dataset.column_names["train"])

Map:   0%|          | 0/15910 [00:00<?, ? examples/s]

Map:   0%|          | 0/1768 [00:00<?, ? examples/s]

In [22]:
processed_dataset.save_to_disk('lines-sampled-fixed')

Saving the dataset (0/31 shards):   0%|          | 0/15910 [00:00<?, ? examples/s]

Saving the dataset (0/4 shards):   0%|          | 0/1768 [00:00<?, ? examples/s]

The following does a check to make sure the inputs are formatted correctly

In [118]:
# Assuming processed_dataset is ready
print("\n--- Verifying processed_dataset labels after map ---")
# Get a sample from the processed_dataset (e.g., the first 5 samples)
sample_data = processed_dataset["train"].select(range(min(5, len(processed_dataset["train"]))))

processor_instance = processor # Use the processor you defined earlier

for i, sample in enumerate(sample_data):
    labels = sample["labels"] # These are the token IDs from prepare_dataset

    # Ensure labels is a list (if it came from prepare_dataset's list of lists)
    if isinstance(labels, torch.Tensor):
        labels_list = labels.tolist()
    else: # It's likely a list of lists if batched=True in map
        # If it's a single sample, it might just be a list
        labels_list = labels 
        if isinstance(labels_list[0], list): # If it's a list of lists (from batched=True)
            labels_list = labels_list[0] # Take the first one if you expect single samples here

    decoded_full = processor_instance.tokenizer.decode(labels_list, skip_special_tokens=False)
    decoded_clean = processor_instance.tokenizer.decode(labels_list, skip_special_tokens=True)
    
    eos_id = processor_instance.tokenizer.eos_token_id

    print(f"\nSample {i+1}:")
    print(f"  Raw Labels IDs: {labels_list}")
    print(f"  Decoded (with special tokens): '{decoded_full}'")
    print(f"  Decoded (clean text): '{decoded_clean}'")
    
    if labels_list and labels_list[-1] == eos_id:
        print(f"  Ends with EOS token ({eos_id}): YES")
    else:
        print(f"  Ends with EOS token ({eos_id}): NO - CRITICAL ISSUE AT prepare_dataset!")
        if labels_list:
            print(f"    Last token: {labels_list[-1]}")


--- Verifying processed_dataset labels after map ---

Sample 1:
  Raw Labels IDs: [50257, 50259, 50359, 50257, 50258, 50259, 50359, 50363, 1353, 428, 3172, 50257]
  Decoded (with special tokens): '<|endoftext|><|en|><|transcribe|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>to your wish<|endoftext|>'
  Decoded (clean text): 'to your wish'
  Ends with EOS token (50257): YES

Sample 2:
  Raw Labels IDs: [50257, 50259, 50359, 50257, 50258, 50259, 50359, 50363, 5616, 50257]
  Decoded (with special tokens): '<|endoftext|><|en|><|transcribe|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>you<|endoftext|>'
  Decoded (clean text): 'you'
  Ends with EOS token (50257): YES

Sample 3:
  Raw Labels IDs: [50257, 50259, 50359, 50257, 50258, 50259, 50359, 50363, 13301, 1106, 7670, 1106, 50257]
  Decoded (with special tokens): '<|endoftext|><|en|><|transcribe|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>eh ho eh ho<|endoftext|>

---------------------------

2. Start from here if the dataset is already made
- Create processor
- Define data collator


In [42]:
from transformers import WhisperProcessor

model_name = "openai/whisper-base"
language = "english" 
task = "transcribe" 

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)

For checking formatting


In [None]:

print("\n--- Verifying new class-based data collator ---")
sample_batch = next(iter(train_dataloader))

first_label_ids = sample_batch['labels'][0].tolist()
# Filter out the -100s to see the original tokens for decoding
original_tokens = [id for id in first_label_ids if id != -100] # <-- This is the sequence to check!

decoded_text_full = processor.tokenizer.decode(original_tokens, skip_special_tokens=False)
print(f"\nSample decoded label (with special tokens): {decoded_text_full}")

whisper_eos_id = processor.tokenizer.eos_token_id
if original_tokens and original_tokens[-1] == whisper_eos_id:
    print(f"Last non-padded token is EOS ({whisper_eos_id}): Yes")
else:
    print(f"Last non-padded token is EOS ({whisper_eos_id}): No - this indicates an issue!")
    if original_tokens:
        print(f"  Last token in non-padded sequence: {original_tokens[-1]}")
        print(f"  Expected EOS ID: {whisper_eos_id}")


--- Verifying new class-based data collator ---


NameError: name 'train_dataloader' is not defined

Use this for training

In [23]:
from transformers import DataCollatorForSeq2Seq
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

# --- Data Collator ---
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and need different padding methods.
        # "input_features" for Whisper-based models (vs. "input_values" for wav2vec...)
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, 
                                                     return_tensors="pt",
                                                     return_attention_mask=True)
        
        labels_batch = self.processor.tokenizer.pad(label_features, 
                                                    return_tensors="pt",)

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, padding=True)

Create the Whisper model


In [24]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name)
# This is necessary for the model to work correctly with the Trainer
#model.config.forced_decoder_ids = None
#model.config.suppress_tokens = []

# send to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print(f'Model {model_name} loaded on {device}')


## Fine tuned parameters
# finetuned_model_path = "C:\\Users\\dacla\\Documents\\auto-censoring-local\\whisper-ft"
# model.load_state_dict(torch.load(finetuned_model_path, map_location=device))

Model openai/whisper-base loaded on cuda


Regardless, freeze the parameters not from the final layer

In [25]:
# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Except those in the last layer
for param in model.proj_out.parameters():
        param.requires_grad = True

# Verify which layers are trainable
print("\nTrainable parameters after freezing:")
trainable_params = 0
frozen_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"  - {name} (Trainable, shape: {param.shape})")
    else:
        frozen_params += param.numel()
        # print(f"  - {name} (Frozen)") # Uncomment to see all frozen params

total_params = trainable_params + frozen_params
print(f"\nTotal trainable parameters: {trainable_params}")
print(f"Total frozen parameters: {frozen_params}")
print(f"Total parameters: {total_params}")
print(f"Ratio of trained params to total params: {trainable_params / total_params:.4f}")



Trainable parameters after freezing:
  - model.decoder.embed_tokens.weight (Trainable, shape: torch.Size([51865, 512]))

Total trainable parameters: 26554880
Total frozen parameters: 46039040
Total parameters: 72593920
Ratio of trained params to total params: 0.3658


Downsample the dataset if needed

In [28]:
from datasets import load_from_disk

sample_percentage = .9999

# Load full prepared dataset
prepared_dataset_path = 'lines-sampled-fixed'
prepared_datasets = load_from_disk(prepared_dataset_path)
print("--- Full Prepared Dataset ---")
print(prepared_datasets)

# Sample 1% from the training set
train_split = prepared_datasets["train"]
sampled_train_split = train_split.train_test_split(train_size=sample_percentage, shuffle=True, seed=555)['train'] # We only want the 'train' part of this new split

test_split = prepared_datasets["test"]
sampled_test_split = test_split.train_test_split(train_size=sample_percentage, shuffle=True, seed=555)['train'] 

# Overwrite the original splits with the sampled splits
prepared_datasets['train'] = sampled_train_split
prepared_datasets['test'] = sampled_test_split

print(f"\n--- Sampled ({sample_percentage*100}%) Dataset ---")
print(prepared_datasets)

# Now, use this smaller `prepared_datasets` object for the rest of your script
# (creating DataLoaders, etc.)

Loading dataset from disk:   0%|          | 0/31 [00:00<?, ?it/s]

--- Full Prepared Dataset ---
DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 15910
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 1768
    })
})

--- Sampled (99.99%) Dataset ---
DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 15908
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 1767
    })
})


Define dataloaders, optimizer, etc. Whisper has a built in loss function

In [29]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
import re
from tqdm import tqdm

def remove_punctuation(s):
    s = re.sub(r'[^a-zA-Z0-9\s]', '', s)
    return s.lower()

# Training parameters
learning_rate = .001
train_batch_size = 64 # 64 works with 16GB of VRAM
eval_batch_size = 64

# Defined train and test DLs
train_dataloader = DataLoader(prepared_datasets["train"], shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
eval_dataloader = DataLoader(prepared_datasets["test"], collate_fn=data_collator, batch_size=eval_batch_size)

optimizer = AdamW(model.parameters(), lr=learning_rate)
scaler = torch.amp.GradScaler('cuda')

Set learning rate scheduler, WER metric, and some more training parameters. 

NOTE: jiwer, evaluate.load('wer'), etc. all have different normalization functions. jiwer seems to be the one use for our data

In [33]:
from transformers import get_scheduler
import jiwer

num_train_epochs = 20
num_warmup_steps = 0
total_steps = len(train_dataloader) * num_train_epochs

lr_scheduler = get_scheduler(name="linear",
                             optimizer=optimizer,
                             num_warmup_steps=num_warmup_steps,
                             num_training_steps=total_steps)


# Set initial WER max to inf
best_wer = float('inf')
output_dir = ".\\whisper-ft"

Main training cycle

In [35]:
patience = 0 # use for early stopping of training if no increase in WER is detected

for epoch in range(num_train_epochs):
    # train loop
    model.train()
    train_loss = 0
    
    # Use tqdm for a progress bar
    for batch in tqdm(train_dataloader, desc=f"(Epoch {epoch+1} / {num_train_epochs}) Training "):
        
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss

        # Backwards pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        lr_scheduler.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_dataloader)

    # eval loop
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Generate predictions. Note this is different than model.transcribe (which is used for untrained?)
            generated_ids = model.generate(input_features=batch["input_features"], 
                                           attention_mask=batch["attention_mask"], 
                                           #max_length=225
                                           )              
            
            # Decode predictions
            predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            # Decode labels, replacing -100 with pad token
            labels = batch["labels"].clone()
            labels[labels == -100] = processor.tokenizer.pad_token_id
            labels_str = processor.batch_decode(labels, skip_special_tokens=True)

            # Remove punctuation and capital letters from transcription
            predictions = [remove_punctuation(p) for p in predictions]

            all_predictions.extend(predictions)
            all_labels.extend(labels_str)


    # Compute WER. I'm not sure if truncating at 1 is "mathematically valid", but whatever
    wer = min(jiwer.wer(all_predictions, all_labels), 1)
    
    # # To see the output:
    # for i in range(len(all_predictions)):
    #     pred = all_predictions[i]
    #     actual = all_labels[i]
    #     wer = jiwer.wer(pred, actual)

    #     print(f'Predicted - {pred}')
    #     print(f'Actual    - {actual}')
    #     print(f'Jiwer WER: {wer}')
    #     print()
         
    print(f"Avg training loss: {avg_train_loss:.4f} | Eval. WER: {wer:.5f}")
    print()

    # Save the model if it has the best WER so far
    if wer < best_wer:
        patience = 0 # reset patience counter
        
        best_wer = wer
        print(f"(!) New best WER: {best_wer}. Saving model...")
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
        print()
    
    else: 
        patience += 1

    if patience == 5: 
        print('No increase in WER detected in 5 rounds, breaking')
        break
        
print("\n--- Training Complete ---")
print(f"Best WER achieved: {best_wer}")

(Epoch 1 / 20) Training : 100%|██████████| 249/249 [17:38<00:00,  4.25s/it]
Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Evaluating: 100%|██████████| 28/28 [05:33<00:00, 11.92s/it]


Avg training loss: 1.8088 | Eval. WER: 0.99238

(!) New best WER: 0.9923807320116685. Saving model...
Model saved to .\whisper-ft



(Epoch 2 / 20) Training : 100%|██████████| 249/249 [18:55<00:00,  4.56s/it]
Evaluating: 100%|██████████| 28/28 [05:18<00:00, 11.39s/it]


Avg training loss: 0.9107 | Eval. WER: 0.94071

(!) New best WER: 0.940708446866485. Saving model...
Model saved to .\whisper-ft



(Epoch 3 / 20) Training : 100%|██████████| 249/249 [18:56<00:00,  4.56s/it]
Evaluating: 100%|██████████| 28/28 [05:24<00:00, 11.61s/it]


Avg training loss: 0.5895 | Eval. WER: 0.95109



(Epoch 4 / 20) Training : 100%|██████████| 249/249 [18:58<00:00,  4.57s/it]
Evaluating: 100%|██████████| 28/28 [05:11<00:00, 11.14s/it]


Avg training loss: 0.4256 | Eval. WER: 0.92223

(!) New best WER: 0.9222289208509172. Saving model...
Model saved to .\whisper-ft



(Epoch 5 / 20) Training : 100%|██████████| 249/249 [18:56<00:00,  4.56s/it]
Evaluating: 100%|██████████| 28/28 [05:16<00:00, 11.31s/it]


Avg training loss: 0.3196 | Eval. WER: 0.93610



(Epoch 6 / 20) Training : 100%|██████████| 249/249 [19:01<00:00,  4.58s/it]
Evaluating: 100%|██████████| 28/28 [05:13<00:00, 11.18s/it]


Avg training loss: 0.2505 | Eval. WER: 0.91996

(!) New best WER: 0.9199620461585689. Saving model...
Model saved to .\whisper-ft



(Epoch 7 / 20) Training : 100%|██████████| 249/249 [18:57<00:00,  4.57s/it]
Evaluating: 100%|██████████| 28/28 [04:46<00:00, 10.25s/it]


Avg training loss: 0.2027 | Eval. WER: 0.88106

(!) New best WER: 0.8810576432579493. Saving model...
Model saved to .\whisper-ft



(Epoch 8 / 20) Training : 100%|██████████| 249/249 [19:00<00:00,  4.58s/it]
Evaluating: 100%|██████████| 28/28 [04:46<00:00, 10.23s/it]


Avg training loss: 0.1654 | Eval. WER: 0.88479



(Epoch 9 / 20) Training : 100%|██████████| 249/249 [19:02<00:00,  4.59s/it]
Evaluating: 100%|██████████| 28/28 [04:58<00:00, 10.68s/it]


Avg training loss: 0.1401 | Eval. WER: 0.89339



(Epoch 10 / 20) Training : 100%|██████████| 249/249 [19:03<00:00,  4.59s/it]
Evaluating: 100%|██████████| 28/28 [05:11<00:00, 11.13s/it]


Avg training loss: 0.1184 | Eval. WER: 0.90774



(Epoch 11 / 20) Training : 100%|██████████| 249/249 [18:58<00:00,  4.57s/it]
Evaluating: 100%|██████████| 28/28 [05:04<00:00, 10.86s/it]


Avg training loss: 0.1058 | Eval. WER: 0.89369



(Epoch 12 / 20) Training : 100%|██████████| 249/249 [19:07<00:00,  4.61s/it]
Evaluating: 100%|██████████| 28/28 [04:58<00:00, 10.65s/it]


Avg training loss: 0.0927 | Eval. WER: 0.86782

(!) New best WER: 0.8678222679063683. Saving model...
Model saved to .\whisper-ft



(Epoch 13 / 20) Training : 100%|██████████| 249/249 [19:00<00:00,  4.58s/it]
Evaluating: 100%|██████████| 28/28 [04:50<00:00, 10.38s/it]


Avg training loss: 0.0826 | Eval. WER: 0.87076



(Epoch 14 / 20) Training : 100%|██████████| 249/249 [19:03<00:00,  4.59s/it]
Evaluating: 100%|██████████| 28/28 [05:00<00:00, 10.74s/it]


Avg training loss: 0.0797 | Eval. WER: 0.88265



(Epoch 15 / 20) Training : 100%|██████████| 249/249 [19:02<00:00,  4.59s/it]
Evaluating: 100%|██████████| 28/28 [04:53<00:00, 10.49s/it]


Avg training loss: 0.0673 | Eval. WER: 0.87781



(Epoch 16 / 20) Training : 100%|██████████| 249/249 [19:01<00:00,  4.59s/it]
Evaluating: 100%|██████████| 28/28 [05:04<00:00, 10.87s/it]


Avg training loss: 0.0610 | Eval. WER: 0.88899



(Epoch 17 / 20) Training : 100%|██████████| 249/249 [19:02<00:00,  4.59s/it]
Evaluating: 100%|██████████| 28/28 [04:54<00:00, 10.52s/it]

Avg training loss: 0.0554 | Eval. WER: 0.88406

No increase in WER detected in 5 rounds, breaking

--- Training Complete ---
Best WER achieved: 0.8678222679063683





--------------------------

--------------------------

3. Test run of WER for untrained Whisper

In [54]:
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Generate predictions. Note this is different than model.transcribe (which is used for untrained?)
        generated_ids = model.generate(input_features=batch["input_features"], 
                                        attention_mask=batch["attention_mask"], 
                                        #max_length=225
                                        )              
        
        # Decode predictions
        predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)

        # Decode labels, replacing -100 with pad token
        labels = batch["labels"].clone()
        labels[labels == -100] = processor.tokenizer.pad_token_id
        labels_str = processor.batch_decode(labels, skip_special_tokens=True)

        # Remove punctuation and capital letters from transcription
        predictions = [remove_punctuation(p) for p in predictions]

        all_predictions.extend(predictions)
        all_labels.extend(labels_str)

    # Compute WER
    wer = jiwer.wer(all_predictions, all_labels)
    print(wer)

Evaluating:   0%|          | 0/2 [01:47<?, ?it/s]


KeyboardInterrupt: 

In [36]:
import jiwer
import pandas as pd
import evaluate

# evaluate
wer_metric = evaluate.load("wer")
ev_wer = wer_metric.compute(predictions=all_predictions, references=all_labels)

# jiwer
jwer = jiwer.wer(all_predictions, all_labels)
l = []

print(f'Evaluate wer: {ev_wer:.5f}')
print(f'Jiwer wer: {jwer:.5f}')
print('-----\n')

for i in range(len(all_predictions)):
    pred = all_predictions[i]
    actual = all_labels[i]
    wer = jiwer.wer(pred, actual)


    print(f'Predicted - {pred}')
    print(f'Actual    - {actual}')
    print(f'Jiwer WER: {wer}')
    print()

    l.append([pred, actual, wer])

df_wer = pd.DataFrame(l, columns=['Prediction', 'Actual', 'WER'])


Evaluate wer: 2.51898
Jiwer wer: 0.88406
-----

Predicted - lyeah
Actual    - we
Jiwer WER: 1.0

Predicted - i am soabout
Actual    - i am spellbound
Jiwer WER: 0.3333333333333333

Predicted - dont you know
Actual    - dont you know
Jiwer WER: 0.0

Predicted - oh you gotta see
Actual    - oh you gotta sing
Jiwer WER: 0.25

Predicted - were like diamonds in the sky
Actual    - were like diamonds in the sky
Jiwer WER: 0.0

Predicted - im in love with the light
Actual    - im in love with her eyes
Jiwer WER: 0.3333333333333333

Predicted - eaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheaheahe

In [37]:
import pandas as pd

df_wer = pd.read_csv('small-dataset-whisper-trained-outputs.csv')
pd.set_option('display.max_colwidth', None)

#df_wer = df_wer.sort_values(by='WER')
#df_wer.to_csv('small-dataset-whisper-trained-outputs.csv', index=False)

df_wer['wer-limited'] = df_wer['WER'].apply(lambda x : min(x,1) )

df_wer


Unnamed: 0,Prediction,Actual,WER,wer-limited
0,this is the night this is the night i will be the star let me into your heart this is my life this is the life ive been waiting for ill be back with more,this is the night this is the night i will be the star let me into your heart this is my life this is the life ive been waiting for ill be back with more,0.0,0.0
1,girls who want boys who like boys to be girls who do boys like theyre girls who do girls like theyre boys always should be someone you really love,girls who want boys who like boys to be girls who do boys like theyre girls who do girls like theyre boys always should be someone you really love,0.0,0.0
2,liar killer demon back to the river aras,liar killer demon back to the river aras,0.0,0.0
3,maybe you shouldnt come back,maybe you shouldnt come back,0.0,0.0
4,i feel angry i feel helpless want to change the world yeah i feel violent i feel alone dont try and change my mind no,i feel angry i feel helpless want to change the world yeah i feel violent i feel alone dont try and change my mind no,0.0,0.0
5,just tonight i will stay and well throw it all away when the light hits your eyes its telling me im right and if i i am through and its all because of you just tonight,just tonight i will stay and well throw it all away when the light hits your eyes its telling me im right and if i i am through but its all because of you just tonight,0.027778,0.027778
6,i feel angry and feel helpless want to change the world yeah i feel violent i feel alone dont try and change my mind no,i feel angry i feel helpless want to change the world yeah i feel violent i feel alone dont try and change my mind no,0.04,0.04
7,tell me how to with your heart for i havent got a clue but let me start by saying i love you,tell me how to win your heart for i havent got a clue but let me start by saying i love you,0.045455,0.045455
8,its in the water baby its in the lake bring you down its in the water baby its in your bag of golden brown its in the water baby its in your frequency,its in the water baby its in the pills that bring you down its in the water baby its in your bag of golden brown its in the water baby its in your frequency,0.060606,0.060606
9,the day i first met you you told me youll never fall in love now that i get you i know fear is what it really is,the day i first met you you told me youll never fall in love but now that i get you i know fear is what it really was,0.074074,0.074074


------------------------------

4. Testing a trained model on a single audio track

In [38]:
import torchaudio

def test_transcribe(audio_path):
    # Put in evaluation mode
    model.eval()

    # Load audio file
    print(f"Loading audio from: {audio_path}...")
    waveform, sample_rate = torchaudio.load(audio_path)

    # Resample if necessary (Whisper expects 16kHz)
    if sample_rate != 16000:
        print(f"Resampling audio from {sample_rate}Hz to 16kHz...")
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
        sample_rate = 16000 # Update sample rate after resampling

    # Ensure mono audio (Whisper expects single channel)
    if waveform.shape[0] > 1:
        print("Converting stereo audio to mono...")
        waveform = waveform.mean(dim=0, keepdim=True) # Average channels to mono

    # Convert to numpy array (required by feature_extractor for raw audio)
    audio_array = waveform.squeeze().numpy()

    # Extract features (Mel spectrogram)
    processed_audio = processor.feature_extractor(audio_array, 
                                                  sampling_rate=sample_rate, 
                                                  return_tensors="pt",
                                                  return_attention_mask=True,
                                                  )
 
    input_features = processed_audio.input_features.to(device)
    attention_mask = processed_audio.attention_mask.to(device)

    print("Generating transcription...")
    with torch.no_grad():
        generated_ids = model.generate(input_features=input_features, 
                                       attention_mask=attention_mask,
                                       max_new_tokens=400,
                                       temperature=0.0,
                                       #no_speech_threshold=.3 # Error when using this ?
                                       )
        
    # Create the transcription
    transcription = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription


# Load and preprocess the audio file
audio_path = 'vocals.wav'
print("\nTranscription:\n", test_transcribe(audio_path))

Loading audio from: vocals.wav...


LibsndfileError: Error opening 'vocals.wav': System error.

---------------


This checks for correct formatting

In [52]:
from collections import Counter

# ... (your dataloader setup) ...

print("\n--- Inspecting a sample batch from train_dataloader ---")
num_batches_to_inspect = 2 # Inspect the first 2 batches

for i, batch in enumerate(train_dataloader):
    if i >= num_batches_to_inspect:
        break

    print(f"\nBatch {i+1}:")
    input_features = batch["input_features"] # No need to move to device for inspection
    labels = batch["labels"]

    print(f"  Input Features Shape: {input_features.shape}")
    print(f"  Labels Shape: {labels.shape}")

    # Decode a few labels to see the actual text and special tokens
    for j in range(min(2, labels.shape[0])): # Inspect first 2 samples in batch
        sample_labels = labels[j].tolist()

        # Find the actual end of the non-padded sequence
        # Remember -100 means ignore for loss, but original token was PAD_ID
        true_label_ids = [
            idx for idx in sample_labels if idx != -100
        ]

        # If PAD_ID is same as EOS_ID, then tokens before -100 are still relevant
        # You want to see if the EOS_ID is present at the logical end of the transcript.

        # For Whisper, the actual EOS token is usually 50257 (<|endoftext|>)
        # Let's check for its presence at the end of the non-padded sequence.

        # Get the tokenizer's actual EOS ID
        whisper_eos_id = processor.tokenizer.eos_token_id

        # Decode the sequence, showing special tokens
        decoded_text_with_special = processor.tokenizer.decode(
            true_label_ids, skip_special_tokens=False
        )

        # Decode the sequence, skipping special tokens (for readability)
        decoded_text_without_special = processor.tokenizer.decode(
            true_label_ids, skip_special_tokens=True
        )

        print(f"    Sample {j+1} - Decoded (with special): '{decoded_text_with_special}'")
        print(f"    Sample {j+1} - Decoded (without special): '{decoded_text_without_special}'")

        # Check if EOS token is at the end of the *non-padded* sequence
        if true_label_ids and true_label_ids[-1] == whisper_eos_id:
            print(f"    Sample {j+1} - Ends with EOS token ({whisper_eos_id}): Yes")
        else:
            print(f"    Sample {j+1} - Ends with EOS token ({whisper_eos_id}): No - CHECK THIS!")
            if true_label_ids:
                print(f"      Last token: {true_label_ids[-1]}")


--- Inspecting a sample batch from train_dataloader ---

Batch 1:
  Input Features Shape: torch.Size([32, 80, 3000])
  Labels Shape: torch.Size([32, 51])
    Sample 1 - Decoded (with special): '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>country roads take me home to the place i belong west virginia mountain momma take me home country roads<|endoftext|>'
    Sample 1 - Decoded (without special): 'country roads take me home to the place i belong west virginia mountain momma take me home country roads'
    Sample 1 - Ends with EOS token (50257): Yes
    Sample 2 - Decoded (with special): '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>oh oh oh oh oh oh oh<|endoftext|>'
    Sample 2 - Decoded (without special): 'oh oh oh oh oh oh oh'
    Sample 2 - Ends with EOS token (50257): Yes

Batch 2:
  Input Features Shape: torch.Size([32, 80, 3000])
  Labels Shape: torch.Size([32, 69])
    Sample 1 - Decoded (with special): '<|startoftranscript|><|en|><|transcribe|><|notime