# Fine-tuning Whisper

In this notebook, we will explore fine-tuning the `whisper-small.en` model on our dataset. 

In [3]:
import warnings

warnings.filterwarnings("ignore")

## Optional: Download the audio files

In [None]:
import os
from tqdm import tqdm
from google.cloud import storage

# Set up Google Cloud credentials and initialize the client
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'aphasia-chatter-5a70166fc2f1.json'
client = storage.Client()
bucket = client.get_bucket('speech-sit-bucket')  # Replace with your bucket name

# Define GCS directory and local download directory
directory_prefix = 'audio/'  # GCS directory prefix
download_directory = 'samples/audio'  # Local directory path

# Check if the local directory exists
if not os.path.exists(download_directory):
    # If directory doesn't exist, create it
    os.makedirs(download_directory)

    # List all blobs (files) in the specified GCS directory
    audios = bucket.list_blobs(prefix=directory_prefix)

    # Initialize tqdm, but set the total to len(audio_file_set), the actual number of files to download
    progress_bar = tqdm(total=14159, desc="Downloading Files", unit=" files", leave=False)
    for index, audio in enumerate(audios):
        if index != 0:
            audio_file_name = os.path.basename(audio.name)
            local_file_path = os.path.join(download_directory, audio_file_name)
            audio.download_to_filename(local_file_path)
            progress_bar.update(1)
    progress_bar.close()
else:
    # If the directory exists, read files from local storage
    print("Reading files from the local directory...")
    files = sorted(os.listdir(download_directory))
    print(f"Loaded {len(files)} files from local directory.")

## Optional: Download the train, validation, and test sets

In [None]:
import os
from tqdm import tqdm
from google.cloud import storage

# Set up Google Cloud credentials and initialize the client
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'aphasia-chatter-5a70166fc2f1.json'
client = storage.Client()
bucket = client.get_bucket('speech-sit-bucket')  # Replace with your bucket name

# Define GCS directory and local download directory
directory_prefix = 'transcripts/'  # GCS directory prefix
download_directory = 'samples/transcripts'  # Local directory path

# Check if the local directory exists
if not os.path.exists(download_directory):
    # If directory doesn't exist, create it
    os.makedirs(download_directory)

    # List all blobs (files) in the specified GCS directory
    audios = bucket.list_blobs(prefix=directory_prefix)

    # Initialize tqdm, but set the total to len(audio_file_set), the actual number of files to download
    progress_bar = tqdm(total=3, desc="Downloading Gold Dataset", unit=" files", leave=False)
    for index, audio in enumerate(audios):
        audio_file_name = os.path.basename(audio.name)
        local_file_path = os.path.join(download_directory, audio_file_name)
        audio.download_to_filename(local_file_path)
        progress_bar.update(1)
    progress_bar.close()
else:
    # If the directory exists, read files from local storage
    print("Reading files from the local directory...")
    files = sorted(os.listdir(download_directory))
    print(f"Loaded {len(files)} files from local directory.")

## Step 1: Load train, validation, and test set into DataFrame

In [4]:
import pandas as pd

train_df = pd.read_csv('train_set.csv')
val_df = pd.read_csv('val_set.csv')
test_df = pd.read_csv('test_set.csv')

print(train_df.groupby('patient').nunique())
print(val_df.groupby('patient').nunique())
print(test_df.groupby('patient').nunique())

         path  audio_base_path  gold_transcript
patient                                        
al_e026   101              101              101
al_e028    92               92               91
al_e078   114              114              114
al_e085   132              132              131
al_e099   134              134              134
al_e100   161              161              156
al_e101   153              153              153
al_e117   157              157              157
al_e118   154              154              152
al_e122   160              160              158
al_e132   109              109              109
al_e179   127              127              124
hl_e002   164              164              161
hl_e003   162              162              162
hl_e005   163              163              162
hl_e006   163              163              161
hl_e007   164              164              162
hl_e008   163              163              160
hl_e010   164              164          

In [5]:
train_df = train_df.dropna(ignore_index=True)
val_df = val_df.dropna(ignore_index=True)
test_df = test_df.dropna(ignore_index=True)

In [6]:
print(train_df.info())
print(val_df.info())
print(test_df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9422 entries, 0 to 9421
Data columns (total 4 columns):
 #   Column           Non-Null Count  Dtype 
---  ------           --------------  ----- 
 0   patient          9422 non-null   object
 1   path             9422 non-null   object
 2   audio_base_path  9422 non-null   object
 3   gold_transcript  9422 non-null   object
dtypes: object(4)
memory usage: 294.6+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2047 entries, 0 to 2046
Data columns (total 4 columns):
 #   Column           Non-Null Count  Dtype 
---  ------           --------------  ----- 
 0   patient          2047 non-null   object
 1   path             2047 non-null   object
 2   audio_base_path  2047 non-null   object
 3   gold_transcript  2047 non-null   object
dtypes: object(4)
memory usage: 64.1+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2211 entries, 0 to 2210
Data columns (total 4 columns):
 #   Column           Non-Null Count  Dtype 
--- 

In [7]:
train_df.head()

Unnamed: 0,patient,path,audio_base_path,gold_transcript
0,al_e026,samples/audio_processed/al_e026_A-02.wav,al_e026_A-02.wav,I do body. And I have a bag. Racking a... ......
1,al_e026,samples/audio_processed/al_e026_A-03.wav,al_e026_A-03.wav,"My body is a frog, frog, frog, cow, cow, cow,..."
2,al_e026,samples/audio_processed/al_e026_A-04.wav,al_e026_A-04.wav,A crack? It looks like a bag of people clicki...
3,al_e026,samples/audio_processed/al_e026_A-05.wav,al_e026_A-05.wav,"The broke, a broke croaking, cracked calf, cr..."
4,al_e026,samples/audio_processed/al_e026_A-06.wav,al_e026_A-06.wav,"This is a frog calf. The two persons, Bok is ..."


## Step 2: Load the Audio, Extract Audio Features, and Encode Text

In [8]:
import torch
import gc

torch.cuda.empty_cache()
gc.collect()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [9]:
gc.collect()

0

In [10]:
from transformers import WhisperFeatureExtractor

# To get the log-mel spectrogram representation of the audio

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en", revision="main")

In [11]:
from transformers import WhisperTokenizer

# To tokenize the 'cleaned_transcript' or the gold data
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small.en", language="en", task="transcribe")

### 3.1: Load Audio Waveform

#### Optinal: If already downloaded, run this block

In [12]:
import numpy as np
import os
import pandas as pd

def load_waveforms(df, base_directory="samples/waveform"):
    """
    Load .npy waveform files into the 'audio' column of the dataframe.
    
    Args:
        df (pd.DataFrame): Dataframe with an 'audio_base_path' column containing file identifiers.
        base_directory (str): Directory where .npy waveform files are stored.
        
    Returns:
        pd.DataFrame: DataFrame with the 'audio' column updated with loaded waveforms.
    """
    # Define a function to load the .npy file for each row
    def load_audio(row):
        if 'chunk' in row['audio_base_path']:
            filepath = os.path.join(base_directory, f"waveform_{row['audio_base_path']}.npy")
        filepath = os.path.join(base_directory, f"waveform_{row['audio_base_path']}.npy")
        if os.path.exists(filepath):
            return np.load(filepath)
        else:
            print(f"Warning: File {filepath} does not exist.")
            return None

    # Apply the loading function to each row in the dataframe
    df['audio'] = df.apply(load_audio, axis=1)
    df['sampling_rate'] = 16000
    return df

# Load waveforms into each dataframe
train_df = load_waveforms(train_df)
val_df = load_waveforms(val_df)
test_df = load_waveforms(test_df)

#### Otherwise download them

In [None]:
import librosa
from tqdm import tqdm
import numpy as np
import os

def load_audio(row: pd.Series) -> pd.Series:
    """
        Loads the audio file as a floating point
        time series (Amplitude-Time domain). Also
        down samples to 16kHz as Whisper requires
        it.

        Args:
            row: pd.Series - The row to modify

        Returns:
            row: pd.Series - The modified row
    """
    audio, sampling_rate = librosa.load(row["path"], sr=16000)
    file_name = f"waveform_{row['audio_base_path']}"
    if not os.path.exists(file_name): 
        os.path.join(f'samples/waveforms/{file_name}')
        np.save(file_name, audio)
    row['waveform_path'] = f"waveform_{row['audio_base_path']}"
    row['sampling_rate'] = 16000
    return row

### 3.2 Extract Audio Features

In [13]:
def extract_features(row: pd.Series) -> pd.Series:
    """
        Extracts the audio features by converting them into a 1-D
        array representation of a Log-Mel Spectrogram.

        Args:
            row: pd.Series - The row to modify

        Returns:
            row: pd.Series - The modified row, or None if extraction fails
    """
    try:
        # Extract features
        features = feature_extractor(row['audio'], sampling_rate=16000, return_tensors="pt").input_features[0].to(device)
        features_cpu = features.cpu().numpy()
        row['features'] = features_cpu
        torch.cuda.empty_cache()
        return row

    except Exception as e:
        # Log the error and return None to indicate removal
        print(f"Error processing row {row.name}: {e}")
        return None

In [14]:
from tqdm import tqdm

tqdm.pandas(desc="Extracting features", unit=" audio processed")

# Apply the extract_features function with progress tracking on each dataframe
train_df['features'] = None
val_df['features'] = None
test_df['features'] = None

train_df = train_df.progress_apply(lambda x: extract_features(x), axis=1)
val_df = val_df.progress_apply(lambda x: extract_features(x), axis=1)
test_df = test_df.progress_apply(lambda x: extract_features(x), axis=1)

Extracting features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9422/9422 [05:11<00:00, 30.29 audio processed/s]
Extracting features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2047/2047 [01:04<00:00, 31.90 audio processed/s]
Extracting features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2211/2211 [01:10<00:00, 31.28 audio processed/s]


### 3.3 Encode the transcriptions with label ids with the tokenizer

In [15]:
def encode_transcriptions(row: pd.Series) -> pd.Series:
    row["labels"] = tokenizer(
        str(row["gold_transcript"]),
    ).input_ids
    return row

In [16]:
tqdm.pandas(desc="Encoding transcriptions...", unit=" transcription encoded")
train_df = train_df.progress_apply(lambda x: encode_transcriptions(x), axis=1)
val_df = val_df.progress_apply(lambda x: encode_transcriptions(x), axis=1)
test_df = test_df.progress_apply(lambda x: encode_transcriptions(x), axis=1)

Encoding transcriptions...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9422/9422 [00:04<00:00, 1955.66 transcription encoded/s]
Encoding transcriptions...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2047/2047 [00:00<00:00, 2151.06 transcription encoded/s]
Encoding transcriptions...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2211/2211 [00:01<00:00, 2136.54 transcription encoded/s]


### Step 4: Load the Pre-Trained Checkpoint Model (Basically the Model with it's latest updated weights)
Make sure to set the language to english, as by default `whisper-large-v3` will try to determine the language

In [17]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en").to(device)

In [18]:
# Manually define task-to-ID mapping
task_to_id = {"transcribe": 0, "translate": 1}

# Set the task-to-ID in the generation configuration
model.generation_config.task_to_id = task_to_id

# Set the desired task and other generation configuration parameters
model.generation_config.task = "transcribe"  # Task: transcribe or translate

### Step 5: Define a Data Collator

In [19]:
class AphasiaDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.features = df['features'].to_numpy()  # Audio features
        self.labels = df['labels'].to_numpy()      # Encoded labels
    
    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return {
            "input_features": self.features[idx],  # Audio features
            "labels": self.labels[idx]             # Encoded ground truth transcriptions
        }

In [20]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any  # This is now the combined Whisper processor
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Process audio features and pad them using processor
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Process labels (tokenized text) and pad them using processor
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore padding tokens in the loss calculation
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove decoder start token if it's at the start of each label sequence
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

In [21]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small.en", revision="main")

In [22]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id
)

### Step 6: Evaluation Metrics

In [23]:
from jiwer import wer

# Use Word Error Rate (WER) as a metric

In [24]:
from whisper.normalizers import EnglishTextNormalizer

normalizer = EnglishTextNormalizer()

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with the pad_token_id from the processor
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels using the processor
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # Normalize each decoded string
    pred_str = [normalizer(pred) for pred in pred_str]
    label_str = [normalizer(label) for label in label_str]

    # Filter out empty predictions and labels
    filtered_pred_str = []
    filtered_label_str = []
    for p, l in zip(pred_str, label_str):
        if p.strip() and l.strip():  # Keep only non-empty strings
            filtered_pred_str.append(p)
            filtered_label_str.append(l)

    # Calculate Word Error Rate if there are non-empty pairs
    if filtered_pred_str and filtered_label_str:
        word_error_rate = 100 * wer(filtered_pred_str, filtered_label_str)
    else:
        word_error_rate = 0  # Set to 0 or another placeholder if no valid pairs

    return {"wer": word_error_rate}

### Step 7: Define the training arguments

In [25]:
import accelerate

print(accelerate.__version__)

1.1.1


In [26]:
from transformers import Seq2SeqTrainingArguments, Trainer
import os

if not os.path.exists("models"):
    os.mkdir("models")

training_args = Seq2SeqTrainingArguments(
    output_dir="./models/whisphasia",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    num_train_epochs=7, 
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="epoch",   # Evaluate at the end of each epoch
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    save_strategy="epoch",         # Save a checkpoint at the end of each epoch
    logging_steps=50,             
    report_to=["mlflow"],          # Report to MLflow or other platforms
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    warmup_steps=200
)


### Step 8: Define the trainer

In [27]:
import torch
torch.cuda.empty_cache()

In [29]:
from transformers import Seq2SeqTrainer

training_set = AphasiaDataset(train_df)
val_set = AphasiaDataset(val_df)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=training_set,
    eval_dataset=val_set,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [32]:
import mlflow

mlflow.set_tracking_uri("https://6e94-137-132-26-42.ngrok-free.app")  # Update with the IP if remote

In [33]:
trainer.train()

Exception ignored in: <function MLflowCallback.__del__ at 0x707937b239a0>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 1344, in __del__
    self._auto_end_run
AttributeError: 'MLflowCallback' object has no attribute '_auto_end_run'
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Epoch,Training Loss,Validation Loss,Wer
1,0.7921,1.090569,81.187396
2,0.5153,1.124141,78.626366
3,0.3629,1.159805,68.132731
4,0.1916,1.221436,72.794686
5,0.1287,1.290847,70.190493
6,0.0762,1.391731,69.147372
7,0.0653,1.439212,70.551509


You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50362]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_clas

TrainOutput(global_step=4123, training_loss=0.4370639836241717, metrics={'train_runtime': 6115.6683, 'train_samples_per_second': 10.784, 'train_steps_per_second': 0.674, 'total_flos': 1.903336149270528e+19, 'train_loss': 0.4370639836241717, 'epoch': 7.0})

### Step 9: Testing our model

In [44]:
import torch
torch.cuda.empty_cache()
gc.collect()

19010

In [39]:
import whisper
from whisper.normalizers import EnglishTextNormalizer

normalizer = EnglishTextNormalizer()

In [40]:
test_df['gold_transcript_normalized'] = test_df.progress_apply(lambda row: normalizer(str(row['gold_transcript'])), axis=1)

Encoding transcriptions...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2211/2211 [00:00<00:00, 8046.15 transcription encoded/s]


In [None]:
large_model = whisper.load_model('large-v2').to("cuda") # This model is available via the API, and is the one that we're using rn

In [45]:
from tqdm import tqdm

tqdm.pandas(desc="Transcribing Large...")
test_df['transcription_large_v2'] = test_df.progress_apply(
    lambda row: normalizer(
        large_model.transcribe(
            row['path'],
            language='en',
        )['text']).strip(), axis=1)

Transcribing Large...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2211/2211 [24:39<00:00,  1.49it/s]  


In [46]:
test_df['transcription_large_v2'].head()

0                                            handicaps
1                                             handycam
2                           it was a plane many planes
3    i just realized these badges literally says 3 ...
4                                             sheering
Name: transcription_large_v2, dtype: object

In [47]:
test_df = test_df[(test_df['gold_transcript'] != '') & (test_df['gold_transcript_normalized'] != '')]

In [48]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2202 entries, 0 to 2210
Data columns (total 10 columns):
 #   Column                      Non-Null Count  Dtype 
---  ------                      --------------  ----- 
 0   patient                     2202 non-null   object
 1   path                        2202 non-null   object
 2   audio_base_path             2202 non-null   object
 3   gold_transcript             2202 non-null   object
 4   audio                       2202 non-null   object
 5   sampling_rate               2202 non-null   int64 
 6   features                    2202 non-null   object
 7   labels                      2202 non-null   object
 8   gold_transcript_normalized  2202 non-null   object
 9   transcription_large_v2      2202 non-null   object
dtypes: int64(1), object(9)
memory usage: 189.2+ KB


In [139]:
from jiwer import wer

wer_large = wer(test_df['gold_transcript_normalized'].tolist(), test_df['transcription_large_v2'].tolist())

print(wer_large)

0.518875690932612


In [50]:
import torch
import whisper
torch.cuda.empty_cache()
gc.collect()

0

In [51]:
small_model = whisper.load_model('small.en').to('cuda')

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 461M/461M [00:54<00:00, 8.85MiB/s]


In [52]:
tqdm.pandas(desc="Transcribing Small...")
test_df['transcription_small'] = test_df.progress_apply(
    lambda row: normalizer(
        small_model.transcribe(
            row['path'],
            language='en',
        )['text']).strip(), axis=1)

Transcribing Small...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2202/2202 [10:15<00:00,  3.58it/s] 


In [140]:
from jiwer import wer

wer_small = wer(test_df['gold_transcript_normalized'].tolist(), test_df['transcription_small'].tolist())

print(wer_small)

0.5515700341056098


In [78]:
import torch
torch.cuda.empty_cache()
gc.collect()

776

In [79]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor

def transcribe_fine_tune(model, processor, row, device="cuda"):
    # Extract input features
    inputs = processor(row['audio'], return_tensors="pt", sampling_rate=16000).input_features.to(device)
    
    # Ensure model is on the correct device
    model.to(device)
    
    # Generate the tokens (this will be an "encoded" form of the generated transcription)
    predictions = model.generate(inputs)
    
    # Decode the generated tokens
    transcription = processor.batch_decode(predictions, skip_special_tokens=True)
    
    return transcription[0]
    
# Load model and processor from the checkpoint (Checkpoint 6 is best)
fine_tune_model = WhisperForConditionalGeneration.from_pretrained("models/whisphasia/checkpoint-3534").to(device)
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en", revision="main")

In [83]:
tqdm.pandas(desc="Transcribing Fine-Tuned...")
test_df["transcription_fine_tune"] = test_df.progress_apply(
    lambda row: transcribe_fine_tune(fine_tune_model,
                                     processor,
                                     row, 
                                    ),
    axis=1)

Transcribing Fine-Tuned...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2179/2179 [04:25<00:00,  8.22it/s]


In [85]:
from jiwer import wer

In [86]:
test_df['transcription_fine_tune'] = test_df['transcription_fine_tune'].apply(normalizer)

In [87]:
test_df['transcription_large_v2'].head(20)

0                                             handicaps
1                                              handycam
2                            it was a plane many planes
3     i just realized these badges literally says 3 ...
4                                              sheering
5                                    this is a heli fox
6                           look caps look in your caps
7                             bending bands handy bands
8                                             applecise
9                           raisins parniers handy cams
10                                         buzzer plate
11                                           what a fun
12                                         brain brains
13                          loading arms loading planes
15    it is a lion it is a hairy face a baggy face h...
16                                             rory fox
17                                 plants boring plains
18                             sitting on the ta

In [88]:
test_df['transcription_small'].head(20)

0                                       and the cups
1                                                  .
2        what about the plane we are going to planes
3                      end your camps end your camps
4                                            hearing
5                                      hey haley fox
6                                               look
7                          ben ling bents andy bents
8                                         both sides
9                ladies please handy camp handy camp
10                                  what is the plan
11                                    what the frang
12                                      brain brains
13                      loading hours loading planes
15    it is a hairy face a peckery face a hairy face
16                                          rory fox
17                              blunds boring plains
18                          sitting on the taxi face
19                            i wonder what he

In [89]:
test_df['transcription_fine_tune'].head(20)

0                             and the cops and the cops
1                                             handy cam
2                           what is a plane belly plane
3                       and you can not and you can not
4                                              cheering
5                                              is the .
6                   look at the copse look at the copse
7                           benling ben is handy ben is
8                                              opposite
9               reading please on handy cams handy cams
10                            what is up with the plane
11                                          water front
12                                         brain brains
13                                        rolling pains
15    there is a lion it is a hairy face backy face ...
16                                            rory foxe
17                                  blunt boring planes
18                         sitting on the flexib

In [93]:
test_df['gold_transcript_normalized'] = test_df['gold_transcript'].apply(normalizer)

In [92]:
test_df = test_df[(test_df['gold_transcript_normalized'] != '') & (test_df['transcription_fine_tune'] != '')]

In [97]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2178 entries, 0 to 2210
Data columns (total 12 columns):
 #   Column                      Non-Null Count  Dtype 
---  ------                      --------------  ----- 
 0   patient                     2178 non-null   object
 1   path                        2178 non-null   object
 2   audio_base_path             2178 non-null   object
 3   gold_transcript             2178 non-null   object
 4   audio                       2178 non-null   object
 5   sampling_rate               2178 non-null   int64 
 6   features                    2178 non-null   object
 7   labels                      2178 non-null   object
 8   gold_transcript_normalized  2178 non-null   object
 9   transcription_large_v2      2178 non-null   object
 10  transcription_small         2178 non-null   object
 11  transcription_fine_tune     2178 non-null   object
dtypes: int64(1), object(11)
memory usage: 221.2+ KB


In [138]:
wer_ft = wer(test_df['gold_transcript_normalized'].tolist(), test_df['transcription_fine_tune'].apply(normalizer).tolist())
wer_small = wer(test_df['gold_transcript_normalized'].tolist(), test_df['transcription_small'].apply(normalizer).tolist())
wer_large_v2 = wer(test_df['gold_transcript_normalized'].tolist(), test_df['transcription_large_v2'].apply(normalizer).tolist())

print(f"Fine-tuned Error Rate: {wer_ft:.02f}")
print(f"Small Error Rate: {wer_small:.02f}")
print(f"Large V2 Error Rate: {wer_large_v2:.02f}")

Fine-tuned Error Rate: 0.48
Small Error Rate: 0.55
Large V2 Error Rate: 0.52


### Conclusion (Fine-Tuning)

As you can see, the fine-tuning process paid off! The fine-tuned model (`fine-tcheckpoint-3534` achived a WER of 0.48, or a 12.7% reduction in terms of error rate compared to the similarly sized small model. In addition, it achieved a 7.8% reduction WER compared to the currently in use `large-v2` model.

### Step 10: Publish the Model to Huggingface

In [102]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [106]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from huggingface_hub import HfApi

model_checkpoint_path = "./models/whisphasia/checkpoint-3534" # Best performing fine-tuned model (Step 1767 performs worse)
repo_name = "f-azm17/whisper-small-singapore-aphasia"
model = WhisperForConditionalGeneration.from_pretrained(model_checkpoint_path)
model.push_to_hub(repo_name)

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/f-azm17/whisper-small-singapore-aphasia/commit/e421b9d7bf8413f13332393d8f9e87dee46960a0', commit_message='Upload WhisperForConditionalGeneration', commit_description='', oid='e421b9d7bf8413f13332393d8f9e87dee46960a0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/f-azm17/whisper-small-singapore-aphasia', endpoint='https://huggingface.co', repo_type='model', repo_id='f-azm17/whisper-small-singapore-aphasia'), pr_revision=None, pr_num=None)

### Step 11: Performance Testing

In [114]:
import torch
torch.cuda.empty_cache()
gc.collect()

1602

#### 11.1 Performance Testing for `large-v2`

In [None]:
large_model = whisper.load_model('large-v2').to("cuda") # This model is available via the API, and is the one that we're using rn

In [128]:
from tqdm import tqdm
import time

start = time.time()
tqdm.pandas(desc="Transcribing Large...")
test_df['transcription_large_v2'] = test_df.progress_apply(
    lambda row: normalizer(
        large_model.transcribe(
            row['path'],
            language='en',
        )['text']).strip(), axis=1)
total_time_large = time.time() - start

Transcribing Large...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2178/2178 [24:55<00:00,  1.46it/s]  


In [133]:
print(f"Total time executed by large-v2 model: {total_time_large:.02f} seconds")
print(f"Average Inference time: {total_time_large / test_df.shape[0]:.02f} seconds")
print(f"Number of Inferences per second: {(test_df.shape[0] / total_time_large):.02f}")

Total time executed by large-v2 model: 1495.49 seconds
Average Inference time: 0.69 seconds
Number of Inferences per second: 1.46


#### 11.2 Performance Testing for `whisper-small-singapore-aphasia`

In [136]:
from tqdm import tqdm
import time

start = time.time()
tqdm.pandas(desc="Transcribing Fine-Tuned...")
test_df["transcription_fine_tune"] = test_df.progress_apply(
    lambda row: normalizer(transcribe_fine_tune(fine_tune_model,
                                     processor,
                                     row, 
                                    )),
    axis=1)
total_time = time.time() - start

Transcribing Fine-Tuned...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2178/2178 [04:23<00:00,  8.26it/s]


In [137]:
print(f"Total time executed by sg-aphasia-fine-tuned model: {total_time:.02f} seconds")
print(f"Average Inference time: {total_time / test_df.shape[0]:.02f} seconds")
print(f"Number of Inferences per second: {(test_df.shape[0] / total_time):.02f}")

Total time executed by sg-aphasia-fine-tuned model: 263.76 seconds
Average Inference time: 0.12 seconds
Number of Inferences per second: 8.26


### Conclusion (Performance Testing)

Based on our experiments, our newly adapted fine-tuned model achieved fantastic results in terms of inference time, as well as throughput (number of inferences per second). 

On our test set, the fine-tuned model achieved over 5x faster inference time compared to `whisper-large-v2`, whilst performing better in terms of WER.