In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from interface.TimitInterface import TimitInterface

In [None]:
dataset = TimitInterface()

In [None]:
all_data = dataset.data
train_data = dataset.train
test_data = dataset.test
val_data = dataset.valid

In [None]:
train_data_keys = list(train_data.keys())
test_data_keys = list(test_data.keys())
val_data_keys = list(val_data.keys())

In [None]:
from phoneme import read_phoneme_file, del_unnecessary_phonetic, sentence_being_read, extract_needed_data

In [None]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    phonetic_file = []
    text_file = []
    for pick in picks:
        phonetics = "-".join(del_unnecessary_phonetic(read_phoneme_file(train_data[train_data_keys[pick]]['phonetic_file'])))
        phonetic_file.append(phonetics)
        wrd_path = train_data[train_data_keys[pick]]['word_file']
        txt_path = wrd_path[:-3] + 'TXT'
        text_file.append(sentence_being_read(txt_path))
    #df = pd.DataFrame(phonetic_file, picks)
    df = pd.DataFrame({
        'Index': picks,
        'Phonetic File': phonetic_file,
        'Text': text_file
    })
    #df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(train_data)

In [None]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
#tokenizer = Wav2Vec2CTCTokenizer("./phoneme.json", unk_token="[UNK]", pad_token="[PAD]")

In [None]:
repo_name = "wav2vec2-tune-timit-asr-for-phoneme"
token = #insert your token from huggingface here

In [None]:
from huggingface_hub import HfApi

# Initialize the API client
api = HfApi()

# Create the repository
api.create_repo(
    repo_id=repo_name,
    token=token,
    exist_ok=True
)

In [None]:
tokenizer.push_to_hub(repo_name)

In [None]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [None]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def prepare_dataset(batch):
    #audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(batch["audio_arr"], sampling_rate=16000).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["phoneme"]).input_ids
    return batch

In [None]:
path = './TIMIT/data/'+train_data_keys[0]
print(path+'.WAV')
audio_arr, line, phoneme = extract_needed_data(path)
batch = {"audio_arr": audio_arr, "line": line, "phoneme": phoneme}
batch = prepare_dataset(batch)

In [None]:
'./TIMIT/data/'+val_data_keys[0]

In [None]:
audio_arr

In [None]:
line

In [None]:
phoneme

In [None]:
batch["input_values"]

In [None]:
batch["labels"]

In [None]:
path = './TIMIT/data/TEST/DR3/MBWM0/SX134'
print(path+'.WAV')
audio_arr, line, phoneme = extract_needed_data(path)
batch = {"audio_arr": audio_arr, "line": line, "phoneme": phoneme}
batch = prepare_dataset(batch)

In [None]:
phoneme

In [None]:
from datasets import Dataset

In [None]:
"""timit_train = {"features": {'input_values': [], 
                            'labels': []}, 
               "num_rows": 2688}"""

In [None]:
input_values = []
labels = []
for t_path in train_data_keys:
    path = './TIMIT/data/'+t_path 
    audio_arr, line, phoneme = extract_needed_data(path)
    batch = {"audio_arr": audio_arr, "line": line, "phoneme": phoneme}
    batch = prepare_dataset(batch)
    #timit_train["features"]["input_values"].append(batch["input_values"])
    #timit_train["features"]["labels"].append(batch["labels"])
    input_values.append(batch["input_values"])
    labels.append(batch["labels"])

In [None]:
# Create a Dataset from the lists
train_dataset = Dataset.from_dict({
    "input_values": input_values,
    "labels": labels
})

In [None]:
"""timit_test = {"features": {'input_values': [], 
                           'labels': []}, 
              "num_rows": 336}"""

In [None]:
input_values = []
labels = []
for t_path in test_data_keys:
    path = './TIMIT/data/'+t_path 
    audio_arr, line, phoneme = extract_needed_data(path)
    batch = {"audio_arr": audio_arr, "line": line, "phoneme": phoneme}
    batch = prepare_dataset(batch)
    #timit_test["features"]["input_values"].append(batch["input_values"])
    #timit_test["features"]["labels"].append(batch["labels"])
    input_values.append(batch["input_values"])
    labels.append(batch["labels"])

In [None]:
test_dataset = Dataset.from_dict({
    "input_values": input_values,
    "labels": labels
})

In [None]:
"""timit_val = {"features": {'input_values': [], 
                          'labels': []}, 
             "num_rows": 336}"""

In [None]:
input_values = []
labels = []
for t_path in val_data_keys:
    path = './TIMIT/data/'+t_path 
    audio_arr, line, phoneme = extract_needed_data(path)
    batch = {"audio_arr": audio_arr, "line": line, "phoneme": phoneme}
    batch = prepare_dataset(batch)
    #timit_val["features"]["input_values"].append(batch["input_values"])
    #timit_val["features"]["labels"].append(batch["labels"])
    input_values.append(batch["input_values"])
    labels.append(batch["labels"])

In [None]:
val_dataset = Dataset.from_dict({
    "input_values": input_values,
    "labels": labels
})

In [None]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
from datasets import load_metric

wer_metric = load_metric("wer", trust_remote_code=True)
#a little unsure if wer is working

In [None]:
import numpy as np
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC

# Load your current model config
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")

# Update the vocab_size
config.vocab_size = 64 #coz there are 61 phonemes sounds and 3 special tokens on top of that
config.ctc_loss_reduction = "mean"
config.pad_token_id=processor.tokenizer.pad_token_id

# Save the updated config and load the model
config.save_pretrained("./wav2vec2-tune-timit-asr-for-phoneme")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base", config=config)

In [None]:
model.freeze_feature_extractor()

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=4, #og: 32
  evaluation_strategy="steps",
  num_train_epochs=4,#og: 30
  fp16=False,
  gradient_checkpointing=True, 
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  #no_cuda=True,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
  push_to_hub=True
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub(repo_name)#, commit_message="training model")