# Klasifikasi Surah Al-Qur'an - Contrastive

### Fine-Tuning Wav2Vec

Notebook ini melakukan fine-tuning pada base Wav2Vec.
1. Wav2Vec base dilatih dengan LR = 5e-6.
2. Projection head dilatih dengan LR = 1e-4.

--- 
## 1. Import Pustaka & Pengaturan Awal

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast
from transformers import Wav2Vec2Model
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
from tqdm.notebook import tqdm
import audiomentations as A
import warnings
from collections import defaultdict

warnings.filterwarnings('ignore', category=UserWarning)

--- 
## 2. Memuat Dataset (Dengan Reciter-Wise Split & Metadata Penuh)

In [None]:
DATASET_ROOT = "audio_data_processed"

all_qaris = [d for d in sorted(os.listdir(DATASET_ROOT)) if os.path.isdir(os.path.join(DATASET_ROOT, d))]
print(f"Found {len(all_qaris)} unique reciters.")
train_qaris, test_val_qaris = train_test_split(all_qaris, test_size=0.30, random_state=42)
test_qaris, _ = train_test_split(test_val_qaris, test_size=0.50, random_state=42)
print(f"Training reciters ({len(train_qaris)}): {train_qaris}")
print(f"Test reciters ({len(test_qaris)}): {test_qaris}")

def get_files_and_metadata(qari_list):
    file_paths = []
    metadata = []
    for qari_folder in qari_list:
        qari_path = os.path.join(DATASET_ROOT, qari_folder)
        for filename in os.listdir(qari_path):
            if filename.endswith(".mp3"):
                file_id = filename.split('.')[0]
                if len(file_id) != 6 or not file_id.isdigit():
                    continue
                surah_label = file_id[:3]
                ayah_label = file_id[3:]
                full_path = os.path.join(DATASET_ROOT, qari_folder, filename)
                file_paths.append(full_path)
                metadata.append({
                    "reciter": qari_folder,
                    "surah": surah_label,
                    "ayah": ayah_label,
                    "full_ayah_id": f"{surah_label}-{ayah_label}"
                })
    return file_paths, metadata

X_train_paths, train_metadata = get_files_and_metadata(train_qaris)
X_test_paths, test_metadata = get_files_and_metadata(test_qaris)
y_train_surah_labels = [m['surah'] for m in train_metadata]
y_test_surah_labels = [m['surah'] for m in test_metadata]
all_labels_for_encoder = y_train_surah_labels + y_test_surah_labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels_for_encoder)
class_names = label_encoder.classes_
NUM_CLASSES = len(class_names)
y_train = label_encoder.transform(y_train_surah_labels)
y_test = label_encoder.transform(y_test_surah_labels)
for i, meta in enumerate(train_metadata):
    meta['encoded_surah_label'] = y_train[i]
print(f"\nUkuran data latih: {len(X_train_paths)} file")
print(f"Ukuran data uji: {len(X_test_paths)} file")
print(f"Jumlah kelas (surah): {NUM_CLASSES}")

--- 
## 3. Preprocessing & Data Augmentation

In [None]:
SAMPLE_RATE = 16000
DURATION = 5 

class AudioDataset(Dataset):
    def __init__(self, paths, labels, target_sr=SAMPLE_RATE, duration_s=DURATION, is_train=False):
        self.paths = paths
        self.labels = labels
        self.target_sr = target_sr
        self.num_samples = self.target_sr * duration_s
        self.is_train = is_train
        if self.is_train:
            self.augment = A.Compose([
                A.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
                A.TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
                A.PitchShift(min_semitones=-4, max_semitones=4, p=0.5)
            ])
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        path = self.paths[idx]
        label = self.labels[idx]
        try:
            waveform, sr = torchaudio.load(path)
        except Exception as e:
            print(f"\nError saat memuat file {path}: {e}")
            return torch.zeros(self.num_samples), -1
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(sr, self.target_sr)
            waveform = resampler(waveform)
        if self.is_train:
            samples_np = waveform.numpy().squeeze()
            augmented_samples = self.augment(samples=samples_np, sample_rate=self.target_sr)
            waveform = torch.from_numpy(augmented_samples).unsqueeze(0)
        if waveform.shape[1] > self.num_samples:
            waveform = waveform[:, :self.num_samples]
        else:
            pad_size = self.num_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_size))
        return waveform.squeeze(0), label

--- 
## 4. Model & Arsitektur (Contrastive - Fine-Tuned)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Model akan berjalan di: {device.type}")

wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)

class ContrastiveModel(nn.Module):
    def __init__(self, base_model):
        super(ContrastiveModel, self).__init__()
        self.wav2vec = base_model
        
        self.projection = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        outputs = self.wav2vec(x).last_hidden_state
        
        pooled_output = torch.mean(outputs, dim=1)
        embedding = self.projection(pooled_output)
        return embedding

contrastive_net = ContrastiveModel(wav2vec_model).to(device)

--- 
## 5. Implementasi & Pelatihan

### 5.1 Definisi Triplet Dataset (Logika Hard Mining)

In [None]:
class TripletDataset(Dataset):
    def __init__(self, paths, metadata, dataset_class, is_train=False):
        self.paths = paths
        self.metadata = metadata
        self.encoded_surah_labels = np.array([m['encoded_surah_label'] for m in metadata])
        self.dataset = dataset_class(self.paths, self.encoded_surah_labels, is_train=is_train)
        self.ayah_to_indices = defaultdict(list)
        self.reciter_surah_to_indices = defaultdict(lambda: defaultdict(list))
        self.surah_to_indices = defaultdict(list)
        for i, meta in enumerate(self.metadata):
            self.ayah_to_indices[meta['full_ayah_id']].append(i)
            self.reciter_surah_to_indices[meta['reciter']][meta['surah']].append(i)
            self.surah_to_indices[meta['surah']].append(i)
        self.all_surahs = list(self.surah_to_indices.keys())
    def __len__(self):
        return len(self.paths)
    def _find_positive(self, anchor_meta, anchor_index):
        anchor_ayah_id = anchor_meta['full_ayah_id']
        anchor_reciter = anchor_meta['reciter']
        candidate_indices = self.ayah_to_indices.get(anchor_ayah_id, [])
        different_reciters = [idx for idx in candidate_indices if self.metadata[idx]['reciter'] != anchor_reciter]
        if different_reciters:
            return np.random.choice(different_reciters)
        same_reciter_diff_file = [idx for idx in candidate_indices if idx != anchor_index]
        if same_reciter_diff_file:
            return np.random.choice(same_reciter_diff_file)
        return anchor_index
    def _find_negative(self, anchor_meta, anchor_index):
        anchor_surah = anchor_meta['surah']
        anchor_reciter = anchor_meta['reciter']
        reciter_surahs = self.reciter_surah_to_indices.get(anchor_reciter, {})
        different_surahs = [surah for surah in reciter_surahs.keys() if surah != anchor_surah]
        if different_surahs:
            neg_surah = np.random.choice(different_surahs)
            neg_index = np.random.choice(reciter_surahs[neg_surah])
            return neg_index
        neg_surah = anchor_surah
        while neg_surah == anchor_surah:
            neg_surah = np.random.choice(self.all_surahs)
        neg_index = np.random.choice(self.surah_to_indices[neg_surah])
        return neg_index
    def __getitem__(self, index):
        anchor_meta = self.metadata[index]
        anchor_waveform, _ = self.dataset[index]
        positive_index = self._find_positive(anchor_meta, index)
        positive_waveform, _ = self.dataset[positive_index]
        negative_index = self._find_negative(anchor_meta, index)
        negative_waveform, _ = self.dataset[negative_index]
        return anchor_waveform, positive_waveform, negative_waveform

### 5.2 Menjalankan Eksperimen (Contrastive Loss - Fine-Tuned)

In [None]:
# Hyperparameters
BATCH_SIZE = 24
NUM_WORKERS = 8
PROJECTION_LR = 1e-4
WAV2VEC_LR = 5e-6
NUM_EPOCHS = 100
torch.backends.cudnn.benchmark = True

# --- Training Setup ---
triplet_train_dataset = TripletDataset(X_train_paths, train_metadata, AudioDataset, is_train=True)
train_loader = DataLoader(triplet_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

print("Training DataLoader is ready.")

# --- Model & Optimizer --- #
optimizer = optim.Adam([
    {'params': contrastive_net.wav2vec.parameters(), 'lr': WAV2VEC_LR},
    {'params': contrastive_net.projection.parameters(), 'lr': PROJECTION_LR}
])

# --- Loss & Scaler ---
contrastive_loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)
scaler = GradScaler()

print("\nStarting Training (Contrastive Loss - Fine-Tuned)...")

# --- Early Stopping Variables ---
patience = 10
epochs_no_improve = 0
best_train_loss = float('inf')
model_save_path = "contrastive_model.pth"

for epoch in range(NUM_EPOCHS):
    contrastive_net.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: N/A")
    
    for anchor, positive, negative in progress_bar:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(device_type='cuda'):
            embed_a = contrastive_net(anchor)
            embed_p = contrastive_net(positive)
            embed_n = contrastive_net(negative)
            loss = contrastive_loss_fn(embed_a, embed_p, embed_n)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_train_loss += loss.item()
        progress_bar.set_description(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {loss.item():.4f}")
        
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"   -> Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")
    
    # --- Early Stopping & Saving Logic ---
    if avg_train_loss < best_train_loss:
        best_train_loss = avg_train_loss
        torch.save(contrastive_net.state_dict(), model_save_path)
        print(f"   -> New best model saved to {model_save_path} with loss: {best_train_loss:.4f}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"   -> No improvement for {epochs_no_improve} epoch(s). Best loss remains {best_train_loss:.4f}")

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs without improvement.")
        break

print(f"\nPelatihan contrastive selesai.")
print(f"Model terbaik disimpan di: {model_save_path}")

--- 
## 6. Evaluasi pada Data Uji (Test Set)

### 6.1 Memuat Bobot Model Terbaik

In [None]:
EVALUATION_MODEL_PATH = "contrastive_model.pth"

evaluation_model = ContrastiveModel(wav2vec_model).to(device)
evaluation_model.load_state_dict(torch.load(EVALUATION_MODEL_PATH))

print(f"Model weights from {EVALUATION_MODEL_PATH} loaded successfully for evaluation.")

### 6.2 Ekstraksi Embedding & Klasifikasi

In [None]:
def extract_embeddings(model, paths, labels, dataset_class):
    model.eval()
    embeddings = []
    dataset = dataset_class(paths, labels, is_train=False)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    
    with torch.no_grad():
        for waveforms, _ in tqdm(loader, desc="Mengekstrak Embedding"):
            waveforms = waveforms.to(device)
            with autocast(device_type='cuda'):
                 embeds = model(waveforms)
                 embeddings.append(embeds.cpu().numpy())
            
    return np.vstack(embeddings)

print("Mengekstrak embedding final dari data latih dan uji...")
X_train_embed = extract_embeddings(evaluation_model, X_train_paths, y_train, AudioDataset)
X_test_embed = extract_embeddings(evaluation_model, X_test_paths, y_test, AudioDataset)
print("Ekstraksi embedding selesai.")

### 6.3. Klasifikasi dengan k-NN (Baseline)

In [None]:
print("Melatih Classifier k-NN (Baseline)...")
knn_classifier = KNeighborsClassifier(n_neighbors=5, metric='cosine')
knn_classifier.fit(X_train_embed, y_train)
y_pred_knn = knn_classifier.predict(X_test_embed)
print("Pelatihan k-NN selesai.")

### 6.4. Klasifikasi dengan MLP (Pembanding)

In [None]:
print("Melatih Classifier MLP (Pembanding)...")
mlp_classifier = MLPClassifier(
    hidden_layer_sizes=(256, 128), 
    max_iter=500, 
    random_state=42, 
    early_stopping=True,
    n_iter_no_change=10
)
mlp_classifier.fit(X_train_embed, y_train)
y_pred_mlp = mlp_classifier.predict(X_test_embed)
print("Pelatihan MLP selesai.")

--- 
## 7. Hasil Evaluasi Kinerja Model

In [None]:
print("="*50)
print(f"HASIL EVALUASI MODEL DARI: {EVALUATION_MODEL_PATH}")
print("="*50)

print("\n--- Kinerja k-Nearest Neighbors (k-NN) ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_knn):.4f}")
print("\nLaporan Klasifikasi:")
print(classification_report(y_test, y_pred_knn, target_names=class_names, zero_division=0))

print("\n--- Kinerja Multilayer Perceptron (MLP) ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_mlp):.4f}")
print("\nLaporan Klasifikasi:")
print(classification_report(y_test, y_pred_mlp, target_names=class_names, zero_division=0))