# Whisper Fine-tuning

Purchase a copy of this notebook at [Trelis.com](https://trelis.com/ADVANCED-transcription).

Notebook built by Trelis Research. Find us at [Trelis.com](https://trelis.com/About), [YouTube](https://YouTube.com/@TrelisResearch) and on [HuggingFace](https://huggingface.co/Trelis).

*Subscribe to Trelis Research emails [here](https://blog.trelis.com) and get notified each time a new video tutorial is published.*

Built upon an [original notebook](https://colab.research.google.com/drive/1DOkD_5OUjFa0r5Ik3SgywJLJtEo2qLxO?usp=sharing#scrollTo=8d230e6d-624c-400a-bbf5-fa660881df25) by HuggingFace.

## Installation

In [1]:
import subprocess

# Run the command and capture the output
add_repo_output = subprocess.run(['add-apt-repository', '-y', 'ppa:jonathonf/ffmpeg-4'], capture_output=True, text=True)
update_output = subprocess.run(['apt', 'update', '-q'], capture_output=True, text=True)
install_ffmpeg_output = subprocess.run(['apt', 'install', '-y', 'ffmpeg', '-q'], capture_output=True, text=True)

# # Print the output from each command
# print("Add repository output:")
# print(add_repo_output.stdout)

# print("Update output:")
# print(update_output.stdout)

# print("Install FFmpeg output:")
# print(install_ffmpeg_output.stdout)

In [2]:
# Install dependencies with compatible versions
!pip install datasets -qU
!pip install transformers peft hf_transfer -qU
!pip install librosa -qU
!pip install evaluate -q
!pip install jiwer -qU
!pip install gradio -qU
!pip install bitsandbytes accelerate loralib -qU
!pip install numba
!pip install soundfile

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To updat

In [3]:
# allow for fast weight uploads and downloads
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
import os

# Select CUDA device index - this just let's the script know we want to use one gpu (the 0th one).
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# Input model
model_name_or_path = "openai/whisper-large-v3" # this will use more vram than whisper-small but is higher quality (and 2x faster at inference)
language = "Uzbek"
language_abbr = "uz"
task = "transcribe"

# Output model path (for pushing to HuggingFace)
org = "bekzod123"
trained_adapter_name = "whisper-turbo-llm-lingo-adapters"
trained_model_name = "whisper-turbo-llm-lingo"

trained_adapter_repo = org + "/" + trained_adapter_name
trained_model_repo = org + '/' + trained_model_name

In [None]:
from transformers import pipeline
from transformers import (
    AutomaticSpeechRecognitionPipeline,
    WhisperTimeStampLogitsProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperProcessor,
)
import torch

In [None]:
# # These two lines allow the first tokens of the output to be forced as the language and task - which accelerates transcription.
# # To force the ids, generate_kwargs needs to be passed to the pipeline
# # This is unnecessary if the model is capable for determining the language and task.
# # Fine-tuning can adversely affect a model's ability to determine the language.
# processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)
# forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)

whisper_asr = pipeline(
    "automatic-speech-recognition",
    model=model_name_or_path,
    chunk_length_s=30,
    device="cuda" if torch.cuda.is_available() else "mps", # for mac
    # device="cuda" if torch.cuda.is_available() else "cpu",
    return_timestamps=True,
)

In [8]:
import re

def format_time(seconds):
    """Formats seconds into HH:MM:SS.mmm format with period separating milliseconds."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = seconds % 60
    return f"{hours:02}:{minutes:02}:{seconds:06.3f}"  # Ensure period is used for milliseconds

def process_audio_and_create_vtt(audio_filename, audio_type, whisper_asr, output_filename=None):
    """Processes audio and creates a VTT file from Whisper ASR predictions."""
    # Process the audio file and get the transcription with timestamps
    prediction = whisper_asr(
        f"{audio_filename}.{audio_type}",
        return_timestamps=True,
        chunk_length_s=30,
        stride_length_s=10,
    )

    print(prediction)

    # Use the specified output file name or default to audio_filename if not provided
    vtt_file_name = output_filename if output_filename else f"{audio_filename}.vtt"

    # Create a VTT file
    with open(vtt_file_name, "w", encoding='utf-8') as vtt_file:
        vtt_file.write("WEBVTT\n\n")

        # Iterate over the ASR prediction chunks
        for i, chunk in enumerate(prediction.get("chunks", [])):
            start, end = chunk.get("timestamp", (None, None))
            text = chunk.get("text", "").strip()

            # Validate that start, end, and text are present
            if start is None or end is None or not text:
                print(f"Skipping invalid chunk at index {i}: start={start}, end={end}, text='{text}'")
                continue  # Skip this chunk if any part is invalid

            # Format the timestamps using the format_time function
            start_time = format_time(start)
            end_time = format_time(end)

            # Validate the generated timestamps
            if not re.match(r"^\d{2}:\d{2}:\d{2}\.\d{3}$", start_time) or not re.match(r"^\d{2}:\d{2}:\d{2}\.\d{3}$", end_time):
                print(f"Invalid timestamp format at index {i}: start={start_time}, end={end_time}")
                continue  # Skip if timestamp format is invalid

            # Write the valid VTT cue to the file
            vtt_file.write(f"{start_time} --> {end_time}\n{text}\n\n")

    print(f"VTT file saved to: {vtt_file_name}")

In [9]:
# Process the training audio file
process_audio_and_create_vtt("data/train", "mp3", whisper_asr)

FileNotFoundError: [Errno 2] No such file or directory: 'data/train.mp3'

In [None]:
# Process the validation audio file
process_audio_and_create_vtt("data/validation", "mp3", whisper_asr)

### Dataset Cleaning

train.vtt and validation.vtt files are now present in your folder structure.

You now need to manually (or with the help of openai below) correct that transcript (optionally, with the help of a GPT and the prompt below) and then upload the vtt file to the ADVANCED-transcription repo, into the data folder.

Here is a prompt you can use (but be very careful as the LLM can slightly mess up time-stamps and that breaks transcription):
```
I want your help in correcting a VTT file transcript. I'll provide a list of words that the ASR was not familiar with. Respond, in a code pen, with the contents of the updated VTT file:

<input-VTT>
...
</input-VTT>

<keywords>
...
</keywords>

Respond immediately with the corrected VTT.  Do not run any code.
```

In [None]:
!pip install openai -qU

In [None]:
import openai
import getpass

# Securely input your OpenAI API key using getpass
api_key = getpass.getpass("Enter your OpenAI API key: ")

In [None]:
from openai import OpenAI
from pydantic import BaseModel
from typing import List

client = OpenAI(api_key=api_key)

# Define the structure for the corrected VTT response
class CorrectedTextResponse(BaseModel):
    corrected_text: str

def get_context(vtt_lines, index, context_range=2):
    """Get up to two lines before and after the current index for context."""
    start_idx = max(0, index - context_range)
    end_idx = min(len(vtt_lines), index + context_range + 1)
    return vtt_lines[start_idx:end_idx]

def is_timestamp(line):
    """Check if a line is a timestamp in VTT format."""
    return re.match(r"^\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}$", line)

def split_vtt_into_lines(vtt_contents):
    """Splits VTT file into individual timestamps and text lines."""
    lines = vtt_contents.strip().splitlines()
    vtt_pairs = []
    
    i = 0
    while i < len(lines):
        line = lines[i].strip()
        if is_timestamp(line):  # This is a timestamp line
            timestamp = line
            if i + 1 < len(lines):
                text = lines[i + 1].strip()  # The text comes right after the timestamp
                vtt_pairs.append((timestamp, text))
            i += 2  # Move to the next potential timestamp
        else:
            i += 1  # Skip non-timestamp lines (e.g., "WEBVTT" header, blank lines)

    return vtt_pairs

def correct_single_line(vtt_lines, index_to_correct, keywords):
    """Correct a single line with context from the VTT file."""
    # Get up to 2 lines before and after the target line for context
    context_lines = get_context(vtt_lines, index_to_correct)

    # Prepare the input with context
    combined_input = "\n".join([f"{timestamp}\n{text}" for timestamp, text in context_lines])

    # Ensure the target line is clearly marked
    target_line = vtt_lines[index_to_correct][1]

    # Prepare the prompt
    prompt = f"""
I want your help in correcting the following VTT file transcript. I will provide up to two lines before and after the line that needs correction for context. 
Do not modify the timestamps. Do not modify any other lines except the target line of text.

The line that needs correction is marked as <target>:

<input-VTT>
{combined_input}
</input-VTT>

<target>
{target_line}
</target>

<keywords>
{keywords}
</keywords>

Please return only the corrected target line of text without changing any other lines.
"""

    # Make the OpenAI API call with structured response handling
    completion = client.beta.chat.completions.parse(
        model="gpt-4o", # if you use a weaker model, it's hard to get the accuracy needed.
        messages=[
            {"role": "system", "content": "Correct the specified line of text based on the provided context and keywords."},
            {"role": "user", "content": prompt},
        ],
        response_format=CorrectedTextResponse,
    )

    # Extract the corrected text
    corrected_text = completion.choices[0].message.parsed.corrected_text
    return corrected_text

def correct_vtt_lines(vtt_lines, keywords):
    corrected_lines = []
    for idx, (timestamp, text) in enumerate(vtt_lines):
        print(f"Correcting line {idx + 1} of {len(vtt_lines)}...")
        corrected_text = correct_single_line(vtt_lines, idx, keywords)
        corrected_lines.append((timestamp, corrected_text))

    return corrected_lines

def save_corrected_vtt(corrected_lines, output_file_path):
    """Save the corrected VTT file by combining corrected text with original timestamps."""
    with open(output_file_path, 'w', encoding='utf-8') as file:
        file.write("WEBVTT\n\n")
        for timestamp, text in corrected_lines:
            file.write(f"{timestamp}\n{text}\n\n")

# Main function
def clean_vtt_file(vtt_file_path, keywords):
    # Read the VTT file
    with open(vtt_file_path, 'r', encoding='utf-8') as file:
        vtt_contents = file.read()

    # Split VTT file into individual lines (timestamp + text)
    vtt_lines = split_vtt_into_lines(vtt_contents)

    # Correct each line one by one, providing context
    corrected_lines = correct_vtt_lines(vtt_lines, keywords)

    return corrected_lines

# Main function to handle both train and validation files
if __name__ == "__main__":
    train_vtt_path = "train.vtt"
    validation_vtt_path = "validation.vtt"
    output_train_vtt_path = "train_corrected.vtt"
    output_validation_vtt_path = "validation_corrected.vtt"
    keywords_file_name = "train_keywords.txt"
    
    # Load keywords
    with open(keywords_file_name, 'r', encoding='utf-8') as file:
        keywords_to_correct = file.read()

    # Clean and save the train VTT file
    corrected_train_vtt_lines = clean_vtt_file(train_vtt_path, keywords_to_correct)
    if corrected_train_vtt_lines:
        save_corrected_vtt(corrected_train_vtt_lines, output_train_vtt_path)
        print(f"Corrected Train VTT file saved to: {output_train_vtt_path}")

    # Clean and save the validation VTT file
    corrected_validation_vtt_lines = clean_vtt_file(validation_vtt_path, keywords_to_correct)
    if corrected_validation_vtt_lines:
        save_corrected_vtt(corrected_validation_vtt_lines, output_validation_vtt_path)
        print(f"Corrected Validation VTT file saved to: {output_validation_vtt_path}")


### Audio chunking
You must be logged into huggingface to run this, see above for login.

In [None]:
!pip install webvtt-py -qU

#### For processing on training and one validation mp3/vtt

In [None]:
# This is for handling single input train and validation files.
# For multiple input audio files (and transcripts) see `prepare-data-multi-input.py` in the `speech-to-text` folder of `ADVANCED-transcription`

from datasets import Dataset, DatasetDict, Audio
import webvtt
from datetime import datetime
import librosa
import soundfile as sf
import os
from huggingface_hub import login

# Setup
hf_username = "Trelis"  # Your Hugging Face username
repo_name = "llm-lingo"  # Name of the repository on the Hub
train_audio_file = "train.mp3"  # Path to the training audio file
train_vtt_file = "train_corrected.vtt"  # Path to the training VTT file
validation_audio_file = "validation.mp3"  # Path to the validation audio file
validation_vtt_file = "validation_corrected.vtt"  # Path to the validation VTT file
save_path = f"data/{repo_name}-dataset"  # Local save path

def parse_time(time_str):
    return datetime.strptime(time_str, '%H:%M:%S.%f')

def milliseconds(time_obj):
    return (time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second) * 1000 + time_obj.microsecond // 1000

def time_to_samples(time_ms, sr):
    return int((time_ms / 1000.0) * sr)

def transform_data(data):
    transformed = {"audio": [], "text": [], "start_time": [], "end_time": []}
    for item in data:
        for key in transformed:
            transformed[key].append(item[key])
    return transformed

def process_audio_file(audio_path, vtt_path, output_dir, max_duration=30):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Load the full audio file with librosa
    full_audio, sr = librosa.load(audio_path, sr=None, mono=True)

    # Parse VTT file
    captions = webvtt.read(vtt_path)

    # Prepare data for 🤗 Datasets
    data = []
    current_text = []
    current_start = None
    current_end = None
    accumulated_duration = 0
    segment_counter = 0

    for caption in captions:
        start_time = parse_time(caption.start)
        end_time = parse_time(caption.end)
        duration = (end_time - start_time).total_seconds()

        if current_start is None:
            current_start = start_time

        if accumulated_duration + duration <= max_duration:
            current_text.append(caption.text)
            current_end = end_time
            accumulated_duration += duration
        else:
            # Process and save the audio segment in MP3 format
            segment_filename = f"{output_dir}/segment_{segment_counter}.mp3"
            start_sample = time_to_samples(milliseconds(current_start.time()), sr)
            end_sample = time_to_samples(milliseconds(current_end.time()), sr)
            audio_segment = full_audio[start_sample:end_sample]
            sf.write(segment_filename, audio_segment, sr, format='mp3')

            # Add the segment info to the dataset
            data.append({
                "audio": segment_filename,
                "text": ' '.join(current_text),
                "start_time": current_start.strftime('%H:%M:%S.%f')[:-3],
                "end_time": current_end.strftime('%H:%M:%S.%f')[:-3]
            })

            # Prepare for the next segment
            current_text = [caption.text]
            current_start = start_time
            current_end = end_time
            accumulated_duration = duration
            segment_counter += 1

    # Process and save any remaining audio segment
    if current_text:
        segment_filename = f"{output_dir}/segment_{segment_counter}.mp3"
        start_sample = time_to_samples(milliseconds(current_start.time()), sr)
        end_sample = time_to_samples(milliseconds(current_end.time()), sr)
        audio_segment = full_audio[start_sample:end_sample]
        sf.write(segment_filename, audio_segment, sr, format='mp3')

        data.append({
            "audio": segment_filename,
            "text": ' '.join(current_text),
            "start_time": current_start.strftime('%H:%M:%S.%f')[:-3],
            "end_time": current_end.strftime('%H:%M:%S.%f')[:-3]
        })

    return data

def create_dataset(train_audio_file, train_vtt_file, validation_audio_file, validation_vtt_file, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Process training audio file
    train_data = process_audio_file(train_audio_file, train_vtt_file, f"{output_dir}/train")

    # Process validation audio file
    validation_data = process_audio_file(validation_audio_file, validation_vtt_file, f"{output_dir}/validation")

    # Transform data into the correct format for Dataset.from_dict
    train_dataset = Dataset.from_dict(transform_data(train_data))
    valid_dataset = Dataset.from_dict(transform_data(validation_data))

    # Create DatasetDict
    dataset_dict = DatasetDict({
        "train": train_dataset,
        "validation": valid_dataset
    })

    return dataset_dict

# Create dataset
dataset = create_dataset(train_audio_file, train_vtt_file, validation_audio_file, validation_vtt_file, save_path)

# Save dataset locally
dataset.save_to_disk(save_path)

# Cast the audio column to the Audio feature
dataset = dataset.cast_column("audio", Audio())
# Push the dataset to the Hub
dataset.push_to_hub(repo_id=f"{hf_username}/{repo_name}")

print(f"Dataset pushed to {hf_username}/{repo_name}")

#### For Processing Multiple Input mp3 and vtt files
Place your files within the `data` folder!

In [None]:
# import os
# from glob import glob
# from datasets import Dataset, DatasetDict, Audio
# import webvtt
# from datetime import datetime
# import librosa
# import soundfile as sf

# # Setup
# hf_username = "Trelis"  # Your Hugging Face username
# repo_name = "llm-lingo"  # Name of the repository on the Hub
# data_folder = "data"  # Path to the data folder
# save_path = f"data/{repo_name}-dataset"  # Local save path

# def parse_time(time_str):
#     return datetime.strptime(time_str, '%H:%M:%S.%f')

# def milliseconds(time_obj):
#     return (time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second) * 1000 + time_obj.microsecond // 1000

# def time_to_samples(time_ms, sr):
#     return int((time_ms / 1000.0) * sr)

# def transform_data(data):
#     transformed = {"audio": [], "text": [], "start_time": [], "end_time": []}
#     for item in data:
#         for key in transformed:
#             transformed[key].append(item[key])
#     return transformed

# def process_audio_file(audio_path, vtt_path, output_dir, max_duration=30):
#     # Create output directory if it doesn't exist
#     os.makedirs(output_dir, exist_ok=True)

#     # Load the full audio file with librosa
#     full_audio, sr = librosa.load(audio_path, sr=None, mono=True)

#     # Parse VTT file
#     captions = webvtt.read(vtt_path)

#     # Prepare data for 🤗 Datasets
#     data = []
#     current_text = []
#     current_start = None
#     current_end = None
#     accumulated_duration = 0
#     segment_counter = 0

#     for caption in captions:
#         start_time = parse_time(caption.start)
#         end_time = parse_time(caption.end)
#         duration = (end_time - start_time).total_seconds()

#         if current_start is None:
#             current_start = start_time

#         if accumulated_duration + duration <= max_duration:
#             current_text.append(caption.text)
#             current_end = end_time
#             accumulated_duration += duration
#         else:
#             # Process and save the audio segment in MP3 format
#             segment_filename = f"{output_dir}/segment_{segment_counter}.mp3"
#             start_sample = time_to_samples(milliseconds(current_start.time()), sr)
#             end_sample = time_to_samples(milliseconds(current_end.time()), sr)
#             audio_segment = full_audio[start_sample:end_sample]
#             sf.write(segment_filename, audio_segment, sr, format='mp3')

#             # Add the segment info to the dataset
#             data.append({
#                 "audio": segment_filename,
#                 "text": ' '.join(current_text),
#                 "start_time": current_start.strftime('%H:%M:%S.%f')[:-3],
#                 "end_time": current_end.strftime('%H:%M:%S.%f')[:-3]
#             })

#             # Prepare for the next segment
#             current_text = [caption.text]
#             current_start = start_time
#             current_end = end_time
#             accumulated_duration = duration
#             segment_counter += 1

#     # Process and save any remaining audio segment
#     if current_text:
#         segment_filename = f"{output_dir}/segment_{segment_counter}.mp3"
#         start_sample = time_to_samples(milliseconds(current_start.time()), sr)
#         end_sample = time_to_samples(milliseconds(current_end.time()), sr)
#         audio_segment = full_audio[start_sample:end_sample]
#         sf.write(segment_filename, audio_segment, sr, format='mp3')

#         data.append({
#             "audio": segment_filename,
#             "text": ' '.join(current_text),
#             "start_time": current_start.strftime('%H:%M:%S.%f')[:-3],
#             "end_time": current_end.strftime('%H:%M:%S.%f')[:-3]
#         })

#     return data

# def process_files(file_type):
#     audio_files = glob(os.path.join(data_folder, f"*{file_type}*.mp3"))
#     vtt_files = glob(os.path.join(data_folder, f"*{file_type}*.vtt"))
    
#     if not audio_files or not vtt_files:
#         raise ValueError(f"No {file_type} files found in the data folder.")
    
#     data = []
#     for audio_file, vtt_file in zip(audio_files, vtt_files):
#         data.extend(process_audio_file(audio_file, vtt_file, f"{save_path}/{file_type}"))
    
#     return data

# def create_dataset(output_dir):
#     os.makedirs(output_dir, exist_ok=True)

#     # Process training files
#     train_data = process_files("train")

#     # Process validation files
#     validation_data = process_files("validation")

#     # Transform data into the correct format for Dataset.from_dict
#     train_dataset = Dataset.from_dict(transform_data(train_data))
#     valid_dataset = Dataset.from_dict(transform_data(validation_data))

#     # Create DatasetDict
#     dataset_dict = DatasetDict({
#         "train": train_dataset,
#         "validation": valid_dataset
#     })

#     return dataset_dict

# # Create dataset
# dataset = create_dataset(save_path)

# # Save dataset locally
# dataset.save_to_disk(save_path)

# # Cast the audio column to the Audio feature
# dataset = dataset.cast_column("audio", Audio())

# # Push the dataset to the Hub
# dataset.push_to_hub(repo_id=f"{hf_username}/{repo_name}")

## Load Dataset

In [None]:
import os
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets, Audio

def trim_text(example):
    """Trims the beginning and end of text fields."""
    example["text"] = example["text"].strip()  # Trim whitespace and unwanted characters
    return example

def load_and_prepare_datasets(datasets_info):
    combined_dataset = DatasetDict()

    for dataset_info in datasets_info:
        dataset_name = dataset_info["name"]
        audio_col = dataset_info["audio_col"]
        text_col = dataset_info["text_col"]
        subset = dataset_info.get("subset")
        revision = dataset_info.get("revision", "main")
        limit = dataset_info.get("limit")  # New optional field
        custom_filter = dataset_info.get("filter_fn")  # Optional filter function

        print(f"Loading dataset: {dataset_name} | subset: {subset} | revision: {revision}")

        # Load dataset
        if subset:
            loaded_dataset = load_dataset(
                dataset_name,
                subset,
                token="",
                revision=revision
            )
        else:
            loaded_dataset = load_dataset(
                dataset_name,
                token="",
                revision=revision
            )

        # Identify train / test splits
        if isinstance(loaded_dataset, Dataset):
            # Single-split dataset
            ds_train = loaded_dataset
            ds_test = None
        else:
            ds_train = loaded_dataset.get("train")
            ds_test = loaded_dataset.get("validate")

            # If "validation" split exists, merge into train
            if loaded_dataset.get("validation"):
                ds_train = concatenate_datasets([ds_train, loaded_dataset["validation"]])

            # Merge or fallback for test
            if ds_test and loaded_dataset.get("test"):
                ds_test = concatenate_datasets([ds_test, loaded_dataset["test"]])
            else:
                ds_test = loaded_dataset.get("test") or ds_test

        # If no train split found, fall back on first (or only) available split
        if ds_train is None:
            available_splits = list(loaded_dataset.keys())
            if len(available_splits) == 1:
                single_split_name = available_splits[0]
                ds_train = loaded_dataset[single_split_name]
                ds_test = None
            else:
                raise ValueError(
                    f"No 'train' split found for {dataset_name}. Available splits: {available_splits}"
                )

        # If test split is missing, create it from train (90/10)
        if ds_test is None:
            split_result = ds_train.train_test_split(test_size=0.1, seed=42)
            ds_train = split_result["train"]
            ds_test = split_result["test"]

        # Apply custom filter if provided in the dataset_info
        if custom_filter is not None:
            ds_train = ds_train.filter(custom_filter)
            ds_test = ds_test.filter(custom_filter)

        # Rename columns to standard 'audio' and 'text'
        rename_map = {}
        if audio_col in ds_train.column_names:
            rename_map[audio_col] = "audio"
        if text_col in ds_train.column_names:
            rename_map[text_col] = "text"

        ds_train = ds_train.rename_columns(rename_map)
        ds_test = ds_test.rename_columns(rename_map)

        # Keep only "audio" and "text" columns
        columns_to_keep = ["audio", "text"]
        ds_train = ds_train.remove_columns([col for col in ds_train.column_names if col not in columns_to_keep])
        ds_test = ds_test.remove_columns([col for col in ds_test.column_names if col not in columns_to_keep])

        # Cast "audio" column to proper Audio feature if present
        if "audio" in ds_train.column_names:
            ds_train = ds_train.cast_column("audio", Audio(sampling_rate=16000))
        if "audio" in ds_test.column_names:
            ds_test = ds_test.cast_column("audio", Audio(sampling_rate=16000))

        # Apply text trimming
        ds_train = ds_train.map(trim_text)
        ds_test = ds_test.map(trim_text)

        # Apply limit if specified
        if limit is not None:
            ds_train = ds_train.select(range(min(limit, len(ds_train))))
            ds_test = ds_test.select(range(min(limit, len(ds_test))))

        # Concatenate into combined dataset
        if "train" not in combined_dataset:
            combined_dataset["train"] = ds_train
        else:
            combined_dataset["train"] = concatenate_datasets([combined_dataset["train"], ds_train])

        if "validation" not in combined_dataset:
            combined_dataset["validation"] = ds_test
        else:
            combined_dataset["validation"] = concatenate_datasets([combined_dataset["validation"], ds_test])

    return combined_dataset

# Example usage
datasets_info = [
{
        "name": "DavronSherbaev/uzbekvoice-filtered",
        "audio_col": "path",
        "text_col": "sentence",
        "limit": 20000,
        "filter_fn": lambda ex: (
            ex.get("reported_reasons") is None and 
            ex.get("downvotes_count", 0) == 0 and 
            ex.get("reported_count", 0) == 0 and 
            ex.get("client_id") not in [
                "56ac8e86-b8c9-4879-a342-0eeb94f686fc",
                "3d3fca02-6a07-41e2-9af4-60886ea60300",
                "231d3776-2dbe-4a42-a535-c67943427e3f",
                "e2716f95-70b5-4832-b903-eef2343591a4",
                "2a815774-e953-4031-931a-8a28052e5cf9",
                "d6fd3dc4-a55d-4a80-9bbf-b713325d05be",
                "10b29e87-bf01-4b16-bead-a044076f849b",
                "e3412d51-f079-4167-b3f9-311a976443ce"
            ]
        )
    },
    {
        "name": "Beehzod/dataset_for_STT_TTSmodels",
        "audio_col": "audio",
        "text_col": "transcription",
        "revision": "refs/convert/parquet"
    },
    # {
    #     "name": "mozilla-foundation/common_voice_17_0",
    #     "subset": "uz",
    #     "audio_col": "audio",
    #     "text_col": "sentence",
    # },
    {
        "name": "google/fleurs",
        "subset": "uz_uz",
        "audio_col": "audio",
        "text_col": "transcription",
    },
    {
        "name": "bekzod123/uzbek_voice",
        "audio_col": "audio",
        "text_col": "text",
        "revision": "refs/convert/parquet",
        "filter_fn": lambda ex: (
            ex.get("is_correct") === True or ex.get("is_correct") === "true"
        )

    },
    {
        "name": "bekzod123/uzbek_voice_2",
        "audio_col": "audio",
        "text_col": "sentence"
    },
    # {
    #     "name": "bekzod123/uzbek_voice_3",
    #     "audio_col": "audio",
    #     "text_col": "text"
    # }
]

# Load and combine all datasets
dataset = load_and_prepare_datasets(datasets_info)
print(dataset)

## Prepare Feature Extractor, Tokenizer and Data

In [7]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

In [8]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
tokenizer.pad_token = tokenizer.eos_token  # Explicitly setting pad_token


tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

In [9]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)

### Prepare Data

In [10]:
print(dataset["train"][0])

{'audio': {'path': '7bdcc7d15bee3827419232f6b17b717e892498b63ca098a83568dd759c34f81a.mp3', 'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
       -3.63368381e-05, -4.90826205e-05, -3.52069983e-05]), 'sampling_rate': 16000}, 'text': "Xudo xohlasa, g'alaba qozonib, muxlislarni xursand qilamiz."}


Often, the input audio is sampled at 48kHz, so we need to _downsample_ it to
16kHz prior to passing it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model.

We'll set the audio inputs to the correct sampling rate using dataset's
[`cast_column`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cast_column#datasets.DatasetDict.cast_column)
method. This operation does not change the audio in-place,
but rather signals to `datasets` to resample audio samples _on the fly_ the
first time that they are loaded:

In [11]:
# from datasets import Audio

# dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

Re-loading the first audio sample in the Common Voice dataset will resample
it to the desired sampling rate:

In [12]:
print(dataset["train"][0])

{'audio': {'path': '7bdcc7d15bee3827419232f6b17b717e892498b63ca098a83568dd759c34f81a.mp3', 'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
       -3.63368381e-05, -4.90826205e-05, -3.52069983e-05]), 'sampling_rate': 16000}, 'text': "Xudo xohlasa, g'alaba qozonib, muxlislarni xursand qilamiz."}


Now we can write a function to prepare our data ready for the model:
1. We load and resample the audio data by calling `batch["audio"]`. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.
2. We use the feature extractor to compute the log-Mel spectrogram input features from our 1-dimensional audio array.
3. We encode the transcriptions to label ids through the use of the tokenizer.

In [13]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch

We can apply the data preparation function to all of our training examples using dataset's `.map` method. The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` and process the dataset sequentially.

In [14]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['audio', 'text'],
        num_rows: 6306
    })
    validation: Dataset({
        features: ['audio', 'text'],
        num_rows: 2910
    })
})


In [15]:
dataset = dataset.map(
    prepare_dataset,
    remove_columns=dataset.column_names["train"],
    num_proc=1
)

Map:   0%|          | 0/6306 [00:00<?, ? examples/s]

Map:   0%|          | 0/2910 [00:00<?, ? examples/s]

In [16]:
dataset["train"]

Dataset({
    features: ['input_features', 'labels'],
    num_rows: 6306
})

In [17]:
dataset["validation"]

Dataset({
    features: ['input_features', 'labels'],
    num_rows: 2910
})

In [18]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 6306
    })
    validation: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 2910
    })
})


## Training and Evaluation

### Define a Data Collator

In [19]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Define the target dtype for input features based on the use_bf16 flag
        input_dtype = torch.bfloat16 if use_bf16 else torch.float16

        # Split inputs and labels since they have to be of different lengths and need different padding methods
        # First treat the audio inputs (input features)
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Cast input features to bfloat16 if necessary
        batch = {k: v.to(input_dtype) for k, v in batch.items()}

        # Get the tokenized label sequences (input_ids)
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # Pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If bos token is appended in the previous tokenization step, remove it here
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        # IMPORTANT: Do NOT cast input_ids/labels to bfloat16; they should remain as Long tensors
        batch["labels"] = labels  # labels remain in their original type (Long)

        return batch


Let's initialise the data collator we've just defined:

In [20]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Metrics

We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing
ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:

In [21]:
import evaluate

metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

We then simply have to define a function that takes our model
predictions and returns the WER metric. This function, called
`compute_metrics`, first replaces `-100` with the `pad_token_id`
in the `label_ids` (undoing the step we applied in the
data collator to ignore padded tokens correctly in the loss).
It then decodes the predicted and label ids to strings. Finally,
it computes the WER between the predictions and reference labels:

In [22]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### Load a Pre-Trained Checkpoint

Now let's load the pre-trained Whisper `turbo` checkpoint.

In [23]:
from transformers import WhisperForConditionalGeneration
import torch

use_bf16 = True # set false for colab

# del model # if you want to re-run and delete the model

model = WhisperForConditionalGeneration.from_pretrained(
    model_name_or_path,
    load_in_8bit=False,
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16, # or float16 for Colab
    device_map="auto")

# model.hf_device_map - this should be {" ": 0}
model.config.pad_token_id = tokenizer.pad_token_id

config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.90k [00:00<?, ?B/s]

In [24]:
print(model.dtype)  # should match input features

torch.bfloat16


Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):

In [25]:
model.config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
language=language, task=task
)
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
language=language, task=task
)
model.generation_config.suppress_tokens = []



### Apply LoRA

Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`.

In [26]:
# optionally inspect the model
print(model)

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bia

In [27]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=8,
                    use_rslora=True,
                    target_modules=["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"], # or optionally just do "all-linear" for all linear layers
                    modules_to_save = ["model.embed_tokens"], # optionally train this if you are doing a heavy tune with lots of data and maybe a language Whisper is weak on.
                    lora_dropout=0.05, bias="none")

model = get_peft_model(model, config)

model.print_trainable_parameters()

trainable params: 57,671,680 || all params: 1,601,162,240 || trainable%: 3.6019


We are ONLY using **~3%** of the total trainable parameters, thereby performing **Parameter-Efficient Fine-Tuning**

### Define the Training Configuration

In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

In [28]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=trained_model_name,  # change to a repo name of your choice
    per_device_train_batch_size=11, # use 1 if you have a dataset of just 10. Use 4 if you've a dataset of 50-200, use 16 for larger than that, or 32.
    per_device_eval_batch_size=11,
    gradient_accumulation_steps=1,  # probably best to leave at 1 unless you see a lot of noise on the training loss.
    learning_rate=5e-4,
    # warmup_steps=50,
    weight_decay=0.02,
    warmup_ratio=0.1,                # Slightly increased warmup ratio for stability
    num_train_epochs=1,
    optim="adamw_torch",
    # load_best_model_at_end=True
    eval_strategy="steps",
    fp16=not use_bf16,
    bf16=use_bf16,
    generation_max_length=128,
    save_steps=10,
    save_total_limit=5,
    logging_steps=2,
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
    predict_with_generate=True,
    eval_steps=10, # run every two steps
    do_eval=True,
    lr_scheduler_type="constant",
    load_best_model_at_end=True,      # ⬅️ Ensures the best model is restored at the end
    metric_for_best_model="wer",      # ⬅️ Metric used to determine the best model
    greater_is_better=False,          # ⬅️ "wer" should be minimized (for accuracy, set to True)
)


**Important Notes:**
1. `remove_unused_columns=False` and `label_names=["labels"]` are required as the PeftModel's forward doesn't have the signature of the base model's forward.

2. If using INT8 - INT8 training required autocasting. `predict_with_generate` can't be passed to Trainer because it internally calls transformer's `generate` without autocasting leading to errors.

3. If using INT8 - Because of point 2, `compute_metrics` shouldn't be passed to `Seq2SeqTrainer` as seen below. (commented out)

In [29]:
from transformers import Seq2SeqTrainer, EarlyStoppingCallback

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!


  trainer = Seq2SeqTrainer(


In [None]:
trainer.train()

Step,Training Loss,Validation Loss,Wer
10,1.0516,1.212386,69.205524
20,0.6372,0.808943,59.288254
30,0.6746,0.71431,55.87994
40,0.5519,0.628524,50.468344
50,0.4175,0.584862,46.731847
60,0.3646,0.550886,46.454943
70,0.362,0.526994,44.171339
80,0.3075,0.495753,43.949132
90,0.4113,0.498131,42.25694
100,0.3718,0.493953,41.303159


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [None]:
## If you instead want to load an adapter from the hub
# adapter_to_push = f"{trained_model_name}/checkpoint-8"
import datetime

# Get today's date formatted as YYYY-MM-DD
today = datetime.date.today().strftime("%Y-%m-%d")

adapter_to_push = "bekzod123/whisper-llm-large-adapter"
adapter_to_push = f"{adapter_to_push}-{today}"

# import torch
# torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()

model.push_to_hub(trained_adapter_repo, private=True)
# print(f"Model pushed to {trained_adapter_repo}")


# print(adapter_to_push)

# from peft import PeftModel
# model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=False, device_map="auto")

# model = PeftModel.from_pretrained(
#     model,
#     adapter_to_push,
# )

# model.push_to_hub(trained_adapter_repo, private=True)
# print(f"Model pushed to {trained_adapter_repo}")

In [None]:
model = model.merge_and_unload()

In [34]:
# # Check the LoRa adapters are merged
# print(model)
print(model.dtype)  # should match input features

torch.bfloat16


In [36]:
# Save the model locally
model.save_pretrained(trained_model_name)
processor.save_pretrained(trained_model_name)




[]

## Evaluation

In [54]:
from transformers import pipeline
import torch

# del model # if you want to clear the gpu of the trained model

whisper_finetuned_asr = pipeline(
    "automatic-speech-recognition",
    model=trained_model_name,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch.bfloat16,
    return_timestamps=True,
    chunk_length_s=30, #this splits the input into 30 second chunks, the max length for whisper.
    # generate_kwargs={"forced_decoder_ids": forced_decoder_ids},
    device="cuda" if torch.cuda.is_available() else "cpu" #ensures that the GPU is used, for speed up.
)

Device set to use cuda


In [44]:
# Process the validation audio file as a test
# process_audio_and_create_vtt("validation", "mp3", whisper_finetuned_asr, "evaluation.vtt")

In [57]:
import soundfile as sf
import librosa

file_path = "./xabar2_processed.wav"
audio, sample_rate = sf.read(file_path)

# Resample audio to match the processor's expected sampling rate if needed
if sample_rate != processor.feature_extractor.sampling_rate:
    audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=processor.feature_extractor.sampling_rate)

result = whisper_finetuned_asr(audio, generate_kwargs={
        "temperature": 0.5,
        "max_new_tokens": 120,
        "repetition_penalty": 0.8,
        # explicitly set forced_decoder_ids to None to overrid
    })
print(result["text"])

Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


birnibki ishtal   isalomligi ishtalal mal malda   ishlarma   iski ishtal     mani    varligning boribki    muroibnibnibul man    navul    besttumir sumanidan kaplayda ki   murobsan tomunningi    taqtili     jalda taqtishi taqtili  ishi taqsiri suqishibda muhavkamak ulazganida peishkom guringida peishkom guringishi man u'zim bilama peishkom mu   mash'inavlan uramanmi   ugani ging man hakorib su'li'ida yulaksida yulashshum tuzgan fa'miliy fa'miliyt yuravshum tuzgan fa'miliy sumarni karayniz yuravshum tuzaynum fa'miliyt kizm sabib tuzmiz nirkomiz bo'la otx


## Push Model and Processor to Hub

In [None]:
model.push_to_hub(trained_model_repo, safe_serialization=True)
processor.push_to_hub(trained_model_repo)

In [None]:
# ## Run this code if you face utf locale errors in Colab
# # https://stackoverflow.com/questions/56081324/why-are-google-colab-shell-commands-not-working
# import locale
# def getpreferredencoding(do_setlocale = True):
#     return "UTF-8"
# locale.getpreferredencoding = getpreferredencoding

## Convert HuggingFace Whisper Model to OpenAI format

In [None]:
!pip install git+https://github.com/openai/whisper.git -q

In [None]:
# Load the model in HuggingFace format
from transformers import WhisperForConditionalGeneration

# # del model # if you want to re-load a model
# trained_model_name = "whisper-turbo-llm-lingo" # comment this in and ensure correct if you have restarted the kernel
# trained_model_repo = "bekzod123/whisper-turbo-llm-lingo"

model = WhisperForConditionalGeneration.from_pretrained(
    trained_model_name,
    load_in_8bit=False,
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16, # or float16 for Colab
    device_map="auto")

In [None]:
trained_model_bin_format = f"{trained_model_name}-bin"

model.save_pretrained(trained_model_bin_format,
                      safe_serialization=False # to save to a pickled format
                     )
# processor.save_pretrained(trained_model_bin_format) # not necessary

In [None]:
# Credit to [ndunks](https://github.com/ndunks) for the snippet
#!/bin/env python3
import whisper
import re
import torch

def hf_to_whisper_states(text):
    text = re.sub('.layers.', '.blocks.', text)
    text = re.sub('.self_attn.', '.attn.', text)
    text = re.sub('.q_proj.', '.query.', text)
    text = re.sub('.k_proj.', '.key.', text)
    text = re.sub('.v_proj.', '.value.', text)
    text = re.sub('.out_proj.', '.out.', text)
    text = re.sub('.fc1.', '.mlp.0.', text)
    text = re.sub('.fc2.', '.mlp.2.', text)
    text = re.sub('.fc3.', '.mlp.3.', text)
    text = re.sub('.fc3.', '.mlp.3.', text)
    text = re.sub('.encoder_attn.', '.cross_attn.', text)
    text = re.sub('.cross_attn.ln.', '.cross_attn_ln.', text)
    text = re.sub('.embed_positions.weight', '.positional_embedding', text)
    text = re.sub('.embed_tokens.', '.token_embedding.', text)
    text = re.sub('model.', '', text)
    text = re.sub('attn.layer_norm.', 'attn_ln.', text)
    text = re.sub('.final_layer_norm.', '.mlp_ln.', text)
    text = re.sub('encoder.layer_norm.', 'encoder.ln_post.', text)
    text = re.sub('decoder.layer_norm.', 'decoder.ln.', text)
    text = re.sub('proj_out.weight', 'decoder.token_embedding.weight', text)
    return text

# Load HF Model
hf_state_dict = torch.load(f"{trained_model_bin_format}/pytorch_model.bin", map_location=torch.device('cpu'))    # pytorch_model.bin file

# Rename layers
for key in list(hf_state_dict.keys())[:]:
    new_key = hf_to_whisper_states(key)
    hf_state_dict[new_key] = hf_state_dict.pop(key)

openai_format_model = f"{trained_model_name}-openai.bin"

model = whisper.load_model('turbo')
dims = model.dims
# Save it
torch.save({
    "dims": model.dims.__dict__,
    "model_state_dict": hf_state_dict
}, openai_format_model)

In [None]:
## Upload to hub
from huggingface_hub import upload_file

openai_format_model = f"{trained_model_name}-openai.bin"
trained_model_repo = "Trelis/whisper-turbo-llm-lingo"

# Upload the file to the repository
upload_file(
    path_or_fileobj=openai_format_model,  # The file path or object
    path_in_repo=openai_format_model,  # The name under which to save it in the repo
    repo_id=trained_model_repo,  # The repo ID
    commit_message=f"Add {openai_format_model}",  # Optional: commit message
    commit_description="Uploading the OpenAI format bin model",  # Optional: description
    create_pr=False  # Set to True if you want to create a PR instead of pushing to the main branch
)

## Converting HuggingFace format to CTranslate2 (for Faster Whisper)

In [None]:
!pip install ctranslate2 -qU
# !pip uninstall ctranslate2 -y

In [None]:
trained_model_path = "whisper-turbo-llm-lingo"
local_ctranslate_repo = f"{trained_model_repo.split('/')[-1]}-ctranslate2"
org = "Trelis"
hf_ctranslate_repo = f"{org}/{local_ctranslate_repo}"

In [None]:
## IMPORTANT: To run this, you'll need to grab the tokenizer.json from
# https://huggingface.co/openai/whisper-large-v3-turbo/tree/main and
# then upload it to where you saved your trained model locally

# Convert to float16 and copy the necessary tokenizer files
!ct2-transformers-converter --model "{trained_model_path}" --output_dir "{local_ctranslate_repo}" --copy_files tokenizer.json preprocessor_config.json --quantization float16

In [None]:
## Upload to hub
from huggingface_hub import upload_folder, create_repo

create_repo(hf_ctranslate_repo, repo_type="model")

# Upload the file to the repository
upload_folder(
    folder_path=local_ctranslate_repo,  # The file path or object
    repo_id=hf_ctranslate_repo,  # The repo ID
    commit_message=f"Add {local_ctranslate_repo}",  # Optional: commit message
    commit_description="Uploading the ctranslate format model",  # Optional: description
    create_pr=False  # Set to True if you want to create a PR instead of pushing to the main branch
)

### Quick test of the ctranslate2 model

In [None]:
import ctranslate2
import librosa
import transformers

# Load and resample the audio file.
audio, _ = librosa.load("validation.mp3", sr=16000, mono=True) # you need to upload some mp3 (or wav or mp4) audio and name it validation.xxx

# Compute the features of the first 30 seconds of audio.
## IMPORTANT - THE PROCESSOR HERE IS BEING LOADED FROM THE HUGGINGFACE SAVED MODEL
processor = transformers.WhisperProcessor.from_pretrained(trained_model_repo)  # local directory
inputs = processor(audio, return_tensors="np", sampling_rate=16000)

# Convert the input features to the StorageView type for ctranslate2.
features = ctranslate2.StorageView.from_array(inputs.input_features)

# Load the model.
model = ctranslate2.models.Whisper(local_ctranslate_repo)

# Detect the language.
results = model.detect_language(features)
language, probability = results[0][0]
print("Detected language %s with probability %f" % (language, probability))

# Describe the task in the prompt.
prompt = processor.tokenizer.convert_tokens_to_ids(
    [
        "<|startoftranscript|>",
        language,
        "<|transcribe|>",
        "<|notimestamps|>",  # Remove this token to generate timestamps.
    ]
)

# Run generation for the 30-second window.
results = model.generate(features, [prompt])
transcription = processor.decode(results[0].sequences_ids[0])
print(transcription)