This notebook contains the necessary code to create a Whisper dataset from raw audio files

-------------------

Define Whisper processor
- openai/whisper-base is common


In [None]:
from transformers import WhisperProcessor
import torch

model_name = "openai/whisper-base"
language = "english" 
task = "transcribe" 

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)

  from .autonotebook import tqdm as notebook_tqdm


Metadata file must be created before this can be run. Must contain
- Audio file paths
- Transcript (full transcript, not word-level)


In [None]:
from datasets import load_dataset, Audio, DatasetDict

# Chunked wav dataset

dataset_path = "T:\\dl-project\\DALI-chunks-lines"
# dataset_path = "C:\\Users\\dacla\\Documents\\chunks" #ela's dataset

## Metadata file is crucial
raw_dataset = load_dataset("csv", data_files="metadata-wer0-lines.csv", split='train')
print("Full dataset\n", raw_dataset)

# Make a train/test split at this point !
raw_dataset = raw_dataset.train_test_split(test_size=0.1, shuffle=True, seed=555)

print("----------")
print("Split dataset\n", raw_dataset)

Full dataset
 Dataset({
    features: ['words', 'start', 'end', 'transcript', 'WER', 'filename'],
    num_rows: 25847
})
----------
Split dataset
 DatasetDict({
    train: Dataset({
        features: ['words', 'start', 'end', 'transcript', 'WER', 'filename'],
        num_rows: 23262
    })
    test: Dataset({
        features: ['words', 'start', 'end', 'transcript', 'WER', 'filename'],
        num_rows: 2585
    })
})


Prepare the dataset using the Whisper processor


In [None]:
import librosa

def prepare_dataset(batch):
    # Load and resample audio data
    audio_paths = [f"{dataset_path}\\{fname}" for fname in batch['filename']]
    audio_arrays = [librosa.load(path, sr=16000)[0] for path in audio_paths]
    
    # Compute log-Mel input features from the audio
    batch['input_features'] = processor.feature_extractor(audio_arrays, 
                                                          sampling_rate=16000,
                                                          return_tensors='pt').input_features

    # Encode the transcriptions to label ids
    batch['labels'] = processor.tokenizer(batch['words'],
                                           max_length=processor.tokenizer.model_max_length,
                                           truncation=True,
                                           #padding='do_not_pad', # no padding is the default. Pad in the collator
                                           return_tensors=None).input_ids


    return batch

# Apply the function to the entire dataset
processed_dataset = raw_dataset.map(prepare_dataset, 
                                    batched=True, 
                                    batch_size=8, 
                                    remove_columns=raw_dataset.column_names["train"])

# And save to the disc
processed_dataset.save_to_disk('wer0-dataset-fixed-padding')

Map: 100%|██████████| 23262/23262 [08:27<00:00, 45.81 examples/s]
Map: 100%|██████████| 2585/2585 [01:32<00:00, 27.96 examples/s]
Saving the dataset (45/45 shards): 100%|██████████| 23262/23262 [00:52<00:00, 446.40 examples/s]
Saving the dataset (5/5 shards): 100%|██████████| 2585/2585 [00:06<00:00, 401.79 examples/s]


The following does a check to make sure the inputs are formatted correctly

In [118]:
# Assuming processed_dataset is ready
print("\n--- Verifying processed_dataset labels after map ---")
# Get a sample from the processed_dataset (e.g., the first 5 samples)
sample_data = processed_dataset["train"].select(range(min(5, len(processed_dataset["train"]))))

processor_instance = processor # Use the processor you defined earlier

for i, sample in enumerate(sample_data):
    labels = sample["labels"] # These are the token IDs from prepare_dataset

    # Ensure labels is a list (if it came from prepare_dataset's list of lists)
    if isinstance(labels, torch.Tensor):
        labels_list = labels.tolist()
    else: # It's likely a list of lists if batched=True in map
        # If it's a single sample, it might just be a list
        labels_list = labels 
        if isinstance(labels_list[0], list): # If it's a list of lists (from batched=True)
            labels_list = labels_list[0] # Take the first one if you expect single samples here

    decoded_full = processor_instance.tokenizer.decode(labels_list, skip_special_tokens=False)
    decoded_clean = processor_instance.tokenizer.decode(labels_list, skip_special_tokens=True)
    
    eos_id = processor_instance.tokenizer.eos_token_id

    print(f"\nSample {i+1}:")
    print(f"  Raw Labels IDs: {labels_list}")
    print(f"  Decoded (with special tokens): '{decoded_full}'")
    print(f"  Decoded (clean text): '{decoded_clean}'")
    
    if labels_list and labels_list[-1] == eos_id:
        print(f"  Ends with EOS token ({eos_id}): YES")
    else:
        print(f"  Ends with EOS token ({eos_id}): NO - CRITICAL ISSUE AT prepare_dataset!")
        if labels_list:
            print(f"    Last token: {labels_list[-1]}")


--- Verifying processed_dataset labels after map ---

Sample 1:
  Raw Labels IDs: [50257, 50259, 50359, 50257, 50258, 50259, 50359, 50363, 1353, 428, 3172, 50257]
  Decoded (with special tokens): '<|endoftext|><|en|><|transcribe|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>to your wish<|endoftext|>'
  Decoded (clean text): 'to your wish'
  Ends with EOS token (50257): YES

Sample 2:
  Raw Labels IDs: [50257, 50259, 50359, 50257, 50258, 50259, 50359, 50363, 5616, 50257]
  Decoded (with special tokens): '<|endoftext|><|en|><|transcribe|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>you<|endoftext|>'
  Decoded (clean text): 'you'
  Ends with EOS token (50257): YES

Sample 3:
  Raw Labels IDs: [50257, 50259, 50359, 50257, 50258, 50259, 50359, 50363, 13301, 1106, 7670, 1106, 50257]
  Decoded (with special tokens): '<|endoftext|><|en|><|transcribe|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>eh ho eh ho<|endoftext|>

---------------------------