# Imports

In [21]:
from dataset_generator import DatasetGenerator
from utils import extract_all_chars, save_dict_as_json
from data_preprocessor import Preprocessor
from data_augmentation import AudioAugmentation
from data_collator import DataCollatorCTCWithPadding

import os
import shutil
import torch
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from torch.utils.data import Dataset
import evaluate
from transformers import (
    Wav2Vec2CTCTokenizer, 
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer
)


In [22]:
AUDIO_DIR = 'dataset'
DATA_PATH = 'data.csv'

word_character_map = {
    'iskljuci': 'isključi',
    'ukljuci': 'uključi'
}

TORCH_DATASETS_DIR = 'torch_datasets'

MODEL_NAME = "wav2vec2-finetuned-voice-commands"
MODELS_DIR = 'models'
MODEL_LOGS_DIR = 'models/logs'

# Gather from folder

In [23]:
dg = DatasetGenerator(word_character_map)


dg.generate(input_dir=AUDIO_DIR, output_file=DATA_PATH)

Dataset saved to data.csv


In [24]:
df = pd.read_csv(DATA_PATH)
df.head()

Unnamed: 0,audio_filepath,text
0,dataset\iskljuci-19-21-1.wav,isključi
1,dataset\iskljuci-19-21-2.wav,isključi
2,dataset\iskljuci-19-21-3.wav,isključi
3,dataset\iskljuci-38-21-1.wav,isključi
4,dataset\iskljuci-38-21-2.wav,isključi


# Create vocabulary

In [25]:
VOCAB_PATH = 'vocab.json'

In [26]:
words = df['text'].unique()

vocab_list = extract_all_chars(words)

vocab_list.extend(['|', '[UNK]', '[PAD]'])
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'i': 0,
 'l': 1,
 's': 2,
 'a': 3,
 'r': 4,
 'k': 5,
 'u': 6,
 'č': 7,
 'e': 8,
 'z': 9,
 'j': 10,
 't': 11,
 'o': 12,
 'v': 13,
 '|': 14,
 '[UNK]': 15,
 '[PAD]': 16}

In [27]:
save_dict_as_json(VOCAB_PATH, vocab_dict)

# Loading the tokenizer, feature extractor and processor

In [28]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [29]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, 
                                             sampling_rate=16000, 
                                             padding_value=0.0, 
                                             do_normalize=True, 
                                             return_attention_mask=True)

In [30]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# Preprocessing and Augmentation

In [31]:
aug = AudioAugmentation(min_noise=0, max_noise=.005, time_stretch_rate=.9, pitch_shift_n_steps=2)

train_preprocessor = Preprocessor(processor=processor, sr=16000, audio_augmentation=aug, augment_count=2)
val_preprocessor = Preprocessor(processor=processor, sr=16000)

In [32]:
# Split dataset into train and validation
train_df, val_df = train_test_split(df, test_size=0.2, shuffle=True, random_state=42)

# Preprocess data
preprocessed_train_data = []
preprocessed_val_data = []

# Preprocess training data
for _, row in train_df.iterrows():
    preprocessed_train_data.extend(train_preprocessor.preprocess(row))

# Preprocess validation data
for _, row in val_df.iterrows():
    preprocessed_val_data.extend(val_preprocessor.preprocess(row))

# Create new dataframes for the preprocessed data
train_df = pd.DataFrame(preprocessed_train_data)
val_df = pd.DataFrame(preprocessed_val_data)

train_df.head()

Unnamed: 0,input_values,labels
0,"[tensor(0.1956), tensor(0.3131), tensor(0.2815...","[tensor(12), tensor(11), tensor(13), tensor(12..."
1,"[tensor(0.4701), tensor(0.3975), tensor(0.6470...","[tensor(12), tensor(11), tensor(13), tensor(12..."
2,"[tensor(0.9730), tensor(0.1340), tensor(0.2596...","[tensor(12), tensor(11), tensor(13), tensor(12..."
3,"[tensor(-0.0024), tensor(-0.0061), tensor(-0.0...","[tensor(13), tensor(4), tensor(3), tensor(11),..."
4,"[tensor(-0.0405), tensor(0.0209), tensor(0.112...","[tensor(13), tensor(4), tensor(3), tensor(11),..."


# Generate PyTorch dataset

In [33]:
# Create a PyTorch Dataset class
class AudioDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            "input_values": self.data.iloc[idx]["input_values"],
            "labels": self.data.iloc[idx]["labels"],
        }

# Prepare datasets
train_dataset = AudioDataset(train_df)
val_dataset = AudioDataset(val_df)

## Save the dataset

In [34]:
if not os.path.exists(TORCH_DATASETS_DIR):
    os.mkdir(TORCH_DATASETS_DIR)

torch.save(train_dataset, os.path.join(TORCH_DATASETS_DIR, 'train.pt'))
torch.save(val_dataset, os.path.join(TORCH_DATASETS_DIR, 'val.pt'))

# Loading the model

In [35]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.2,
    hidden_dropout=0.2,
    feat_proj_dropout=0.05,
    mask_time_prob=0.04,
    layerdrop=0.15,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

model.config.vocab_size = len(processor.tokenizer)

model.freeze_feature_encoder()
model.gradient_checkpointing_enable()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Training

In [36]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
if not os.path.exists(MODELS_DIR):
    os.mkdir(MODELS_DIR)

if not os.path.exists(MODEL_LOGS_DIR):
    os.mkdir(MODEL_LOGS_DIR)

training_args = TrainingArguments(
    output_dir=os.path.join(MODELS_DIR, MODEL_NAME),        # Directory to save model checkpoints
    evaluation_strategy="steps",                            # Evaluate every N steps
    per_device_train_batch_size=8,                          # Batch size for training
    per_device_eval_batch_size=8,                           # Batch size for evaluation
    gradient_accumulation_steps=2,                          # Gradient accumulation
    learning_rate=3e-4,                                     # Learning rate
    warmup_steps=500,                                       # Warmup steps for LR scheduler
    num_train_epochs=150,                                   # Number of epochs
    logging_dir=MODEL_LOGS_DIR,                             # Directory for logging
    logging_steps=10,                                       # Log every N steps
    save_steps=100,                                         # Save checkpoint every N steps
    save_total_limit=2,                                     # Only keep the last 2 checkpoints
    fp16=True,                                              # Use mixed precision
    dataloader_num_workers=2,                               # Number of workers for DataLoader
    load_best_model_at_end=True,                            # Load the best model at the end
    metric_for_best_model="wer",                            # Metric to determine best model
    greater_is_better=False,                                # Smaller WER is better
    seed=42,                                                # Random seed for reproducibility
)


In [38]:
# Define WER metric
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)

    # Decode predictions and labels
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(pred.label_ids, skip_special_tokens=True)

    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}


In [42]:
# Initialize Trainer
trainer = Trainer(
    model=model,                               # Wav2Vec2 model
    data_collator=data_collator,               # Data Collator
    args=training_args,                        # Training arguments
    train_dataset=train_dataset,               # Training dataset
    eval_dataset=val_dataset,                  # Validation dataset
    processing_class=processor,                # Processor
    compute_metrics=compute_metrics,           # WER metric
)

In [None]:
# Start training
trainer.train()