# Thesis Demo: Full Replication (k-NN & MLP)

**Author:** Muhammad Rafie Hamizan

This notebook replicates the full evaluation methodology using both **k-NN** and **MLP** classifiers.

### Methodology
1. **Feature Extraction**: Extracts embeddings from the **Full Training Set** (Reference) and **Full Test Set** (Query) using the fine-tuned Wav2Vec 2.0 model.
2. **k-NN Evaluation**: Classifies test data based on geometric distance to training data.
3. **MLP Evaluation**: Trains a neural network classifier on the extracted training embeddings and predicts the test set.
4. **Comparative Visualization**: Visualizes the decision boundaries of both models using t-SNE.

In [None]:
import os
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Model
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.manifold import TSNE
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, accuracy_score
from tqdm.notebook import tqdm

# Configuration
DATASET_ROOT = "audio_data_processed"
MODEL_PATH = "contrastive_model.pth"
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Running on: {DEVICE}")

### 1. Data Pipeline (Reciter-Wise Split)

In [None]:
# --- 1. Define Reciter Splits ---
all_qaris = [d for d in sorted(os.listdir(DATASET_ROOT)) if os.path.isdir(os.path.join(DATASET_ROOT, d))]
train_qaris, test_val_qaris = train_test_split(all_qaris, test_size=0.30, random_state=42)
test_qaris, _ = train_test_split(test_val_qaris, test_size=0.50, random_state=42)

print(f"Training Reciters ({len(train_qaris)}): {train_qaris}")
print(f"Test Reciters ({len(test_qaris)}): {test_qaris}")

# --- 2. Helper to get files ---
def get_files_and_metadata(qari_list):
    file_paths = []
    surah_labels = []
    
    for qari_folder in qari_list:
        qari_path = os.path.join(DATASET_ROOT, qari_folder)
        if not os.path.exists(qari_path): continue
        for filename in os.listdir(qari_path):
            if filename.endswith(".mp3"):
                file_id = filename.split('.')[0]
                if len(file_id) != 6 or not file_id.isdigit(): continue
                surah_label = file_id[:3]
                file_paths.append(os.path.join(DATASET_ROOT, qari_folder, filename))
                surah_labels.append(surah_label)
    return file_paths, surah_labels

# --- 3. Load File Paths ---
print("\nScanning files...")
X_train_paths, y_train_raw = get_files_and_metadata(train_qaris)
X_test_paths, y_test_raw = get_files_and_metadata(test_qaris)

print(f"Training Files: {len(X_train_paths)}")
print(f"Test Files:     {len(X_test_paths)}")

# --- 4. Encode Labels ---
label_encoder = LabelEncoder()
label_encoder.fit(y_train_raw + y_test_raw)
y_train_encoded = label_encoder.transform(y_train_raw)
y_test_encoded = label_encoder.transform(y_test_raw)
class_names = label_encoder.classes_

In [None]:
class AudioDataset(Dataset):
    def __init__(self, paths, labels, target_sr=16000, duration_s=5):
        self.paths = paths
        self.labels = labels
        self.target_sr = target_sr
        self.num_samples = target_sr * duration_s
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        path = self.paths[idx]
        try:
            waveform, sr = torchaudio.load(path)
        except: return torch.zeros(1, self.num_samples), -1
        if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(sr, self.target_sr)
            waveform = resampler(waveform)
        if waveform.shape[1] > self.num_samples:
            waveform = waveform[:, :self.num_samples]
        else:
            pad_size = self.num_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_size))
        return waveform.squeeze(0), self.labels[idx]

### 2. Model Loading & Feature Extraction

In [None]:
class ContrastiveModel(nn.Module):
    def __init__(self, base_model):
        super(ContrastiveModel, self).__init__()
        self.wav2vec = base_model
        self.projection = nn.Sequential(
            nn.Linear(768, 256), nn.ReLU(), nn.Linear(256, 128)
        )
    def forward(self, x):
        outputs = self.wav2vec(x).last_hidden_state
        pooled_output = torch.mean(outputs, dim=1)
        embedding = self.projection(pooled_output)
        return embedding

wav2vec_base = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model = ContrastiveModel(wav2vec_base).to(DEVICE)
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
else:
    raise FileNotFoundError(f"Missing {MODEL_PATH}")

def extract_embeddings(paths, labels, desc):
    dataset = AudioDataset(paths, labels)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=2)
    embeddings, labels_out = [], []
    with torch.no_grad():
        for waveforms, labs in tqdm(loader, desc=desc):
            waveforms = waveforms.to(DEVICE)
            embeds = model(waveforms)
            embeddings.append(embeds.cpu().numpy())
            labels_out.append(labs.numpy())
    return np.vstack(embeddings), np.concatenate(labels_out)

print("Extracting Training Data...")
X_train_embed, y_train_labels = extract_embeddings(X_train_paths, y_train_encoded, desc="Train")
print("Extracting Test Data...")
X_test_embed, y_test_labels = extract_embeddings(X_test_paths, y_test_encoded, desc="Test")

### 3. Classifier 1: k-Nearest Neighbors (k-NN)

In [None]:
print("Training k-NN...")
knn = KNeighborsClassifier(n_neighbors=5, metric='cosine')
knn.fit(X_train_embed, y_train_labels)
y_pred_knn = knn.predict(X_test_embed)

acc_knn = accuracy_score(y_test_labels, y_pred_knn)
print(f"\n>>> k-NN Accuracy: {acc_knn:.4f}")
print(classification_report(y_test_labels, y_pred_knn, target_names=class_names, zero_division=0))

### 4. Classifier 2: Multilayer Perceptron (MLP)

In [None]:
print("Training MLP (on extracted embeddings)...")
mlp = MLPClassifier(hidden_layer_sizes=(256, 128), max_iter=500, random_state=42)
mlp.fit(X_train_embed, y_train_labels)
y_pred_mlp = mlp.predict(X_test_embed)

acc_mlp = accuracy_score(y_test_labels, y_pred_mlp)
print(f"\n>>> MLP Accuracy: {acc_mlp:.4f}")
print(classification_report(y_test_labels, y_pred_mlp, target_names=class_names, zero_division=0))

### 5. Comparative Visualization (t-SNE)
We visualize the Test Set embeddings colored by: **Ground Truth**, **k-NN Prediction**, and **MLP Prediction**.

In [None]:
print("Running t-SNE projection on Test Data...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_test_2d = tsne.fit_transform(X_test_embed)

# Setup Plotting Data
df_viz = pd.DataFrame({
    'x': X_test_2d[:, 0],
    'y': X_test_2d[:, 1],
    'Ground Truth': [class_names[i] for i in y_test_labels],
    'k-NN Pred': [class_names[i] for i in y_pred_knn],
    'MLP Pred': [class_names[i] for i in y_pred_mlp]
})

fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharey=True)

# 1. Ground Truth
sns.scatterplot(ax=axes[0], data=df_viz, x='x', y='y', hue='Ground Truth', palette='husl', s=50, alpha=0.7, legend=False)
axes[0].set_title(f"Ground Truth Labels\n(Target)", fontsize=14, fontweight='bold')

# 2. k-NN Predictions
sns.scatterplot(ax=axes[1], data=df_viz, x='x', y='y', hue='k-NN Pred', palette='husl', s=50, alpha=0.7, legend=False)
axes[1].set_title(f"k-NN Predictions\n(Accuracy: {acc_knn:.4f})", fontsize=14, fontweight='bold')

# 3. MLP Predictions
sns.scatterplot(ax=axes[2], data=df_viz, x='x', y='y', hue='MLP Pred', palette='husl', s=50, alpha=0.7, legend=False)
axes[2].set_title(f"MLP Predictions\n(Accuracy: {acc_mlp:.4f})", fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig("comparison_tsne_knn_mlp.png", dpi=300)
plt.show()

print("Visualization saved as 'comparison_tsne_knn_mlp.png'.")