## Tutorial

https://huggingface.co/blog/fine-tune-wav2vec2-english

In [1]:
import torch

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version torch was built with:", torch.version.cuda)

Torch version: 2.10.0+cu128
CUDA available: True
CUDA version torch was built with: 12.8


## Config HF Cache

In [255]:
import os

class cfg: 

    # to store HF pre-trained models weights and configs
    HF_CACHE_ROOT = os.path.join("..", "..", "..",
                                 "data",
                                 "05_cache", 
                                 "HF"
                                )



    # to store HF pre-trained models weights and configs
    HF_FINETUNE_ROOT = os.path.join("..", "..", "..",
                                    "data",
                                    "06_fine_tune",
                                    "01_tuto",
                                    "01_hug_llm",
                                    "ch03",
                                   )


    HF_FINETUNED_MODEL_SAVE_ROOT = os.path.join(HF_CACHE_ROOT,
                                                "FR_finetuned",
                                               )


## HF Cache management

https://huggingface.co/docs/datasets/en/cache

In [3]:
print("HF_HOME:", os.environ.get("HF_HOME"))
os.environ["HF_HOME"] = cfg.HF_CACHE_ROOT
print("HF_HOME:", os.environ.get("HF_HOME"))

HF_HOME: None
HF_HOME: ../../../data/05_cache/HF


In [4]:
print("HF_HUB_CACHE:", os.environ.get("HF_HUB_CACHE"))
os.environ["HF_HUB_CACHE"] = cfg.HF_CACHE_ROOT
print("HF_HUB_CACHE:", os.environ.get("HF_HUB_CACHE"))

HF_HUB_CACHE: None
HF_HUB_CACHE: ../../../data/05_cache/HF


## Import libraries

In [225]:
import io
import sys
import random

import numpy as np

#_________
import torch
import torchaudio

#__________
# HF stack
import transformers
from transformers import pipeline

import datasets 
from datasets import load_dataset 
# from datasets import load_metric # deprecated, Now I'm using evaluate
from datasets import Audio as Audio_ds # the instances generated from load_dataset use under the hood Audio to decode (Audio use Torchcoced).



import evaluate

#_________
import pandas as pd 

#_________
from dotenv import load_dotenv

from IPython.display import Audio
from IPython.display import display, HTML

In [6]:
transformers.__version__

'5.1.0'

In [7]:
torch.__version__

'2.10.0+cu128'

In [8]:
datasets.__version__

'4.5.0'

## Utilities

### disable_torchcodec()

In [9]:
import sys
from contextlib import contextmanager

@contextmanager  
def disable_torchcodec():
    """
    Description:
    ------------
        - torchcodec is required to decode (read/load) files
        - torchcodec follow a compatibility with 
            torch versions "https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec"
        - torchcode depend on ffmpeg [version 4 to 8]. I install ffmpeg inside 
            of the docker apt-get install -y ffmpeg
        - HF datasets.load_dataset() create an audio instance which expect 
            a version of torchcodec that has "AudioDecoder" class
            - However, torchcodec is only compatible to certain torch version,
                and torch version need compatibility to CUDA version. 
                Currently my CUDA version is 12.2 which allows me to upgrade 
                my torch to v2.5.1 which only allows me to update to torchcodec v0.1.
                which does not have the class "AudioDecoder" expected for 
                dataaudio instances created with datasets.load_dataset()
            - CUDA 12.2 → Torch 2.5.1 → Torchcodec 0.1 (no AudioDecoder) ← HF datasets (expects AudioDecoder)
            - Verification of classes in current torchcodec version
                    import torchcodec.decoders
                    print(dir(torchcodec.decoders))

        This function helps to:
        - Safely run code that would fail due to torchcodec import issues.
        - Temporarily disable torchcodec module.
    
    """
    # Create a real dummy module
    class DummyTorchCodec:
        class decoders:
            AudioDecoder = type('AudioDecoder', (), {})  # Empty class

    # out: <module 'torchcodec' from '/usr/local/lib/python3.10/dist-packages/torchcodec/__init__.py'>
    original_torchcodec = sys.modules.get("torchcodec")

    # out: <__main__.DummyTorchCodec object at 0x7de7e48cfac0>
    sys.modules["torchcodec"] = DummyTorchCodec() # simplint adding `sys.modules["torchcodec"] = None` generate ModuleNotFound error.

    #________________________________________________
    # try, except, finnally expected extructure form `with` operator in python 
    #  consider that with operator has a __enter__ and __exit__ method. 

    try:
        yield
    finally:
        if original_torchcodec is not None:
            sys.modules["torchcodec"] = original_torchcodec
        else:
            del sys.modules["torchcodec"]


## Service Token Authentication

In [10]:
# Verify token is loaded
load_dotenv()

HF_TOKEN_READ = os.getenv("07_FR_phone_TokenType_READ")
print(f"Token loaded: {'Yes' if HF_TOKEN_READ else 'No'}")

Token loaded: Yes


## Load data

In [11]:
dataset_name = 'kylelovesllms/timit_asr'
timit = load_dataset(dataset_name,
                     token=HF_TOKEN_READ
                    )
print(timit)

DatasetDict({
    train: Dataset({
        features: ['audio', 'phonetic_detail', 'word_detail', 'text', 'duration', 'timit_path', 'dialect_region', 'dialect_region_name', 'speaker_id', 'speaker_sex', 'id', 'sentence_type'],
        num_rows: 3629
    })
    test: Dataset({
        features: ['audio', 'phonetic_detail', 'word_detail', 'text', 'duration', 'timit_path', 'dialect_region', 'dialect_region_name', 'speaker_id', 'speaker_sex', 'id', 'sentence_type'],
        num_rows: 1340
    })
})


In [12]:
## Address

In [13]:
    # DatasetDict({
    #     train: Dataset({
    #         features: ['file', 'audio', 'text', 'phonetic_detail', 
    #                    'word_detail', 'dialect_region', 'sentence_type', 
    #                    'speaker_id', 'id'],
    #         num_rows: 4620
    #     })
    #     test: Dataset({
    #         features: ['file', 'audio', 'text', *'phonetic_detail', 
    #                    *'word_detail', *'dialect_region', *'sentence_type',
    #                    *'speaker_id', *'id'],
    #         num_rows: 1680
    #     })
    # })

    # timit = timit.remove_columns(["phonetic_detail", 
                                    # "word_detail", 
                                    # "dialect_region", 
                                    # "id", "sentence_type", "speaker_id"])



In [14]:
# '''
#     Only keeping 
#         - file 
#         - audio < 
#         - text <
# '''
# timit = timit.remove_columns(["phonetic_detail", 
#                               "word_detail", 
#                               "dialect_region", 
#                               "id", 
#                               "sentence_type", 
#                               "speaker_id"])
# timit

## Inspecting one sample audio

In [15]:
timit = timit.cast_column("audio", Audio_ds(decode=True)) # Lazy decoder

In [16]:
audio_obj = timit["train"][0]["audio"]
print(f"Type: {type(audio_obj)}")

Type: <class 'datasets.features._torchcodec.AudioDecoder'>


In [17]:
print(f"Docstring: {audio_obj.__doc__}")

Docstring: None


For Hugging Face's Audio feature with decode=True, you'll typically have:

- array: The actual audio data as a numpy array
- sampling_rate: The sampling rate of the audio
- path: The path to the audio file (if available)

In [18]:
audio_arr = timit["train"]["audio"][0]["array"]
audio_arr

array([-2.1362305e-04,  6.1035156e-05,  3.0517578e-05, ...,
       -3.0517578e-05, -9.1552734e-05, -6.1035156e-05],
      shape=(39936,), dtype=float32)

In [19]:
audio_sr = timit["train"]["audio"][0]["sampling_rate"]
audio_sr

16000

In [20]:
# audio_path = timit["train"]["audio"][0]["path"] # path is not available in this dataset


In [21]:
Audio(data=audio_arr, rate=audio_sr)

## Inspection function

In [22]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

### Proto

In [23]:
# x = -5
# assert x > 0, "Can't pick more elements than there are in the dataset."

In [24]:
x = 5
assert x > 0

### show_random_elements()

In [25]:
def show_random_elements(dataset, num_examples=10):
    '''
        Randomly selecting timit instances

        Pending
        -------
            Audio, pending to solve the issue with torchcodec

    '''
    
    
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."

    #==========================================================
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))
# show_random_elements(timit["train"].remove_columns(["file", "audio"]))

In [26]:
show_random_elements(timit["train"]['text'])

Unnamed: 0,0
0,Count the number of teaspoons of soysauce that you add.
1,"While one element is announcing progress, another is delineating its problems."
2,We experience distress and frustration obtaining our degrees.
3,We've done our part.
4,She had your dark suit in greasy wash water all year.
5,Tim takes Sheila to see movies twice a week.
6,She had your dark suit in greasy wash water all year.
7,The mango and the papaya are in a bowl.
8,No more startling contrast to a system of sullen satellites could be imagined.
9,Don't ask me to carry an oily rag like that.


In [27]:
show_random_elements(timit["train"]['phonetic_detail'], num_examples=3)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51
0,"{'start': 0, 'stop': 2190, 'utterance': 'h#'}","{'start': 2190, 'stop': 2847, 'utterance': 'dh'}","{'start': 2847, 'stop': 3751, 'utterance': 'ix'}","{'start': 3751, 'stop': 5790, 'utterance': 'iy'}","{'start': 5790, 'stop': 6781, 'utterance': 's'}","{'start': 6781, 'stop': 7165, 'utterance': 'tcl'}","{'start': 7165, 'stop': 7897, 'utterance': 's'}","{'start': 7897, 'stop': 8400, 'utterance': 'ix'}","{'start': 8400, 'stop': 9590, 'utterance': 'n'}","{'start': 9590, 'stop': 9827, 'utterance': 'kcl'}","{'start': 9827, 'stop': 11318, 'utterance': 'k'}","{'start': 11318, 'stop': 14338, 'utterance': 'ow'}","{'start': 14338, 'stop': 15356, 'utterance': 's'}","{'start': 15356, 'stop': 15510, 'utterance': 'tcl'}","{'start': 15510, 'stop': 16280, 'utterance': 's'}","{'start': 16280, 'stop': 17280, 'utterance': 'ix'}","{'start': 17280, 'stop': 17978, 'utterance': 'z'}","{'start': 17978, 'stop': 18953, 'utterance': 'ix'}","{'start': 18953, 'stop': 20110, 'utterance': 'pcl'}","{'start': 20110, 'stop': 21197, 'utterance': 'p'}","{'start': 21197, 'stop': 21960, 'utterance': 'l'}","{'start': 21960, 'stop': 23403, 'utterance': 'ey'}","{'start': 23403, 'stop': 24147, 'utterance': 's'}","{'start': 24147, 'stop': 25359, 'utterance': 'f'}","{'start': 25359, 'stop': 26160, 'utterance': 'ax'}","{'start': 26160, 'stop': 27120, 'utterance': 'pcl'}","{'start': 27120, 'stop': 28114, 'utterance': 'p'}","{'start': 28114, 'stop': 28768, 'utterance': 'y'}","{'start': 28768, 'stop': 30280, 'utterance': 'ao'}","{'start': 30280, 'stop': 30820, 'utterance': 'pcl'}","{'start': 30820, 'stop': 32294, 'utterance': 'p'}","{'start': 32294, 'stop': 32956, 'utterance': 'l'}","{'start': 32956, 'stop': 34014, 'utterance': 'eh'}","{'start': 34014, 'stop': 35361, 'utterance': 'zh'}","{'start': 35361, 'stop': 36443, 'utterance': 'axr'}","{'start': 36443, 'stop': 37271, 'utterance': 'ix'}","{'start': 37271, 'stop': 37764, 'utterance': 'nx'}","{'start': 37764, 'stop': 38884, 'utterance': 'ix'}","{'start': 38884, 'stop': 39320, 'utterance': 'kcl'}","{'start': 39320, 'stop': 41290, 'utterance': 's'}","{'start': 41290, 'stop': 43729, 'utterance': 'ay'}","{'start': 43729, 'stop': 44428, 'utterance': 'tcl'}","{'start': 44428, 'stop': 45133, 'utterance': 'm'}","{'start': 45133, 'stop': 45896, 'utterance': 'en'}","{'start': 45896, 'stop': 46600, 'utterance': 'tcl'}","{'start': 46600, 'stop': 61280, 'utterance': 'h#'}",,,,,,
1,"{'start': 0, 'stop': 3462, 'utterance': 'h#'}","{'start': 3462, 'stop': 4469, 'utterance': 'ah'}","{'start': 4469, 'stop': 5347, 'utterance': 'y'}","{'start': 5347, 'stop': 6349, 'utterance': 'ux'}","{'start': 6349, 'stop': 7227, 'utterance': 'y'}","{'start': 7227, 'stop': 8563, 'utterance': 'ux'}","{'start': 8563, 'stop': 8981, 'utterance': 'dx'}","{'start': 8981, 'stop': 10527, 'utterance': 'el'}","{'start': 10527, 'stop': 12992, 'utterance': 'ay'}","{'start': 12992, 'stop': 13953, 'utterance': 'z'}","{'start': 13953, 'stop': 14830, 'utterance': 'ih'}","{'start': 14830, 'stop': 15460, 'utterance': 'ng'}","{'start': 15460, 'stop': 16434, 'utterance': 'v'}","{'start': 16434, 'stop': 17503, 'utterance': 'eh'}","{'start': 17503, 'stop': 18255, 'utterance': 'n'}","{'start': 18255, 'stop': 18478, 'utterance': 'd'}","{'start': 18478, 'stop': 19801, 'utterance': 'iy'}","{'start': 19801, 'stop': 20553, 'utterance': 'ng'}","{'start': 20553, 'stop': 21296, 'utterance': 'm'}","{'start': 21296, 'stop': 21751, 'utterance': 'ix'}","{'start': 21751, 'stop': 23431, 'utterance': 'sh'}","{'start': 23431, 'stop': 24856, 'utterance': 'iy'}","{'start': 24856, 'stop': 25805, 'utterance': 'n'}","{'start': 25805, 'stop': 26369, 'utterance': 'pcl'}","{'start': 26369, 'stop': 26978, 'utterance': 'p'}","{'start': 26978, 'stop': 27580, 'utterance': 'r'}","{'start': 27580, 'stop': 28783, 'utterance': 'ow'}","{'start': 28783, 'stop': 30899, 'utterance': 's'}","{'start': 30899, 'stop': 32459, 'utterance': 'iy'}","{'start': 32459, 'stop': 33503, 'utterance': 'dcl'}","{'start': 33503, 'stop': 34631, 'utterance': 'z'}","{'start': 34631, 'stop': 35236, 'utterance': 'tcl'}","{'start': 35236, 'stop': 35625, 'utterance': 't'}","{'start': 35625, 'stop': 36721, 'utterance': 'hv'}","{'start': 36721, 'stop': 37662, 'utterance': 'eh'}","{'start': 37662, 'stop': 38132, 'utterance': 'l'}","{'start': 38132, 'stop': 39156, 'utterance': 'pcl'}","{'start': 39156, 'stop': 39739, 'utterance': 'p'}","{'start': 39739, 'stop': 41548, 'utterance': 'ey'}","{'start': 41548, 'stop': 43587, 'utterance': 'f'}","{'start': 43587, 'stop': 44233, 'utterance': 'y'}","{'start': 44233, 'stop': 45524, 'utterance': 'axr'}","{'start': 45524, 'stop': 46474, 'utterance': 'pcl'}","{'start': 46474, 'stop': 47267, 'utterance': 'p'}","{'start': 47267, 'stop': 47739, 'utterance': 'r'}","{'start': 47739, 'stop': 48683, 'utterance': 'ow'}","{'start': 48683, 'stop': 49582, 'utterance': 'gcl'}","{'start': 49582, 'stop': 49845, 'utterance': 'g'}","{'start': 49845, 'stop': 50828, 'utterance': 'r'}","{'start': 50828, 'stop': 52796, 'utterance': 'ae'}","{'start': 52796, 'stop': 53462, 'utterance': 'm'}","{'start': 53462, 'stop': 55501, 'utterance': 'h#'}"
2,"{'start': 0, 'stop': 2577, 'utterance': 'h#'}","{'start': 2577, 'stop': 3007, 'utterance': 'g'}","{'start': 3007, 'stop': 3757, 'utterance': 'w'}","{'start': 3757, 'stop': 5438, 'utterance': 'eh'}","{'start': 5438, 'stop': 7018, 'utterance': 'n'}","{'start': 7018, 'stop': 7740, 'utterance': 'pcl'}","{'start': 7740, 'stop': 8990, 'utterance': 'p'}","{'start': 8990, 'stop': 9806, 'utterance': 'l'}","{'start': 9806, 'stop': 11440, 'utterance': 'ae'}","{'start': 11440, 'stop': 11892, 'utterance': 'n'}","{'start': 11892, 'stop': 12460, 'utterance': 'tcl'}","{'start': 12460, 'stop': 13139, 'utterance': 't'}","{'start': 13139, 'stop': 13909, 'utterance': 'ix'}","{'start': 13909, 'stop': 14755, 'utterance': 'dcl'}","{'start': 14755, 'stop': 15044, 'utterance': 'd'}","{'start': 15044, 'stop': 15940, 'utterance': 'gcl'}","{'start': 15940, 'stop': 16560, 'utterance': 'g'}","{'start': 16560, 'stop': 17253, 'utterance': 'r'}","{'start': 17253, 'stop': 19122, 'utterance': 'iy'}","{'start': 19122, 'stop': 20584, 'utterance': 'n'}","{'start': 20584, 'stop': 21219, 'utterance': 'bcl'}","{'start': 21219, 'stop': 21411, 'utterance': 'b'}","{'start': 21411, 'stop': 24647, 'utterance': 'iy'}","{'start': 24647, 'stop': 26219, 'utterance': 'n'}","{'start': 26219, 'stop': 27842, 'utterance': 's'}","{'start': 27842, 'stop': 28547, 'utterance': 'ix'}","{'start': 28547, 'stop': 29470, 'utterance': 'n'}","{'start': 29470, 'stop': 29971, 'utterance': 'hv'}","{'start': 29971, 'stop': 31482, 'utterance': 'er'}","{'start': 31482, 'stop': 32591, 'utterance': 'v'}","{'start': 32591, 'stop': 33951, 'utterance': 'eh'}","{'start': 33951, 'stop': 34374, 'utterance': 'tcl'}","{'start': 34374, 'stop': 35997, 'utterance': 'ch'}","{'start': 35997, 'stop': 36510, 'utterance': 'tcl'}","{'start': 36510, 'stop': 36716, 'utterance': 't'}","{'start': 36716, 'stop': 37000, 'utterance': 'ix'}","{'start': 37000, 'stop': 38081, 'utterance': 'bcl'}","{'start': 38081, 'stop': 38383, 'utterance': 'b'}","{'start': 38383, 'stop': 39800, 'utterance': 'el'}","{'start': 39800, 'stop': 40858, 'utterance': 'gcl'}","{'start': 40858, 'stop': 41346, 'utterance': 'g'}","{'start': 41346, 'stop': 43183, 'utterance': 'aa'}","{'start': 43183, 'stop': 44070, 'utterance': 'r'}","{'start': 44070, 'stop': 45115, 'utterance': 'dcl'}","{'start': 45115, 'stop': 46994, 'utterance': 'en'}","{'start': 46994, 'stop': 48480, 'utterance': 'h#'}",,,,,,


In [28]:
show_random_elements(timit["train"]['word_detail'], num_examples=3)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,"{'start': 2266, 'stop': 9037, 'utterance': 'according'}","{'start': 9037, 'stop': 10326, 'utterance': 'to'}","{'start': 10326, 'stop': 14161, 'utterance': 'my'}","{'start': 14161, 'stop': 27000, 'utterance': 'interpretation'}","{'start': 27000, 'stop': 28840, 'utterance': 'of'}","{'start': 28840, 'stop': 29800, 'utterance': 'the'}","{'start': 29800, 'stop': 40005, 'utterance': 'problem'}","{'start': 45960, 'stop': 49080, 'utterance': 'two'}","{'start': 49080, 'stop': 56399, 'utterance': 'lines'}","{'start': 57314, 'stop': 60360, 'utterance': 'must'}","{'start': 60360, 'stop': 63000, 'utterance': 'be'}","{'start': 63000, 'stop': 76040, 'utterance': 'perpendicular'}"
1,"{'start': 2270, 'stop': 3103, 'utterance': 'a'}","{'start': 3103, 'stop': 13534, 'utterance': 'connoisseur'}","{'start': 13534, 'stop': 16418, 'utterance': 'will'}","{'start': 16418, 'stop': 22613, 'utterance': 'enjoy'}","{'start': 22613, 'stop': 26778, 'utterance': 'this'}","{'start': 26778, 'stop': 36440, 'utterance': 'shellfish'}","{'start': 36440, 'stop': 41890, 'utterance': 'dish'}",,,,,
2,"{'start': 2120, 'stop': 6704, 'utterance': 'as'}","{'start': 4730, 'stop': 8154, 'utterance': 'you'}","{'start': 8154, 'stop': 10665, 'utterance': 'can'}","{'start': 10665, 'stop': 15560, 'utterance': 'count'}","{'start': 15560, 'stop': 18255, 'utterance': 'on'}","{'start': 18255, 'stop': 21540, 'utterance': 'me'}","{'start': 21540, 'stop': 23194, 'utterance': 'to'}","{'start': 23194, 'stop': 27070, 'utterance': 'do'}","{'start': 27070, 'stop': 29070, 'utterance': 'the'}","{'start': 29070, 'stop': 35025, 'utterance': 'same'}",,


In [29]:
timit["train"]['phonetic_detail'][0]

[{'start': 0, 'stop': 1960, 'utterance': 'h#'},
 {'start': 1960, 'stop': 2466, 'utterance': 'w'},
 {'start': 2466, 'stop': 3480, 'utterance': 'ix'},
 {'start': 3480, 'stop': 4000, 'utterance': 'dcl'},
 {'start': 4000, 'stop': 5960, 'utterance': 's'},
 {'start': 5960, 'stop': 7480, 'utterance': 'ah'},
 {'start': 7480, 'stop': 7880, 'utterance': 'tcl'},
 {'start': 7880, 'stop': 9400, 'utterance': 'ch'},
 {'start': 9400, 'stop': 9960, 'utterance': 'ix'},
 {'start': 9960, 'stop': 10680, 'utterance': 'n'},
 {'start': 10680, 'stop': 13480, 'utterance': 'ae'},
 {'start': 13480, 'stop': 15680, 'utterance': 'kcl'},
 {'start': 15680, 'stop': 15880, 'utterance': 't'},
 {'start': 15880, 'stop': 16920, 'utterance': 'ix'},
 {'start': 16920, 'stop': 18297, 'utterance': 'v'},
 {'start': 18297, 'stop': 18882, 'utterance': 'r'},
 {'start': 18882, 'stop': 19480, 'utterance': 'ix'},
 {'start': 19480, 'stop': 21723, 'utterance': 'f'},
 {'start': 21723, 'stop': 22516, 'utterance': 'y'},
 {'start': 22516, 's

In [30]:
timit["train"]['word_detail'][0]

[{'start': 1960, 'stop': 4000, 'utterance': 'would'},
 {'start': 4000, 'stop': 9400, 'utterance': 'such'},
 {'start': 9400, 'stop': 10680, 'utterance': 'an'},
 {'start': 10680, 'stop': 15880, 'utterance': 'act'},
 {'start': 15880, 'stop': 18297, 'utterance': 'of'},
 {'start': 18297, 'stop': 27080, 'utterance': 'refusal'},
 {'start': 27080, 'stop': 30120, 'utterance': 'be'},
 {'start': 30120, 'stop': 37720, 'utterance': 'useful'}]

In [31]:
timit["train"]['text'][0:2]

['Would such an act of refusal be useful?',
 "Don't ask me to carry an oily rag like that."]

## Remode special characters from text

### Proto

In [32]:
import re

In [39]:
chars_to_ignore_regex

'?'

In [34]:
txt_sample = timit["train"]['text'][0]
txt_sample

'Would such an act of refusal be useful?'

In [43]:
chars_to_ignore_regex = r"[\,\?\.\!\-\;\:\"]"
re.sub(chars_to_ignore_regex , '====', txt_sample).lower()

'would such an act of refusal be useful===='

In [None]:
re.sub(chars_to_ignore_regex, '', txt)

In [50]:
timit["train"]["text"][0:5]

['Would such an act of refusal be useful?',
 "Don't ask me to carry an oily rag like that.",
 'Butterscotch fudge goes well with vanilla ice cream.',
 'She had your dark suit in greasy wash water all year.',
 'I honor my mom.']

In [56]:
re.sub(chars_to_ignore_regex, '', timit["train"]["text"][0]).lower()

'would such an act of refusal be useful'

### remove_special_characters()

In [59]:
import re
chars_to_ignore_regex = r"[\,\?\.\!\-\;\:\"]"


#https://huggingface.co/docs/datasets/v4.5.0/en/package_reference/main_classes#datasets.DatasetDict.map
def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
    return batch

timit = timit.map(remove_special_characters,
                  # input_columns=["text"], # dont required because function already specify the column text
                  # batched=False,
                  # with_indices=False,
                 )

Map:   0%|          | 0/3629 [00:00<?, ? examples/s]

Map:   0%|          | 0/1340 [00:00<?, ? examples/s]

In [58]:
timit["train"]["text"][0:5]

['would such an act of refusal be useful',
 "don't ask me to carry an oily rag like that",
 'butterscotch fudge goes well with vanilla ice cream',
 'she had your dark suit in greasy wash water all year',
 'i honor my mom']

## Extracting vocab

In CTC, it is common to classify speech chunks into letters, so we will do the same here. Let's extract all distinct letters of the training and test data and build our vocabulary from this set of letters.

We write a mapping function that concatenates all transcriptions into one long transcription and then transforms the string into a set of chars. It is important to pass the argument batched=True to the map(...) function so that the mapping function has access to all transcriptions at once.

In [60]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocabs = timit.map(extract_all_chars,
                   batched=True,
                   batch_size=-1,
                   keep_in_memory=True,
                   remove_columns=timit.column_names["train"], 
                  )

Map:   0%|          | 0/3629 [00:00<?, ? examples/s]

Map:   0%|          | 0/1340 [00:00<?, ? examples/s]

In [61]:
vocabs

DatasetDict({
    train: Dataset({
        features: ['vocab', 'all_text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['vocab', 'all_text'],
        num_rows: 1
    })
})

In [64]:
len(vocabs["train"]['vocab'][0]), len(vocabs["test"]['vocab'][0])

(28, 28)

In [71]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
vocab_list.sort()
len(vocab_list), vocab_list

(28,
 [' ',
  "'",
  'a',
  'b',
  'c',
  'd',
  'e',
  'f',
  'g',
  'h',
  'i',
  'j',
  'k',
  'l',
  'm',
  'n',
  'o',
  'p',
  'q',
  'r',
  's',
  't',
  'u',
  'v',
  'w',
  'x',
  'y',
  'z'])

In [73]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict 

{' ': 0,
 "'": 1,
 'a': 2,
 'b': 3,
 'c': 4,
 'd': 5,
 'e': 6,
 'f': 7,
 'g': 8,
 'h': 9,
 'i': 10,
 'j': 11,
 'k': 12,
 'l': 13,
 'm': 14,
 'n': 15,
 'o': 16,
 'p': 17,
 'q': 18,
 'r': 19,
 's': 20,
 't': 21,
 'u': 22,
 'v': 23,
 'w': 24,
 'x': 25,
 'y': 26,
 'z': 27}

Cool, we see that all letters of the alphabet occur in the dataset (which is not really surprising) and we also extracted the special characters " " and '. Note that we did not exclude those special characters because:

- The model has to learn to predict when a word finished or else the model prediction would always be a sequence of chars which would make it impossible to separate words from each other.
- In English, we need to keep the ' character to differentiate between words, e.g., "it's" and "its" which have very different meanings.
  
To make it clearer that " " has its own token class, we give it a more visible character |. In addition, we also add an "unknown" token so that the model can later deal with characters not encountered in Timit's training set.

In [74]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict 

{"'": 1,
 'a': 2,
 'b': 3,
 'c': 4,
 'd': 5,
 'e': 6,
 'f': 7,
 'g': 8,
 'h': 9,
 'i': 10,
 'j': 11,
 'k': 12,
 'l': 13,
 'm': 14,
 'n': 15,
 'o': 16,
 'p': 17,
 'q': 18,
 'r': 19,
 's': 20,
 't': 21,
 'u': 22,
 'v': 23,
 'w': 24,
 'x': 25,
 'y': 26,
 'z': 27,
 '|': 0}

Finally, we also add a padding token that corresponds to CTC's "blank token". The "blank token" is a core component of the CTC algorithm. For more information, please take a look at the "Alignment" section [here](https://distill.pub/2017/ctc/)

In [77]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
print(len(vocab_dict))
vocab_dict

30


{"'": 1,
 'a': 2,
 'b': 3,
 'c': 4,
 'd': 5,
 'e': 6,
 'f': 7,
 'g': 8,
 'h': 9,
 'i': 10,
 'j': 11,
 'k': 12,
 'l': 13,
 'm': 14,
 'n': 15,
 'o': 16,
 'p': 17,
 'q': 18,
 'r': 19,
 's': 20,
 't': 21,
 'u': 22,
 'v': 23,
 'w': 24,
 'x': 25,
 'y': 26,
 'z': 27,
 '|': 0,
 '[UNK]': 30,
 '[PAD]': 30}

Cool, now our vocabulary is complete and consists of 30 tokens, which means that the linear layer that we will add on top of the pretrained Wav2Vec2 checkpoint will have an output dimension of 30

### Saving vocab.json

In [79]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

## Tokenizer in ASR Wav2Vec2 

The Tokenizer's Role in Wav2Vec2
You can think of a Wav2Vec2 model for ASR as a pipeline with three main stages:

1. The Audio Encoder (Wav2Vec2Model): This part takes your raw audio waveform and converts it into a sequence of hidden states. It's understanding the sounds.

2. The CTC Classifier (Wav2Vec2ForCTC): This is a final linear layer that takes each hidden state from the encoder and predicts which character (from your 30-token vocabulary) is most likely being spoken at that moment. The output here is a matrix of logits (scores) .

3. The Decoder (Wav2Vec2CTCTokenizer): This takes the sequence of predicted character logits from the classifier and decodes them into the final text transcript .

Your tokenizer is exclusively for stage 3. It's a rule-based decoder, not a learned text model.

In [85]:
from transformers import Wav2Vec2CTCTokenizer
# https://huggingface.co/docs/transformers/en/model_doc/wav2vec2#transformers.Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer(vocab_file="./vocab.json", 
                                 unk_token="[UNK]", 
                                 pad_token="[PAD]", 
                                 word_delimiter_token="|")
tokenizer

Wav2Vec2CTCTokenizer(name_or_path='', vocab_size=29, model_max_length=1000000000000000019884624838656, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '[UNK]', 'pad_token': '[PAD]', 'word_delimiter_token': '|'}, added_tokens_decoder={
	0: AddedToken("|", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	30: AddedToken("[PAD]", rstrip=True, lstrip=True, single_word=False, normalized=False, special=False),
	31: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [82]:
repo_name = "wav2vec2-base-timit-demo-colab"

## Feature Extractor

Here is the breakdown of what the Wav2Vec2FeatureExtractor actually does :

- **checks Sampling Rate:** It verifies that the sampling rate of your audio file matches the rate the model was trained on (usually 16,000 Hz) .
- **Normalizes Audio:** It can zero-mean unit-variance normalize the raw waveform, which is crucial for good performance .
- **Pads & Truncates:** It pads audio arrays to the same length within a batch so they can be processed together, and can truncate very long audio files .

**return_attention_mask:** 
 - Whether the model should make use of an attention_mask for batched inference. In general, models should always make use of the attention_mask to mask padded tokens. However, due to a very specific design choice of Wav2Vec2's "base" checkpoint, better results are achieved when using no attention_mask. This is not recommended for other speech models. For more information, one can take a look at this issue. Important If you want to use this notebook to fine-tune large-lv60, this parameter should be set to True.

 - https://github.com/facebookresearch/fairseq/issues/3227
 - https://huggingface.co/facebook/wav2vec2-large-lv60
 

In [86]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, # feature dimension (1)
                                             sampling_rate=16000, 
                                             padding_value=0.0, # value that is added when padding.
                                             do_normalize=True, 
                                             return_attention_mask=False, 
                                            )
feature_extractor

Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}

## Wav2Vec2Processor

Wav2Vec2Processor is indeed like a scikit-learn Pipeline that combines two separate steps into one convenient object.

In [87]:
from transformers import Wav2Vec2Processor

# https://huggingface.co/docs/transformers/en/model_doc/wav2vec2#transformers.Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, 
                              tokenizer=tokenizer)
processor

Wav2Vec2Processor:
- feature_extractor: Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}

- tokenizer: Wav2Vec2CTCTokenizer(name_or_path='', vocab_size=29, model_max_length=1000000000000000019884624838656, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '[UNK]', 'pad_token': '[PAD]', 'word_delimiter_token': '|'}, added_tokens_decoder={
	0: AddedToken("|", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	30: AddedToken("[PAD]", rstrip=True, lstrip=True, single_word=False, normalized=False, special=False),
	31: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

{
  "feature_extractor": {
    "do_normalize": true,
    "feature_extractor_type": "Wav2V

In [89]:
# print(timit[0]["path"])

In [91]:
timit

DatasetDict({
    train: Dataset({
        features: ['audio', 'phonetic_detail', 'word_detail', 'text', 'duration', 'timit_path', 'dialect_region', 'dialect_region_name', 'speaker_id', 'speaker_sex', 'id', 'sentence_type'],
        num_rows: 3629
    })
    test: Dataset({
        features: ['audio', 'phonetic_detail', 'word_detail', 'text', 'duration', 'timit_path', 'dialect_region', 'dialect_region_name', 'speaker_id', 'speaker_sex', 'id', 'sentence_type'],
        num_rows: 1340
    })
})

## Mapping input_values with class instance processor

This transformation is typical for ASR training because:

- Wav2Vec2 models expect exactly these two columns: input_values (audio features) and labels (tokenized text)

- Keeping the original columns would:

    - Waste memory (you don't need them for training)

    - Potentially confuse the data collator

    - Make batches larger and slower

However, if you need to keep some original information (like speaker_id or dialect_region for analysis), you should:

In [105]:
# def prepare_dataset(batch):
#     audio = batch["audio"]

#     # batched output is "un-batched" to ensure mapping is correct
#     batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    
#     with processor.as_target_processor():
#         batch["labels"] = processor(batch["text"]).input_ids
#     return batch

In [106]:
def prepare_dataset(batch):
    # Process both audio and text in one call
    processed = processor(
        audio=batch["audio"]["array"],
        sampling_rate=batch["audio"]["sampling_rate"],
        text=batch["text"]
    )
    
    batch["input_values"] = processed["input_values"]
    batch["labels"] = processed["labels"]  # Labels are automatically created
    return batch

In [102]:
audio_arr = timit["train"]["audio"][0]["array"]
audio_arr.shape

(39936,)

In [103]:
timit["train"]["text"][0]

'would such an act of refusal be useful'

In [104]:
timit

DatasetDict({
    train: Dataset({
        features: ['audio', 'phonetic_detail', 'word_detail', 'text', 'duration', 'timit_path', 'dialect_region', 'dialect_region_name', 'speaker_id', 'speaker_sex', 'id', 'sentence_type'],
        num_rows: 3629
    })
    test: Dataset({
        features: ['audio', 'phonetic_detail', 'word_detail', 'text', 'duration', 'timit_path', 'dialect_region', 'dialect_region_name', 'speaker_id', 'speaker_sex', 'id', 'sentence_type'],
        num_rows: 1340
    })
})

In [107]:
timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)
timit

Map (num_proc=4):   0%|          | 0/3629 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1340 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_values', 'labels'],
        num_rows: 3629
    })
    test: Dataset({
        features: ['input_values', 'labels'],
        num_rows: 1340
    })
})

In [113]:
sample_input_values = np.array(timit["train"]['input_values'][0]) # from list to array
sample_input_values.shape

(1, 39936)

In [114]:
sample_input_values = sample_input_values.squeeze()
sample_input_values.shape

(39936,)

In [115]:
Audio(data=sample_input_values, rate=16000)

In [119]:
print(timit["train"]['labels'][2])

[3, 22, 21, 21, 6, 19, 20, 4, 16, 21, 4, 9, 0, 7, 22, 5, 8, 6, 0, 8, 16, 6, 20, 0, 24, 6, 13, 13, 0, 24, 10, 21, 9, 0, 23, 2, 15, 10, 13, 13, 2, 0, 10, 4, 6, 0, 4, 19, 6, 2, 14]


## Data colletor

Without going into too many details, in contrast to the common data collators, this data collator treats the input_values and labels differently and thus applies to separate padding functions on them (again making use of Wav2Vec2's context manager). This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function. Analogous to the common data collators, the padding tokens in the labels with -100 so that those tokens are not taken into account when computing the loss

### Proto

In [170]:
class test1: 

    def __init__(self, x, y): 
        self.x=x
        self.y=y

    def __call__(self,):
        self.n_call =+1

        return self.n_call
        # print(n_call)
        

In [163]:
ksl = test1(x=1, y=2)
ksl.y

2

In [171]:
#__call__ make that the class instance behave like a function
print(ksl())

1


### Class DataCollatorCTCWithPadding

In [199]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from transformers import Wav2Vec2Processor

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Compatible with Transformers v4.49+ (post as_target_processor deprecation)
    """
    
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_values": np.array(feature["input_values"]).squeeze()} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        
        # Pad input_values using feature_extractor directly
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        # Pad labels using tokenizer directly
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            max_length=self.max_length_labels,
            pad_to_multiple_of=self.pad_to_multiple_of_labels,
            return_tensors="pt",
        )
        
        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        
        return batch

In [200]:
# Initialize your data collator
data_collator = DataCollatorCTCWithPadding(
    processor=processor,
    padding=True,
    max_length=None,  # Let it pad to longest in batch
    max_length_labels=None, # You can set to 8 for TPU optimization if needed
)


In [213]:
# Take a small batch of examples from your dataset
sample_size = 2  # Small batch to test
# sample_dataset = timit["train"].select(range(sample_size))
sample_dataset = timit["train"].select([2, 3])

# Convert to list of features (this is what Trainer would pass to collator)
raw_features = [sample_dataset[i] for i in range(sample_size)]

# print("Original features (first example only):")
# print(f"  input_values length: {len(raw_features[0]['input_values'])}")
# print(f"  labels length: {len(raw_features[0]['labels'])}")
# print(f"  input_values sample: {raw_features[0]['input_values'][:10]}...")
# print(f"  labels sample: {raw_features[0]['labels'][:10]}...")

In [214]:
sample_dataset 

Dataset({
    features: ['input_values', 'labels'],
    num_rows: 2
})

In [216]:
# raw_features[0]["input_values"]

In [217]:
[len(lbls['labels']) for lbls in raw_features]

[51, 52]

In [219]:
print(raw_features[0]["labels"])

[3, 22, 21, 21, 6, 19, 20, 4, 16, 21, 4, 9, 0, 7, 22, 5, 8, 6, 0, 8, 16, 6, 20, 0, 24, 6, 13, 13, 0, 24, 10, 21, 9, 0, 23, 2, 15, 10, 13, 13, 2, 0, 10, 4, 6, 0, 4, 19, 6, 2, 14]


In [220]:
# Pass the raw features to the collator
batched_output = data_collator(raw_features)

In [221]:
batched_output['input_values'].shape

torch.Size([2, 56116])

In [222]:
40448/16000

2.528

In [223]:
batched_output["labels"]

tensor([[   3,   22,   21,   21,    6,   19,   20,    4,   16,   21,    4,    9,
            0,    7,   22,    5,    8,    6,    0,    8,   16,    6,   20,    0,
           24,    6,   13,   13,    0,   24,   10,   21,    9,    0,   23,    2,
           15,   10,   13,   13,    2,    0,   10,    4,    6,    0,    4,   19,
            6,    2,   14, -100],
        [  20,    9,    6,    0,    9,    2,    5,    0,   26,   16,   22,   19,
            0,    5,    2,   19,   12,    0,   20,   22,   10,   21,    0,   10,
           15,    0,    8,   19,    6,    2,   20,   26,    0,   24,    2,   20,
            9,    0,   24,    2,   21,    6,   19,    0,    2,   13,   13,    0,
           26,    6,    2,   19]])

    ✅ Your data collator is working perfectly
    ✅ The -100 padding is correct
    ✅ Repeated characters like "21, 21" should NOT have padding between them
    ✅ CTC handles this automatically during training and decoding

## Evaluation metric WER

In [227]:
wer_metric = evaluate.load("wer")
wer_metric

EvaluationModule(name: "wer", module_type: "metric", features: {'predictions': Value('string'), 'references': Value('string')}, usage: """
Compute WER score of transcribed segments against references.

Args:
    references: List of references for each speech input.
    predictions: List of transcriptions to score.
    concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.

Returns:
    (float): the word error rate

Examples:

    >>> predictions = ["this is the prediction", "there is an other sample"]
    >>> references = ["this is the reference", "there is another one"]
    >>> wer = evaluate.load("wer")
    >>> wer_score = wer.compute(predictions=predictions, references=references)
    >>> print(wer_score)
    0.5
""", stored examples: 0)

In [228]:
processor.tokenizer.pad_token_id

30

In [229]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

## Load the checkpoint pretrained weigths

In [230]:
from transformers import Wav2Vec2ForCTC

In [232]:
checkpoint = "facebook/wav2vec2-base"

model = Wav2Vec2ForCTC.from_pretrained(
    pretrained_model_name_or_path=checkpoint , 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    token=HF_TOKEN_READ, 
)

Loading weights:   0%|          | 0/211 [00:00<?, ?it/s]

[1mWav2Vec2ForCTC LOAD REPORT[0m from: facebook/wav2vec2-base
Key                          | Status     | 
-----------------------------+------------+-
quantizer.weight_proj.bias   | UNEXPECTED | 
quantizer.weight_proj.weight | UNEXPECTED | 
project_q.weight             | UNEXPECTED | 
project_hid.bias             | UNEXPECTED | 
project_hid.weight           | UNEXPECTED | 
quantizer.codevectors        | UNEXPECTED | 
project_q.bias               | UNEXPECTED | 
lm_head.weight               | MISSING    | 
lm_head.bias                 | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


* MISSING → your new classification head was created (random init)
* UNEXPECTED → old pretraining heads (MLM/NSP) that were ignored

In [233]:
model

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder)

In [235]:
# Check if frozen properly
print("Feature extractor frozen:")
for name, param in model.wav2vec2.feature_extractor.named_parameters():
    print(f"  {name}: requires_grad = {param.requires_grad}")
    break  # Just check first one

Feature extractor frozen:
  conv_layers.0.conv.weight: requires_grad = True


In [237]:


print("\nEncoder still trainable:")
for name, param in model.wav2vec2.encoder.named_parameters():
    print(f"  {name}: requires_grad = {param.requires_grad}")
    break  # Just check first one


Encoder still trainable:
  pos_conv_embed.conv.bias: requires_grad = True


### Wav2Vec2FeatureEncoder this will be frozen, already trained 

In [238]:
'''

(feature_extractor): Wav2Vec2FeatureEncoder(
  (conv_layers): ModuleList(
    (0): Conv1d(1, 512, kernel_size=10, stride=5)  # First layer with GroupNorm
    (1-4): 4 x Conv1d(512, 512, kernel_size=3, stride=2)  # No layer norm
    (5-6): 2 x Conv1d(512, 512, kernel_size=2, stride=2)  # No layer norm
  )
)

'''

'\n\n(feature_extractor): Wav2Vec2FeatureEncoder(\n  (conv_layers): ModuleList(\n    (0): Conv1d(1, 512, kernel_size=10, stride=5)  # First layer with GroupNorm\n    (1-4): 4 x Conv1d(512, 512, kernel_size=3, stride=2)  # No layer norm\n    (5-6): 2 x Conv1d(512, 512, kernel_size=2, stride=2)  # No layer norm\n  )\n)\n\n'

### Feature proyection, trainable 

In [239]:
'''
(feature_projection): Wav2Vec2FeatureProjection(
  (layer_norm): LayerNorm(512)
  (projection): Linear(512 → 768)  # Projects to model dimension
  (dropout): Dropout(0.1)
)
'''

'\n(feature_extractor): Wav2Vec2FeatureEncoder(\n  (conv_layers): ModuleList(\n    (0): Conv1d(1, 512, kernel_size=10, stride=5)  # First layer with GroupNorm\n    (1-4): 4 x Conv1d(512, 512, kernel_size=3, stride=2)  # No layer norm\n    (5-6): 2 x Conv1d(512, 512, kernel_size=2, stride=2)  # No layer norm\n  )\n)\n'

### encoder (transformer) trainable 

In [243]:
'''
(encoder): Wav2Vec2Encoder(
  (pos_conv_embed): PositionalConvEmbedding  # Adds positional info
  (layers): 12 x Wav2Vec2EncoderLayer(       # 12 Transformer blocks
      (attention): Multi-head self-attention  # Contextualized representations
      (feed_forward): Linear(768 → 3072 → 768)  # FFN with GELU Wav2Vec2FeedForward(....)
  )
)

'''

'\n(encoder): Wav2Vec2Encoder(\n  (pos_conv_embed): PositionalConvEmbedding  # Adds positional info\n  (layers): 12 x Wav2Vec2EncoderLayer(       # 12 Transformer blocks\n      (attention): Multi-head self-attention  # Contextualized representations\n      (feed_forward): Linear(768 → 3072 → 768)  # FFN with GELU Wav2Vec2FeedForward(....)\n  )\n)\n\n'

### language modelling head, trainable

In [242]:
'''
(lm_head): Linear(in_features=768, out_features=32, bias=True)

'''

'\n(lm_head): Linear(in_features=768, out_features=32, bias=True)\n\n'

### *** Freezing Wav2Vec2FeatureEncoder

The first component of Wav2Vec2 consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretrainind and as stated in the paper does not need to be fine-tuned anymore. Thus, we can set the requires_grad to False for all parameters of the feature extraction part.

In [245]:
# model.freeze_feature_extractor()
model.freeze_feature_encoder()

## Fine tuning 

In [258]:
repo_name

'wav2vec2-base-timit-demo-colab'

In [260]:
root2save_weigts=os.path.join(cfg.HF_FINETUNED_MODEL_SAVE_ROOT, 
                              repo_name,
                             )
root2save_weigts

'../../../data/05_cache/HF/FR_finetuned/wav2vec2-base-timit-demo-colab'

In [261]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir=root2save_weigts,
                                  group_by_length=True,
                                  per_device_train_batch_size=32,
                                  
                                  # evaluation_strategy="steps",
                                  eval_strategy='epoch',
                                  
                                  num_train_epochs=30,
                                  fp16=True,
                                  gradient_checkpointing=True, 
                                  save_steps=5,
                                  eval_steps=5,
                                  logging_steps=5,
                                  learning_rate=1e-4,
                                  weight_decay=0.005,
                                  warmup_steps=1000,
                                  save_total_limit=2,
                                 )
training_args

    Found GPU0 NVIDIA GeForce GTX 1050 which is of cuda capability 6.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (7.0) - (12.0)
    
  queued_call()
    Please install PyTorch with a following CUDA
    configurations:  12.6 following instructions at
    https://pytorch.org/get-started/locally/
    
  queued_call()
NVIDIA GeForce GTX 1050 with CUDA capability sm_61 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_70 sm_75 sm_80 sm_86 sm_90 sm_100 sm_120.
If you want to use the NVIDIA GeForce GTX 1050 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

  queued_call()


TrainingArguments(
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=True,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=False,
enable_jit_checkpoint=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=5,
eval_strategy=epoch,
eval_use_gather_object=False,
fp16=True,
fp16_f

In [262]:
torch.__version__

'2.10.0+cu128'