# Audio Classification using Wav2Vec2 (Nepali Speech)

This notebook presents an end-to-end pipeline for **audio classification using the Wav2Vec2 model**, with a focus on **Nepali speech data**.  
It covers key stages including **dataset exploration, audio preprocessing, model fine-tuning, and evaluation**.

The objective is to demonstrate how self-supervised speech models like Wav2Vec2 can be effectively adapted for **low-resource languages such as Nepali** to support tasks like speech understanding and language technology development.


In [None]:
pip install datasets transformers torch

In [None]:
import librosa
import numpy as np
import librosa.display
import matplotlib.pyplot as plt
import pandas as pd
import glob
import torchaudio
import os
import torch
import numpy as np

In [None]:
from datasets import load_dataset
from datasets import Dataset
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split

In [None]:
from datasets import Dataset, Audio

## Load Audio Dataset
We begin by loading the audio dataset to understand its structure, labels, and sample distribution.

In [None]:
def load_audio_dataset(root_dir):
    data = {"audio": [], "label": []}
    label_map = {folder: idx for idx, folder in enumerate(sorted(os.listdir(root_dir)))}

    for folder, label_idx in label_map.items():
        folder_path = os.path.join(root_dir, folder)
        if os.path.isdir(folder_path):
            for file_name in os.listdir(folder_path):
                if file_name.endswith(('.wav', '.flac')):  # Ensure compatibility
                    file_path = os.path.join(folder_path, file_name)
                    try:
                        # Load audio
                        waveform, sample_rate = torchaudio.load(file_path)

                        # Convert to mono if stereo
                        if waveform.shape[0] > 1:
                            waveform = waveform.mean(dim=0)

                        # Resample to 16kHz
                        if sample_rate != 16000:
                            waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

                        # Convert tensor to NumPy and flatten
                        waveform = waveform.numpy().flatten().tolist()

                        data["audio"].append(waveform)
                        data["label"].append(label_idx)

                    except Exception as e:
                        print(f"Error loading {file_path}: {e}")  # Debugging

    # Convert to Hugging Face Dataset format
    dataset = Dataset.from_dict(data)

    return dataset, label_map


In [None]:
# Load dataset
dataset, label_map = load_audio_dataset('/kaggle/input/audiodataset/Dataset_Arc')

# Print class mappings
print(f"Class to Label Mapping: {label_map}")



## Sample Audio Waveform
To better understand the raw audio signal, we visualize the waveform of a sample audio file from the dataset.


In [None]:
import librosa
import librosa.display


# Path to your audio file
file_path = '/kaggle/input/audiodataset/Dataset_Arc/Eight/eight-2018-05-30T11_28_25.746Z.wav'  # Replace with the path to your audio file

# Load audio using librosa
waveform, sample_rate = librosa.load(file_path, sr=16000)  # Resample to 16kHz directly

# Debugging: Check waveform shape and sample rate
print(f"Waveform shape: {waveform.shape}")
print(f"Sample rate: {sample_rate}")

# Check if waveform is non-empty before plotting
if waveform.size > 0:
    # Plot the waveform
    plt.figure(figsize=(10, 4))
    librosa.display.waveshow(waveform, sr=sample_rate)
    plt.title(f"Waveform for Audio: {file_path}")
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.show()
else:
    print("The waveform is empty. Check the audio file.")


## Load Wav2vec2Processor

In [None]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

## Preprocess the audio in batch

In [None]:
max_length = 32000

def preprocess_function(batch):
    # Extract the list of audio waveforms and labels
    audio_list = batch["audio"]  # List of audio arrays
    labels_list = batch["label"]  # List of labels

    # Process audio using the Wav2Vec2 processor
    input_values = processor(
        audio_list, return_tensors="pt", sampling_rate=16000, truncation=True, 
        padding="max_length", max_length=max_length
    ).input_values  # This will be a tensor of shape (batch_size, max_length)

    # Convert labels to tensors
    labels_tensor = torch.tensor(labels_list)  # Convert label list to tensor
    
    return {"input_values": input_values, "labels": labels_tensor}

# Apply preprocessing in batches
dataset_mapped = dataset.map(preprocess_function, batched=True, remove_columns=["audio", "label"])

In [None]:
dataset_mapped

In [None]:
len(label_map)

## Train test split in the ratio 80/20

In [None]:
# Split the dataset into train and test sets (80% train, 20% test)
dataset_split = dataset_mapped.train_test_split(test_size=0.2, seed=42)  

# Extract train and test datasets
train_dataset = dataset_split["train"]
test_dataset = dataset_split["test"]

# Print dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


## Define evaluation metrics for model

In [None]:
pip install evaluate


In [None]:
pip install jiwer

In [None]:
import evaluate
# Load accuracy metric
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


In [None]:
import matplotlib.pyplot as plt
from copy import deepcopy
from transformers import  TrainerCallback

class CustomCallback(TrainerCallback):
    
    def __init__(self, trainer) -> None:
        super().__init__()
        self._trainer = trainer
        self.train_accuracies = []  
        self.eval_accuracies = []  
    
    def on_epoch_end(self, args, state, control, **kwargs):
        if control.should_evaluate:
            
            control_copy = deepcopy(control)
        
            train_metrics = self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train")
            
            eval_metrics = self._trainer.evaluate(eval_dataset=self._trainer.eval_dataset, metric_key_prefix="eval")
     
            train_accuracy = train_metrics.get('train_accuracy', None)
            eval_accuracy = eval_metrics.get('eval_accuracy', None)
    
            if train_accuracy is not None:
                self.train_accuracies.append(train_accuracy)
            if eval_accuracy is not None:
                self.eval_accuracies.append(eval_accuracy)
     
            print(f"Train Accuracy: {train_accuracy}")
            print(f"Eval Accuracy: {eval_accuracy}")
            
            return control_copy

## Load Wav2Vec2 Model for Audio Classification

In [None]:
from transformers import  Wav2Vec2ForSequenceClassification

model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "facebook/wav2vec2-base-960h", num_labels=27
)  
model.gradient_checkpointing_enable()

## Define training arguments

In [None]:
from transformers import Trainer, TrainingArguments

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    warmup_steps=500,
    logging_dir="./logs",
    logging_steps=100,
    eval_strategy="epoch",  
    save_strategy="epoch",        
    save_total_limit=1,            
    gradient_accumulation_steps=4,  
    lr_scheduler_type="linear",
    max_grad_norm=0.5,  
    report_to="none",
    push_to_hub=False,
    fp16=True,
    load_best_model_at_end=True,    
    metric_for_best_model="accuracy", 
    greater_is_better=True,         
    seed=42
)

In [None]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor,
    compute_metrics=compute_metrics,
)

## Model training

In [None]:

custom_callback = CustomCallback(trainer)
trainer.add_callback(custom_callback)

train = trainer.train()

In [None]:

train_accuracies = custom_callback.train_accuracies
eval_accuracies = custom_callback.eval_accuracies

epochs = range(1, len(train_accuracies) + 1)

plt.plot(epochs, train_accuracies, label="Train Accuracy")
plt.plot(epochs, eval_accuracies, label="Eval Accuracy")

plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Training and Evaluation Accuracy Over Epochs")
plt.legend() 
plt.grid(False)
plt.show()

In [None]:
history = trainer.state.log_history

train_losses = []
eval_losses = []
epochs = []

for log in history:
    if "loss" in log:  
        train_losses.append(log["loss"])
    if "eval_loss" in log:  
        eval_losses.append(log["eval_loss"])
        epochs.append(log["epoch"])  

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(range(len(train_losses)), train_losses, label="Train Loss")
plt.plot(epochs, eval_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()

plt.tight_layout()
plt.show()


## Save the model

In [None]:
# Save the trained model
model.save_pretrained("./fine_tuned_ASR_model")
processor.save_pretrained("./fine_tuned_ASR_model")

In [None]:
import shutil

shutil.make_archive("/kaggle/working/fine_tuned_ASR_model", 'zip', "./fine_tuned_ASR_model")


In [None]:
class_to_label= {'Eight': 0, 'Five': 1, 'Four': 2, 'Nine': 3, 'One': 4, 'Seven': 5, 'Six': 6, 'Three': 7, 'Two': 8, 'Zero': 9,
                 'अ': 10, 'अं': 11, 'अः': 12, 'आ': 13, 'इ': 14, 'ई': 15, 'उ': 16, 'ऊ': 17,
                 'ए': 18, 'ऐ': 19, 'ओ': 20, 'औ': 21, 'क': 22, 'ख': 23, 'ग': 24, 'घ': 25, 'ङ': 26}

index_to_label = {v: k for k, v in class_to_label.items()}


## Load the fine tuned model

In [None]:
model_path = "/kaggle/input/fine_tuned_asr_model/transformers/default/1"  
processor = Wav2Vec2Processor.from_pretrained(model_path)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)

model.eval()


## Predict using fine_tuned model

In [None]:
MAX_LENGTH = 32000

def predict_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)

    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0)

    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)

    waveform = waveform.numpy().flatten().tolist()

    input_values = processor(
        waveform, return_tensors="pt", sampling_rate=16000,
        truncation=True, padding="max_length", max_length=MAX_LENGTH
    ).input_values

    with torch.no_grad():
        outputs = model(input_values)

    logits = outputs.logits
    predicted_id = torch.argmax(logits, dim=-1).item()

    predicted_label = index_to_label.get(predicted_id, "Unknown Label (Out of Range)")

    return predicted_label


In [None]:
audio_path = "/kaggle/input/test-datasets/nga.wav"  
predicted_label = predict_audio(audio_path)
print(f"Predicted Label: {predicted_label}")


## Confusion Matrix

In [None]:
import numpy as np
import torch
from sklearn.metrics import confusion_matrix, classification_report

In [None]:
import os

model_dir = '/kaggle/input/fine_tuned_asr_model/transformers/default/1'
print(os.listdir(model_dir))


In [None]:
import torch

X_test = torch.stack([torch.tensor(example["input_values"]) for example in test_dataset]) 
y_test = torch.tensor([example["labels"] for example in test_dataset])  



In [None]:

with torch.no_grad(): 
    outputs = model(X_test)  

    logits = outputs.logits

    predicted_labels = torch.argmax(logits, dim=1).numpy() 


conf_matrix = confusion_matrix(y_test.numpy(), predicted_labels)


print("Confusion Matrix:")
print(conf_matrix)

print("Classification Report:")
print(classification_report(y_test.numpy(), predicted_labels))

In [None]:
import seaborn as sns

In [None]:

class_to_label_mapping = {
    'Eight': 0, 'Five': 1, 'Four': 2, 'Nine': 3, 'One': 4, 'Seven': 5,
    'Six': 6, 'Three': 7, 'Two': 8, 'Zero': 9, 'अ': 10, 'अं': 11,
    'अः': 12, 'आ': 13, 'इ': 14, 'ई': 15, 'उ': 16, 'ऊ': 17, 'ए': 18,
    'ऐ': 19, 'ओ': 20, 'औ': 21, 'क': 22, 'ख': 23, 'ग': 24, 'घ': 25,
    'ङ': 26
}


class_names = list(class_to_label_mapping.keys())


print(class_names)


In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=[str(i) for i in range(conf_matrix.shape[1])],
            yticklabels=[str(i) for i in range(conf_matrix.shape[0])])

plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix Heatmap')

plt.show()