## Fine Tuning superb/wav2vec2-base-superb-er With Shemo Persian Dataset


##### Importing Dependencies And Defining Constants

In [117]:
import os
import sys
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification, Trainer, TrainingArguments
import torch
import librosa
import numpy as np
from sklearn.model_selection import train_test_split
import os
import json
import pandas as pd
from sklearn.metrics import accuracy_score
from tqdm import tqdm

CWD_PATH = os.getcwd()
ROOT_PATH = os.path.abspath(os.path.join(CWD_PATH, '..'))
UTILS_PATH = os.path.join(ROOT_PATH, 'utils')
DATASETS_BASE_PATH = os.path.join(ROOT_PATH, 'data')
RESULTS_PATH = os.path.join(ROOT_PATH,'results')
if UTILS_PATH not in sys.path:
    sys.path.append(UTILS_PATH)

FEATURES_PATH = os.path.join(ROOT_PATH, 'features')
MODELS_PATH = os.path.join(ROOT_PATH, 'models')
os.makedirs(FEATURES_PATH, exist_ok=True)
os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH,exist_ok=True)
MODEL_NAME = "superb/wav2vec2-base-superb-er"
AUDIO_MAX_LENGTH = 8 
SAMPLE_RATE = 16000

##### Loading The Original Model

In [118]:
extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=4
)
model.gradient_checkpointing_enable()
id2label = model.config.id2label
label2id = {v: k for k, v in id2label.items()}



#### Defining Functions And Classes That Are Going To Be Used In Training Process

In [128]:
def load_and_validate_data(basePath, representationFilePath):
    with open(representationFilePath, 'r', encoding='utf-8') as f:
        metadata = json.load(f)    
    paths, labels = [], []
    for fileName, details in metadata.items():
        filePath = os.path.join(basePath, details["path"])
        if os.path.exists(filePath):
            try:
                librosa.load(filePath, sr=SAMPLE_RATE, duration=1)
                paths.append(filePath)
                labels.append(details["emotion"].lower())                
            except:
                print(f"Skipping corrupted file: {filePath}")          
    return pd.DataFrame({'speech': paths, 'label': labels})

def extract_features(file_path):
    try:
        audio, _ = librosa.load(file_path, sr=SAMPLE_RATE, duration=AUDIO_MAX_LENGTH)
        inputs = extractor(
            audio,
            sampling_rate=SAMPLE_RATE,
            return_tensors="pt",
            padding="max_length",
            max_length=SAMPLE_RATE*AUDIO_MAX_LENGTH,
            truncation=True
        )
        return inputs.input_values[0].numpy()
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return np.zeros(SAMPLE_RATE*AUDIO_MAX_LENGTH)

def precompute_and_save_features(df, save_path):
    features = []
    for path in tqdm(df['speech'], desc="Extracting features"):
        features.append(extract_features(path))
    np.save(save_path, np.array(features))

class AudioFeaturesDataset(torch.utils.data.Dataset):
    def __init__(self, features_path, features):
        self.features = np.load(features_path, mmap_mode='r')
        self.features = features
        
    def __len__(self):
        return len(self.features)
        
    def __getitem__(self, idx):
        return {
            'input_values': torch.tensor(self.features[idx]),
            'features': torch.tensor(self.features[idx])
        }
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

#### Loading Shemo Dataset

In [129]:
shemo_dataset_audio_files_path = os.path.join(DATASETS_BASE_PATH,'shemo')
shemo_dataset_representation_file_path = os.path.join(DATASETS_BASE_PATH, 'shemo/modified_shemo.json')
shemo_df = load_and_validate_data(
    shemo_dataset_audio_files_path,
    shemo_dataset_representation_file_path
)

#### Mapping And Filtering Shemo DataFrame

In [130]:
label_mapping = {'happiness': 'hap', 'anger': 'ang', 'sadness': 'sad', 'neutral': 'neu'}
shemo_df = shemo_df[shemo_df['label'].isin(label_mapping.keys())].copy()
shemo_df['label'] = shemo_df['label'].map(label_mapping)
shemo_df['label_id'] = shemo_df['label'].map(label2id)
print(len(shemo_df))

2766


#### Splitting Data Frame For The Test And Train
in here we used test size of 0.8, in train_test_split function so we could have a pretty smaller train dataset, so the ram can handle it, for the training process we give the model a dataset of 400 so the ram could handle it, the test dataset length is almost 500, for the testing we use all the remaining data and do it manually which is probably a good thing because we see if our model has the problem of over fitting.

In [131]:
train_df, rest_df = train_test_split(shemo_df,test_size=0.8 , random_state=42)
test_df = rest_df.groupby('label').head(100)

#### Extracting And Saving Features For Train And Test Data Frame

In [132]:
print("Precomputing training features...")
precompute_and_save_features(train_df, f"{FEATURES_PATH}/shemo_train_features.npy")
print("Precomputing evaluation features...")
precompute_and_save_features(test_df, f"{FEATURES_PATH}/shemo_test_features.npy")

Precomputing training features...


Extracting features: 100%|██████████| 553/553 [00:01<00:00, 340.45it/s]


Precomputing evaluation features...


Extracting features: 100%|██████████| 400/400 [00:01<00:00, 325.91it/s]


In [133]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    logging_steps=100,
    report_to="none",
    dataloader_num_workers=2,
)

In [134]:
train_dataset = AudioFeaturesDataset(
    f"{FEATURES_PATH}/shemo_train_features.npy",
    train_df['label_id'].values
)

test_dataset = AudioFeaturesDataset(
    f"{FEATURES_PATH}/shemo_train_features.npy",
    test_df['label_id'].values
)


In [135]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

In [136]:
fine_tuned_with_shemo_model_path = os.path.join(MODELS_PATH,'w2v_fine_tuned_with_shemo_voice_based_semantic_analytics')

In [None]:
print("Starting training...")
try:
    trainer.train()
    print("Training completed successfully!")
    model.save_pretrained(fine_tuned_with_shemo_model_path)
    extractor.save_pretrained(fine_tuned_with_shemo_model_path)
    print("Model saved successfully")
except Exception as e:
    print(f"Training failed: {str(e)}")
    print("Saving current progress...")

Starting training...


Step,Training Loss,Validation Loss


Training completed successfully!
Model saved successfully


In [137]:
model = Wav2Vec2ForSequenceClassification.from_pretrained(fine_tuned_with_shemo_model_path)
extractor = AutoFeatureExtractor.from_pretrained(fine_tuned_with_shemo_model_path)

id2label = model.config.id2label
label2id = {v: k for k, v in id2label.items()}

def predict(audio_path):
    audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, duration=8)
    inputs = extractor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    return id2label[torch.argmax(logits).item()]

number_of_correct_predictions = 0
number_of_incorrect_predictions = 0
prediction_results = []
for index, row in tqdm(rest_df.iterrows()):
    prediction = predict(row['speech'])

    result = {
        'file_path': row['speech'],
        'predicted_label': prediction,
        'correct_label': row['label'],
        'correct': prediction == row['label']
    }
    if prediction == row['label']:
        number_of_correct_predictions += 1
    else:
        number_of_incorrect_predictions += 1
        prediction_results.append(result)
results_df = pd.DataFrame(prediction_results)
result_path = os.path.join(RESULTS_PATH,"shemo_dataset_test_result.xlsx")
results_df.to_excel(result_path, index=False)
print(f"Results saved to {result_path}")

2213it [10:20,  3.57it/s]

Results saved to /home/dbk/fine-tuned-voice-based-semantic-analytics-for-Persian-language/results/shemo_dataset_test_result.xlsx





In [138]:
print("Final Result")
print(f"number of correct predictions: {number_of_correct_predictions}" )
print(f"number of incorrect predictions: {number_of_incorrect_predictions}")
accuracy   =   100*number_of_correct_predictions/(number_of_correct_predictions+number_of_incorrect_predictions)
print(f"accuracy: {accuracy:.2f} percent")

Final Result
number of correct predictions: 1872
number of incorrect predictions: 341
accuracy: 84.59 percent
