## Fine-Tuning Whisper Small

### Libraries:

In [None]:
# pip install evaluate
# pip install jiwer

In [None]:
import os
import pandas as pd
import numpy as np
import evaluate


### Data Augmentation:

Since the fine-tuning was done on Kaggle, using the accelarator GPU 100, we had tp do some data augmentation including uploading the audio files to Kaggle as zipped folders, and modifying the Audio_WAV column to include paths to the audio files in Kaggle rather than the local paths.

In [8]:
train = pd.read_csv("/kaggle/input/train-test-checkps/training-checkpoint-semifin.csv")
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11493 entries, 0 to 11492
Data columns (total 7 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   Transcriptions    11493 non-null  object 
 1   Audio_URLs        11493 non-null  object 
 2   TA_ID             11493 non-null  int64  
 3   Audio_WAV         11493 non-null  object 
 4   Audio_Lengths     11493 non-null  float64
 5   Silenced_Paths    1149 non-null   object 
 6   All_Audios_Paths  11493 non-null  object 
dtypes: float64(1), int64(1), object(5)
memory usage: 628.6+ KB


In [9]:
test = pd.read_csv("/kaggle/input/train-test-checkps/test-checkpoint-semifin.csv")
test.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2880 entries, 0 to 2879
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   Transcriptions  2880 non-null   object 
 1   Audio_URLs      2880 non-null   object 
 2   TA_ID           2880 non-null   int64  
 3   Audio_WAV       2880 non-null   object 
 4   Audio_Lengths   2880 non-null   float64
dtypes: float64(1), int64(1), object(3)
memory usage: 112.6+ KB


Paths of the files' directories:

- train_dir is the directory containing training files
- test_dir is the directory containig testing files
- sil_dir is the directory containing the audio files that were augmented using silence seconds addition, we did this step already in Data-Augmentation-Done.ipynb notebook and saved the augmented audio files locally to the local directory "Added-Silence-To--Audios", this local directory was uploaded to Kaggle to add the augmented audio files' paths to the dataset

In [10]:
train_dir = "/kaggle/input/train-test-folders/train-wav/train-wav"
test_dir = "/kaggle/input/train-test-folders/test-wav/test-wav"
sil_dir = "/kaggle/input/silenced-audios/Added-Silence-To--Audios" 

In [11]:
len(os.listdir(train_dir))

11493

In [12]:
for file in os.listdir(train_dir):
    id = os.path.splitext(file)[0]
    id = int(id)

    audio_path = os.path.join(train_dir, file)
    train.loc[train['TA_ID'] == id, 'Audio_WAV'] = audio_path

In [13]:
train.head(3)

Unnamed: 0,Transcriptions,Audio_URLs,TA_ID,Audio_WAV,Audio_Lengths,Silenced_Paths,All_Audios_Paths
0,قالت باختصار شديد والدكم قرر ترك عمله في الخ...,https://www.ireadarabic.com/uploads/slides/52...,93810,/kaggle/input/train-test-folders/train-wav/tra...,12.120816,D:\GP1\.venv\ipynb_Files\ASR_Data2\Added-Silen...,D:\GP1\.venv\ipynb_Files\ASR_Data2\Added-Silen...
1,سألت سحر أيها الفيل ترضع صغارك ولكنك تستعمل أ...,https://www.ireadarabic.com/uploads/slides/16...,24592,/kaggle/input/train-test-folders/train-wav/tra...,3.683265,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\2...
2,قرأنا الرسالتين فإذا هما تحتويان على الكلام ن...,https://www.ireadarabic.com/uploads/slides/26...,13278,/kaggle/input/train-test-folders/train-wav/tra...,19.226122,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\1...


In [14]:
def extract_id(filename):
    return int(filename.split('_')[1])

In [15]:
silenced_paths_dict = {}

for filename in os.listdir(sil_dir):
        id_from_filename = extract_id(filename)
        silenced_paths_dict[id_from_filename] = os.path.join(sil_dir, filename)

train['Silence_Paths'] = train['TA_ID'].map(silenced_paths_dict)


In [16]:
train

Unnamed: 0,Transcriptions,Audio_URLs,TA_ID,Audio_WAV,Audio_Lengths,Silenced_Paths,All_Audios_Paths,Silence_Paths
0,قالت باختصار شديد والدكم قرر ترك عمله في الخ...,https://www.ireadarabic.com/uploads/slides/52...,93810,/kaggle/input/train-test-folders/train-wav/tra...,12.120816,D:\GP1\.venv\ipynb_Files\ASR_Data2\Added-Silen...,D:\GP1\.venv\ipynb_Files\ASR_Data2\Added-Silen...,/kaggle/input/silenced-audios/Added-Silence-To...
1,سألت سحر أيها الفيل ترضع صغارك ولكنك تستعمل أ...,https://www.ireadarabic.com/uploads/slides/16...,24592,/kaggle/input/train-test-folders/train-wav/tra...,3.683265,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\2...,
2,قرأنا الرسالتين فإذا هما تحتويان على الكلام ن...,https://www.ireadarabic.com/uploads/slides/26...,13278,/kaggle/input/train-test-folders/train-wav/tra...,19.226122,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\1...,
3,وأخيرا ارتداها الأسد جميعها ثم جلس على صخرة م...,https://www.ireadarabic.com/uploads/slides/34...,46048,/kaggle/input/train-test-folders/train-wav/tra...,11.781224,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\4...,
4,استيقظ عمر صباح يوم الجمعة على رائحة شهية يحب...,https://www.ireadarabic.com/uploads/slides/72...,42098,/kaggle/input/train-test-folders/train-wav/tra...,24.120771,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\4...,
...,...,...,...,...,...,...,...,...
11488,كنت أرتجف وأسناني bتصطكb من شدة البرد وأنا أت...,https://www.ireadarabic.com/uploads/slides/28...,61552,/kaggle/input/train-test-folders/train-wav/tra...,12.669388,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\6...,
11489,وفي نهاية الحفل تجمع الطلاب والتقطت لهم صورة ...,https://www.ireadarabic.com/uploads/slides/12...,16760,/kaggle/input/train-test-folders/train-wav/tra...,10.248000,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\1...,
11490,قالت ريما وهي تحاول الإمساك به من الأفضل لك ...,https://www.ireadarabic.com/uploads/slides/29...,85740,/kaggle/input/train-test-folders/train-wav/tra...,5.773061,,D:\GP1\.venv\ipynb_Files\ASR_Data2\train-wav\8...,
11491,قالت أنا لم أنس الشرط لكنك مت ...,https://www.ireadarabic.com/uploads/slides/33...,66243,/kaggle/input/train-test-folders/train-wav/tra...,3.604898,D:\GP1\.venv\ipynb_Files\ASR_Data2\Added-Silen...,D:\GP1\.venv\ipynb_Files\ASR_Data2\Added-Silen...,/kaggle/input/silenced-audios/Added-Silence-To...


In [18]:
train.drop(['All_Audios_Paths','Silenced_Paths'],axis=1,inplace=True)

In [19]:
train['All_Audio_Paths'] = np.where(train['Silence_Paths'].isna(), train['Audio_WAV'], train['Silence_Paths'])
train

Unnamed: 0,Transcriptions,Audio_URLs,TA_ID,Audio_WAV,Audio_Lengths,Silence_Paths,All_Audio_Paths
0,قالت باختصار شديد والدكم قرر ترك عمله في الخ...,https://www.ireadarabic.com/uploads/slides/52...,93810,/kaggle/input/train-test-folders/train-wav/tra...,12.120816,/kaggle/input/silenced-audios/Added-Silence-To...,/kaggle/input/silenced-audios/Added-Silence-To...
1,سألت سحر أيها الفيل ترضع صغارك ولكنك تستعمل أ...,https://www.ireadarabic.com/uploads/slides/16...,24592,/kaggle/input/train-test-folders/train-wav/tra...,3.683265,,/kaggle/input/train-test-folders/train-wav/tra...
2,قرأنا الرسالتين فإذا هما تحتويان على الكلام ن...,https://www.ireadarabic.com/uploads/slides/26...,13278,/kaggle/input/train-test-folders/train-wav/tra...,19.226122,,/kaggle/input/train-test-folders/train-wav/tra...
3,وأخيرا ارتداها الأسد جميعها ثم جلس على صخرة م...,https://www.ireadarabic.com/uploads/slides/34...,46048,/kaggle/input/train-test-folders/train-wav/tra...,11.781224,,/kaggle/input/train-test-folders/train-wav/tra...
4,استيقظ عمر صباح يوم الجمعة على رائحة شهية يحب...,https://www.ireadarabic.com/uploads/slides/72...,42098,/kaggle/input/train-test-folders/train-wav/tra...,24.120771,,/kaggle/input/train-test-folders/train-wav/tra...
...,...,...,...,...,...,...,...
11488,كنت أرتجف وأسناني bتصطكb من شدة البرد وأنا أت...,https://www.ireadarabic.com/uploads/slides/28...,61552,/kaggle/input/train-test-folders/train-wav/tra...,12.669388,,/kaggle/input/train-test-folders/train-wav/tra...
11489,وفي نهاية الحفل تجمع الطلاب والتقطت لهم صورة ...,https://www.ireadarabic.com/uploads/slides/12...,16760,/kaggle/input/train-test-folders/train-wav/tra...,10.248000,,/kaggle/input/train-test-folders/train-wav/tra...
11490,قالت ريما وهي تحاول الإمساك به من الأفضل لك ...,https://www.ireadarabic.com/uploads/slides/29...,85740,/kaggle/input/train-test-folders/train-wav/tra...,5.773061,,/kaggle/input/train-test-folders/train-wav/tra...
11491,قالت أنا لم أنس الشرط لكنك مت ...,https://www.ireadarabic.com/uploads/slides/33...,66243,/kaggle/input/train-test-folders/train-wav/tra...,3.604898,/kaggle/input/silenced-audios/Added-Silence-To...,/kaggle/input/silenced-audios/Added-Silence-To...


In [20]:
for file in os.listdir(test_dir):
    id = os.path.splitext(file)[0]
    id = int(id)

    audio_path = os.path.join(test_dir, file)
    test.loc[test['TA_ID'] == id, 'Audio_WAV'] = audio_path

In [21]:
test

Unnamed: 0,Transcriptions,Audio_URLs,TA_ID,Audio_WAV,Audio_Lengths
0,كبرت الشتول وأعطتbrخضارا طيبة ...,https://www.ireadarabic.com/uploads/slides/93...,2824,/kaggle/input/train-test-folders/test-wav/test...,3.996735
1,هذا التفاح أحمر ...,https://www.ireadarabic.com/uploads/slides/12...,1409,/kaggle/input/train-test-folders/test-wav/test...,1.593469
2,في المطبعة كبيرة تطبع ...,https://www.ireadarabic.com/uploads/slides/20...,5506,/kaggle/input/train-test-folders/test-wav/test...,4.649796
3,دققت النظر في الصور في هذه الصورة ماما تسبح م...,https://www.ireadarabic.com/uploads/slides/82...,5012,/kaggle/input/train-test-folders/test-wav/test...,15.792000
4,قالت سمكة السلمون نعم أنا سمكة ولي ذيل وزعانف...,https://www.ireadarabic.com/uploads/slides/15...,4657,/kaggle/input/train-test-folders/test-wav/test...,7.183673
...,...,...,...,...,...
2875,النملة والفيل ...,https://www.ireadarabic.com/uploads/slides/27...,6704,/kaggle/input/train-test-folders/test-wav/test...,1.253878
2876,في الرحلة التالية أرسل أبو أسعد على ظهر الحما...,https://www.ireadarabic.com/uploads/slides/46...,4410,/kaggle/input/train-test-folders/test-wav/test...,4.957664
2877,حكى للناس ما شاهده من غرائبbrفتعجبوا لأنه لم ...,https://www.ireadarabic.com/uploads/slides/53...,8790,/kaggle/input/train-test-folders/test-wav/test...,5.247664
2878,لن أتابع الركض ...,https://www.ireadarabic.com/uploads/slides/13...,2222,/kaggle/input/train-test-folders/test-wav/test...,1.384490


In this code we dropped the rows containing English letters to lessen the time it took for fine-tuning

In [22]:
import re 

def contains_english(text):
    # Regular expression to match English letters
    english_pattern = re.compile(r'[a-zA-Z]')
    return bool(english_pattern.search(text))


train['contains_english'] = train['Transcriptions'].apply(contains_english)

# Filter the DataFrame to show rows where English letters are detected
english_letters_train = train[train['contains_english']]

# Print the rows where English letters are detected
print(english_letters_train)

test['contains_english'] = test['Transcriptions'].apply(contains_english)

# Filter the DataFrame to show rows where English letters are detected
english_letters_test = test[test['contains_english']]



                                          Transcriptions  \
1       سألت سحر أيها الفيل ترضع صغارك ولكنك تستعمل أ...   
2       قرأنا الرسالتين فإذا هما تحتويان على الكلام ن...   
3       وأخيرا ارتداها الأسد جميعها ثم جلس على صخرة م...   
4       استيقظ عمر صباح يوم الجمعة على رائحة شهية يحب...   
5       كان هذا الجار رجلا له رأي سديد وقلب من ذهب ول...   
...                                                  ...   
11479    قالت الآنسة صفاء وهزت ميرة رأسها BR موافقة ل...   
11482   لكنها كانت تراجع إجاباتها وتسألbr نفسها هل أح...   
11486   زنابق الماء العملاقة لها جذورbr تمتد تحت الما...   
11488   كنت أرتجف وأسناني bتصطكb من شدة البرد وأنا أت...   
11489   وفي نهاية الحفل تجمع الطلاب والتقطت لهم صورة ...   

                                              Audio_URLs  TA_ID  \
1       https://www.ireadarabic.com/uploads/slides/16...  24592   
2       https://www.ireadarabic.com/uploads/slides/26...  13278   
3       https://www.ireadarabic.com/uploads/slides/34...  46048   
4       htt

In [23]:
train = train[~train['contains_english']]

train = train.drop(columns=['contains_english'])

test = test[~test['contains_english']]

test = test.drop(columns=['contains_english'])

In [24]:
train.drop(['Silence_Paths','Audio_WAV'],axis=1,inplace=True)
train.rename(columns={'All_Audio_Paths': 'Audio_WAV'}, inplace=True)
train.head(3)

Unnamed: 0,Transcriptions,Audio_URLs,TA_ID,Audio_Lengths,Audio_WAV
0,قالت باختصار شديد والدكم قرر ترك عمله في الخ...,https://www.ireadarabic.com/uploads/slides/52...,93810,12.120816,/kaggle/input/silenced-audios/Added-Silence-To...
7,تمر ...,https://www.ireadarabic.com/uploads/slides/54...,23434,0.626939,/kaggle/input/train-test-folders/train-wav/tra...
8,تزلج على الجليد ...,https://www.ireadarabic.com/uploads/slides/33...,98696,1.752,/kaggle/input/train-test-folders/train-wav/tra...


In [3]:
train = pd.read_csv('/kaggle/input/fine-tune-data/fine-tune-ftrain.csv')
test = pd.read_csv('/kaggle/input/fine-tune-data/fine-tune-ftest.csv')


In [None]:
train.to_csv('/kaggle/working/fine-tune-ftrain.csv',index=False)
test.to_csv('/kaggle/working/fine-tune-ftest.csv',index=False)

This step is performed because later on we need to set the parameter max_length for the tokenizer.

In [4]:
import nltk

# Tokenize sequences and find the sequence with the longest length
max_length = 0
longest_sequence = None

for sequence in train['Transcriptions']:
    tokens = nltk.word_tokenize(sequence)
    if len(tokens) > max_length:
        max_length = len(tokens)
        longest_sequence = sequence

print("Longest sequence:", longest_sequence)
print("Length of longest sequence:", max_length)


Longest sequence:  سناء منذ انتصف النهار تتجول مع والدتها في القرية يقنعن النساء بالمشاركة في الانتخابات وفي مثالية نادرة كانت سناء تطلب منهن الانتخاب فقط دون ذكر شخص بعينه ورغم ما يعرفه الجميع من حرب هلال أبي الدهب على سناء ووالدها فقد احترمت قواعد اللعبة ولم تطالب أحدا بعدم انتخاب هلال أو التصويت لخالد كانت تقول للنساء أنتن نصف المجتمع وأصواتكن ستغير الدنيا الكثيرات اقتنعن والكثيرات أيضا ترددن إذ كيف سيذهبن إلى اللجان وهنا جاء دور والد سناء الذي استأجر هذه الحافلة لتوفير وسيلة نقل إلى اللجان كان يمر على المنزل فتنزل سناء ووالدتها لدعوة النساء وما إن توافق إحداهن حتى تطلب منها سناء ارتداء ملابس الخروج واللحاق بها حيث الحافلة تنتظر عند الباب  
Length of longest sequence: 112


Connect to Hugging Face using this write token

In [26]:
from huggingface_hub import login

login(token = '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')

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [27]:
test.shape

(2018, 5)

In [28]:
train.shape

(7861, 5)

loading dataset:

In [29]:
from datasets import Features, Value, Audio, load_dataset

sd = load_dataset(
    'csv', data_files={
        'train': 'fine-tune-ftrain.csv', 
        'test': 'fine-tune-ftest.csv',
    }
)

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

explicitly defining the features contained in both traing and test:

In [30]:
features = Features(
    {
        "Transcriptions": Value("string"), 
        "Audio_URLs": Value('string'),
        "TA_ID": Value("int64"),
        "Audio_Lengths": Value("float64"),
        "Audio_WAV": Audio(sampling_rate=16000)
    }
)

In [31]:
sd = sd.cast(features)
sd

Casting the dataset:   0%|          | 0/7861 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2018 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Transcriptions', 'Audio_URLs', 'TA_ID', 'Audio_Lengths', 'Audio_WAV'],
        num_rows: 7861
    })
    test: Dataset({
        features: ['Transcriptions', 'Audio_URLs', 'TA_ID', 'Audio_Lengths', 'Audio_WAV'],
        num_rows: 2018
    })
})

In [32]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

In [33]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Arabic", task="transcribe",push_to_hub=True)


tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [34]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Arabic", task="transcribe")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [35]:
from datasets import Audio

sd = sd.cast_column("Audio_WAV", Audio(sampling_rate=16000))

In [36]:
sd['train'][0]

{'Transcriptions': ' قالت باختصار شديد  والدكم قرر ترك عمله في الخارج والعودة للعيش معنا هنا لقد رتب الأمور مع خالكم سوف يأتي غدا صباحا                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            ',
 'Audio_URLs': ' https://www.ireadarabic.com/uploads/slides/528/8449895552ae0361f230c080dc19096f.mp3',
 'TA_ID': 93810,
 'Audio_Lengths': 12.120816326530612,
 'Audio_WAV': {'path

max_length=112 is set to this number because earlier we performed tokenization on the training column "Transcriptions" and found the maximum sequence length in that entire column.

In [37]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["Audio_WAV"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["Transcriptions"],padding="max_length",truncation=True,max_length=112, add_special_tokens=True).input_ids
    return batch


In [38]:
sd = sd.map(prepare_dataset, remove_columns=sd.column_names["train"])


Map:   0%|          | 0/7861 [00:00<?, ? examples/s]

2024-05-22 17:11:50.199992: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-22 17:11:50.200103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-22 17:11:50.332563: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Map:   0%|          | 0/2018 [00:00<?, ? examples/s]

In [22]:
import transformers
transformers.__version__

'4.39.3'

In [39]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")


config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

In [40]:
model.generation_config.language = "arabic"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

In [41]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

from torch.nn.utils.rnn import pad_sequence

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length


        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")


        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
      

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [42]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)


In [45]:
metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [46]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


In [47]:

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="dana2002/latest-finetuned",  #HF repository where all model checkpoints will be saved
    per_device_train_batch_size=16, #at each step, 16 audio files are taken as a batch and processed
    gradient_accumulation_steps=1,  
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=3000, #the code keeps running until 3000 steps are reached, this means that it might go over the dataset more than once
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000, #number of steps before model checkpoitns are saved, training loss is evaluated after 1000 steps
    eval_steps=1000, #evaluation loss is evaluated after 1000 steps
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True, #so that the model checkpoints can be loaded to HF repository "output_dir"
)


In [48]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=sd["train"],
    eval_dataset=sd["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [49]:
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
1000,0.0309,0.048095,12.158039
2000,0.0054,0.048523,10.280809
3000,0.0013,0.051961,10.027409


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

TrainOutput(global_step=3000, training_loss=0.03287332394346595, metrics={'train_runtime': 28838.6034, 'train_samples_per_second': 1.664, 'train_steps_per_second': 0.104, 'total_flos': 1.383305257893888e+19, 'train_loss': 0.03287332394346595, 'epoch': 6.1})

In [50]:
kwargs = {
    "dataset_tags": "dana2002/fine-tuning-code",
    "dataset": "whisper-finetune", 
    "dataset_args": "config: ar, split: test",
    "language": "ar",
    "model_name": "Whisper Small AR LTMM",  
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

In [54]:
processor.push_to_hub('dana2002/laest-tokenizer')

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/dana2002/laest-tokenizer/commit/4de363e72f5298bfbd7d810c5b68cddd7f03f702', commit_message='Upload processor', commit_description='', oid='4de363e72f5298bfbd7d810c5b68cddd7f03f702', pr_url=None, pr_revision=None, pr_num=None)

In [55]:
trainer.push_to_hub('dana2002/latest-finetuned') #an extra step to ensure checkpoint is loaded

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


CommitInfo(commit_url='https://huggingface.co/dana2002/latest-finetuned/commit/d741d88061d17d54e418d6bfa1e3640a92fea5b5', commit_message='dana2002/latest-finetuned', commit_description='', oid='d741d88061d17d54e418d6bfa1e3640a92fea5b5', pr_url=None, pr_revision=None, pr_num=None)