In [None]:
# SMS Scam Detection - Deep Learning Models
# ==========================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import time
import re
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import Adam
import optuna
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    matthews_corrcoef, roc_auc_score, average_precision_score,
    confusion_matrix, classification_report, roc_curve, precision_recall_curve
)
from tqdm.notebook import tqdm
import joblib
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

# Set plotting style
sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (12, 8)

# Download NLTK resources
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)

# Set random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set up project paths
project_dir = '/content/drive/MyDrive/sms-scam-detection'
os.chdir(project_dir)

data_dir = "data/processed/"
model_dir = "models/deep_learning/"
results_dir = "results/"

# Create directories
os.makedirs(model_dir, exist_ok=True)
os.makedirs(os.path.join(results_dir, "metrics"), exist_ok=True)
os.makedirs(os.path.join(results_dir, "visualizations"), exist_ok=True)

class Tokenizer:
    """Simple tokenizer for text data."""

    def __init__(self, vocab_size=10000, max_seq_length=100):
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        self.word_to_idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx_to_word = {0: '<PAD>', 1: '<UNK>'}
        self.word_counts = {}

    def fit(self, texts):
        """Fit the tokenizer to the provided texts."""
        for text in texts:
            for word in text.split():
                self.word_counts[word] = self.word_counts.get(word, 0) + 1

        sorted_words = sorted(self.word_counts.items(), key=lambda x: x[1], reverse=True)

        for i, (word, _) in enumerate(sorted_words[:self.vocab_size - 2]):
            idx = i + 2
            self.word_to_idx[word] = idx
            self.idx_to_word[idx] = word

        print(f"Vocabulary size: {len(self.word_to_idx)}")

    def texts_to_sequences(self, texts):
        """Convert texts to sequences of indices."""
        sequences = []
        for text in texts:
            words = text.split()
            sequence = [self.word_to_idx.get(word, 1) for word in words]
            sequences.append(sequence)
        return sequences

    def pad_sequences(self, sequences):
        """Pad sequences to the maximum sequence length."""
        padded_sequences = []
        for sequence in sequences:
            if len(sequence) > self.max_seq_length:
                padded_sequence = sequence[:self.max_seq_length]
            else:
                padded_sequence = sequence + [0] * (self.max_seq_length - len(sequence))
            padded_sequences.append(padded_sequence)
        return padded_sequences

    def save(self, filepath):
        """Save the tokenizer to a JSON file."""
        tokenizer_data = {
            'vocab_size': self.vocab_size,
            'max_seq_length': self.max_seq_length,
            'word_to_idx': self.word_to_idx,
            'idx_to_word': {int(k): v for k, v in self.idx_to_word.items()},
            'word_counts': self.word_counts
        }

        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        with open(filepath, 'w') as f:
            json.dump(tokenizer_data, f)

        print(f"Tokenizer saved to {filepath}")

    @classmethod
    def load(cls, filepath):
        """Load a tokenizer from a JSON file."""
        with open(filepath, 'r') as f:
            tokenizer_data = json.load(f)

        tokenizer = cls(tokenizer_data['vocab_size'], tokenizer_data['max_seq_length'])
        tokenizer.word_to_idx = tokenizer_data['word_to_idx']
        tokenizer.idx_to_word = {int(k): v for k, v in tokenizer_data['idx_to_word'].items()}
        tokenizer.word_counts = tokenizer_data['word_counts']

        print(f"Tokenizer loaded from {filepath}")
        return tokenizer

class TextDataset(Dataset):
    """Dataset for text classification."""

    def __init__(self, sequences, labels):
        self.sequences = torch.tensor(sequences, dtype=torch.long)
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]

class TextCNN(nn.Module):
    """CNN-based text classifier with regularization."""

    def __init__(self, vocab_size, embedding_dim, max_seq_length, num_filters=128,
                 filter_sizes=(3, 4, 5), num_classes=1, dropout=0.5, l2_reg=0.01):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.l2_reg = l2_reg

        self.embedding_dropout = nn.Dropout(dropout * 0.5)

        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, (filter_size, embedding_dim))
            for filter_size in filter_sizes
        ])

        self.batch_norms = nn.ModuleList([
            nn.BatchNorm2d(num_filters) for _ in filter_sizes
        ])

        self.dropout1 = nn.Dropout(dropout * 0.7)
        self.dropout2 = nn.Dropout(dropout)

        hidden_dim = num_filters * len(filter_sizes) // 2
        self.fc1 = nn.Linear(num_filters * len(filter_sizes), hidden_dim)
        self.fc1_bn = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.embedding_dropout(x)
        x = x.unsqueeze(1)

        conv_outputs = []
        for conv, bn in zip(self.convs, self.batch_norms):
            h = F.relu(bn(conv(x)))
            h = F.max_pool2d(h, (h.size(2), 1)).squeeze(3).squeeze(2)
            conv_outputs.append(h)

        x = torch.cat(conv_outputs, 1)
        x = self.dropout1(x)

        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = self.dropout2(x)

        x = self.fc2(x)
        return x

    def l2_penalty(self):
        """Calculate L2 penalty for regularization."""
        l2_penalty = 0
        for param in self.parameters():
            l2_penalty += torch.norm(param, 2) ** 2
        return self.l2_reg * l2_penalty

class BiLSTMClassifier(nn.Module):
    """BiLSTM-based text classifier with regularization."""

    def __init__(self, vocab_size, embedding_dim, hidden_dim=128, num_layers=2,
                 num_classes=1, dropout=0.5, l2_reg=0.01):
        super(BiLSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.l2_reg = l2_reg
        self.hidden_dim = hidden_dim

        self.embedding_dropout = nn.Dropout(dropout * 0.3)

        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=min(num_layers, 2),
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        self.dropout1 = nn.Dropout(dropout * 0.7)
        self.dropout2 = nn.Dropout(dropout)

        lstm_output_dim = hidden_dim * 2
        hidden_fc_dim = hidden_dim

        self.fc1 = nn.Linear(lstm_output_dim, hidden_fc_dim)
        self.fc1_bn = nn.BatchNorm1d(hidden_fc_dim)
        self.fc2 = nn.Linear(hidden_fc_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.embedding_dropout(x)

        output, (hidden, cell) = self.lstm(x)

        mean_pool = torch.mean(output, dim=1)
        final_states = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)

        x = torch.cat([mean_pool, final_states], dim=1)
        x = x[:, :self.hidden_dim * 2]
        x = self.dropout1(x)

        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = self.dropout2(x)

        x = self.fc2(x)
        return x

    def l2_penalty(self):
        """Calculate L2 penalty for regularization."""
        l2_penalty = 0
        for param in self.parameters():
            l2_penalty += torch.norm(param, 2) ** 2
        return self.l2_reg * l2_penalty

def get_weighted_loss_function(train_df):
    """Create a weighted BCE loss function based on class imbalance."""
    n_samples = len(train_df)
    n_positive = train_df['label'].sum()
    n_negative = n_samples - n_positive

    weight_ratio = n_negative / n_positive

    print(f"Class imbalance - Negative: {n_negative}, Positive: {n_positive}, Ratio: {weight_ratio:.2f}")
    print(f"Setting positive class weight to {weight_ratio:.2f}")

    pos_weight = torch.tensor([weight_ratio]).to(device)
    return nn.BCEWithLogitsLoss(pos_weight=pos_weight)

def create_balanced_sampler(y_train):
    """Create a weighted random sampler to balance class distribution in batches."""
    class_sample_count = np.array([len(np.where(y_train == t)[0]) for t in np.unique(y_train)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in y_train])
    samples_weight = torch.from_numpy(samples_weight).float()

    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

def objective_cnn(trial, train_dataset, val_dataset, vocab_size, max_seq_length, weighted_criterion, balanced_sampler, max_epochs=8):
    """Objective function for Optuna CNN optimization."""
    batch_size = trial.suggest_categorical('batch_size', [16, 32])
    embedding_dim = trial.suggest_categorical('embedding_dim', [50, 100])
    num_filters = trial.suggest_categorical('num_filters', [64, 128])
    dropout = trial.suggest_float('dropout', 0.3, 0.7)
    learning_rate = trial.suggest_float('learning_rate', 1e-4, 5e-3, log=True)
    l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-2, log=True)
    use_balanced_sampling = trial.suggest_categorical('use_balanced_sampling', [True, False])

    if use_balanced_sampling:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=balanced_sampler)
    else:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = TextCNN(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        max_seq_length=max_seq_length,
        num_filters=num_filters,
        dropout=dropout,
        l2_reg=l2_reg
    ).to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=l2_reg)

    best_val_mcc = -1
    patience_counter = 0
    patience = 3

    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs).squeeze()

            loss = weighted_criterion(outputs, labels) + model.l2_penalty()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)

        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).squeeze()
                val_preds.extend((torch.sigmoid(outputs) > 0.5).cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_mcc = matthews_corrcoef(val_labels, val_preds)

        if val_mcc > best_val_mcc:
            best_val_mcc = val_mcc
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            break

        trial.report(val_mcc, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return best_val_mcc

def objective_bilstm(trial, train_dataset, val_dataset, vocab_size, weighted_criterion, balanced_sampler, max_epochs=8):
    """Objective function for Optuna BiLSTM optimization."""
    batch_size = trial.suggest_categorical('batch_size', [16, 32])
    embedding_dim = trial.suggest_categorical('embedding_dim', [50, 100])
    hidden_dim = trial.suggest_categorical('hidden_dim', [64, 128])
    num_layers = trial.suggest_int('num_layers', 1, 2)
    dropout = trial.suggest_float('dropout', 0.3, 0.7)
    learning_rate = trial.suggest_float('learning_rate', 1e-4, 5e-3, log=True)
    l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-2, log=True)
    use_balanced_sampling = trial.suggest_categorical('use_balanced_sampling', [True, False])

    if use_balanced_sampling:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=balanced_sampler)
    else:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = BiLSTMClassifier(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        num_classes=1,
        dropout=dropout,
        l2_reg=l2_reg
    ).to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=l2_reg)

    best_val_mcc = -1
    patience_counter = 0
    patience = 3

    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs).squeeze()

            loss = weighted_criterion(outputs, labels) + model.l2_penalty()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)

        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).squeeze()
                val_preds.extend((torch.sigmoid(outputs) > 0.5).cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_mcc = matthews_corrcoef(val_labels, val_preds)

        if val_mcc > best_val_mcc:
            best_val_mcc = val_mcc
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            break

        trial.report(val_mcc, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return best_val_mcc

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, model_name):
    """Train the model with early stopping and return training statistics."""
    print(f"Training {model_name} on {device} with early stopping...")

    training_stats = {
        'train_loss': [], 'train_acc': [], 'train_f1': [], 'train_mcc': [],
        'val_loss': [], 'val_acc': [], 'val_f1': [], 'val_mcc': [],
        'val_prec': [], 'val_rec': []
    }

    best_val_mcc = -1
    patience_counter = 0
    patience = 3
    best_model_state = None

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Training phase
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels = []

        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs).squeeze()

            loss = criterion(outputs, labels)
            if hasattr(model, 'l2_penalty'):
                loss += model.l2_penalty()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            train_preds.extend((torch.sigmoid(outputs) > 0.5).cpu().numpy())
            train_labels.extend(labels.cpu().numpy())

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = accuracy_score(train_labels, train_preds)
        train_f1 = f1_score(train_labels, train_preds)
        train_mcc = matthews_corrcoef(train_labels, train_preds)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).squeeze()

                loss = criterion(outputs, labels)
                if hasattr(model, 'l2_penalty'):
                    loss += model.l2_penalty()

                val_loss += loss.item() * inputs.size(0)
                val_preds.extend((torch.sigmoid(outputs) > 0.5).cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds)
        val_mcc = matthews_corrcoef(val_labels, val_preds)
        val_prec = precision_score(val_labels, val_preds)
        val_rec = recall_score(val_labels, val_preds)

        # Update training statistics
        training_stats['train_loss'].append(train_loss)
        training_stats['train_acc'].append(train_acc)
        training_stats['train_f1'].append(train_f1)
        training_stats['train_mcc'].append(train_mcc)
        training_stats['val_loss'].append(val_loss)
        training_stats['val_acc'].append(val_acc)
        training_stats['val_f1'].append(val_f1)
        training_stats['val_mcc'].append(val_mcc)
        training_stats['val_prec'].append(val_prec)
        training_stats['val_rec'].append(val_rec)

        # Early stopping check
        if val_mcc > best_val_mcc:
            best_val_mcc = val_mcc
            patience_counter = 0
            best_model_state = model.state_dict().copy()
            print(f"New best validation MCC: {best_val_mcc:.4f}")
        else:
            patience_counter += 1

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, Train MCC: {train_mcc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, Val MCC: {val_mcc:.4f}")
        print(f"Val Precision: {val_prec:.4f}, Val Recall: {val_rec:.4f}")

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation MCC: {best_val_mcc:.4f}")

    return model, training_stats

def evaluate_model(model, test_loader, criterion, device):
    """Evaluate the model on the test set."""
    print("Evaluating model on test set...")

    model.eval()

    test_loss = 0.0
    test_preds = []
    test_probs = []
    test_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)

            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)

            test_probs.extend(probs)
            test_preds.extend(preds)
            test_labels.extend(labels.cpu().numpy())

    test_loss = test_loss / len(test_loader.dataset)

    test_acc = accuracy_score(test_labels, test_preds)
    test_prec = precision_score(test_labels, test_preds)
    test_rec = recall_score(test_labels, test_preds)
    test_f1 = f1_score(test_labels, test_preds)
    test_mcc = matthews_corrcoef(test_labels, test_preds)
    test_roc_auc = roc_auc_score(test_labels, test_probs)
    test_pr_auc = average_precision_score(test_labels, test_probs)
    test_cm = confusion_matrix(test_labels, test_preds)
    test_report = classification_report(test_labels, test_preds, output_dict=True)

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")
    print(f"Test Precision: {test_prec:.4f}")
    print(f"Test Recall: {test_rec:.4f}")
    print(f"Test F1 Score: {test_f1:.4f}")
    print(f"Test MCC: {test_mcc:.4f}")
    print(f"Test ROC AUC: {test_roc_auc:.4f}")
    print(f"Test PR AUC: {test_pr_auc:.4f}")
    print(f"Test Confusion Matrix:\n{test_cm}")
    print(f"Test Classification Report:\n{classification_report(test_labels, test_preds)}")

    return {
        'test_loss': test_loss,
        'test_acc': test_acc,
        'test_prec': test_prec,
        'test_rec': test_rec,
        'test_f1': test_f1,
        'test_mcc': test_mcc,
        'test_roc_auc': test_roc_auc,
        'test_pr_auc': test_pr_auc,
        'test_cm': test_cm,
        'test_report': test_report,
        'test_probs': test_probs,
        'test_preds': test_preds,
        'test_labels': test_labels
    }

def visualize_training(training_stats, model_name):
    """Visualize training progress."""
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(24, 6))

    # Plot loss
    ax1.plot(training_stats['train_loss'], label='Train Loss')
    ax1.plot(training_stats['val_loss'], label='Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)

    # Plot accuracy
    ax2.plot(training_stats['train_acc'], label='Train Accuracy')
    ax2.plot(training_stats['val_acc'], label='Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Score')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)

    # Plot F1 score
    ax3.plot(training_stats['train_f1'], label='Train F1')
    ax3.plot(training_stats['val_f1'], label='Validation F1')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Score')
    ax3.set_title('Training and Validation F1 Score')
    ax3.legend()
    ax3.grid(True)

    # Plot MCC
    ax4.plot(training_stats['train_mcc'], label='Train MCC')
    ax4.plot(training_stats['val_mcc'], label='Validation MCC')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Score')
    ax4.set_title('Training and Validation MCC')
    ax4.legend()
    ax4.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'visualizations', f"{model_name}_training_progress.png"))
    plt.show()

    # Plot precision and recall
    plt.figure(figsize=(10, 6))
    plt.plot(training_stats['val_prec'], label='Validation Precision')
    plt.plot(training_stats['val_rec'], label='Validation Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Validation Precision and Recall')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'visualizations', f"{model_name}_precision_recall.png"))
    plt.show()

def visualize_evaluation(eval_results, model_name):
    """Visualize evaluation results."""
    # Confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        eval_results['test_cm'],
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=['Legitimate', 'Scam'],
        yticklabels=['Legitimate', 'Scam']
    )
    plt.title(f'{model_name} - Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'visualizations', f"{model_name}_confusion_matrix.png"))
    plt.show()

    # ROC curve
    plt.figure(figsize=(8, 6))
    fpr, tpr, _ = roc_curve(eval_results['test_labels'], eval_results['test_probs'])
    plt.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {eval_results["test_roc_auc"]:.4f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'{model_name} - ROC Curve')
    plt.legend(loc='lower right')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'visualizations', f"{model_name}_roc_curve.png"))
    plt.show()

    # Precision-Recall curve
    plt.figure(figsize=(8, 6))
    precision, recall, _ = precision_recall_curve(eval_results['test_labels'], eval_results['test_probs'])
    plt.step(recall, precision, where='post', lw=2, label=f'PR curve (AP = {eval_results["test_pr_auc"]:.4f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'{model_name} - Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'visualizations', f"{model_name}_pr_curve.png"))
    plt.show()

def main():
    """Main execution function."""
    # Load data
    train_df = pd.read_csv(os.path.join(data_dir, "train.csv"))
    val_df = pd.read_csv(os.path.join(data_dir, "val.csv"))
    test_df = pd.read_csv(os.path.join(data_dir, "test.csv"))

    print(f"Loaded data: Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

    # Print class distribution
    print("\nClass Distribution:")
    for name, df in [("Training", train_df), ("Validation", val_df), ("Test", test_df)]:
        print(f"{name} Set:")
        print(df['label'].value_counts(normalize=True) * 100)

    # Calculate imbalance ratio
    train_neg_count = (train_df['label'] == 0).sum()
    train_pos_count = (train_df['label'] == 1).sum()
    imbalance_ratio = train_neg_count / train_pos_count
    print(f"\nImbalance ratio (negative:positive): {imbalance_ratio:.2f}:1")

    # Ensure cleaned_text column exists
    if 'cleaned_text' not in train_df.columns:
        print("Adding cleaned_text column...")
        def clean_text(text):
            if not isinstance(text, str):
                return ""
            text = text.lower()
            text = re.sub(r'https?://\S+|www\.\S+', '', text)
            text = re.sub(r'\S+@\S+', '', text)
            text = re.sub(r'\b\d{10,}\b', '', text)
            text = re.sub(r'[^\x00-\x7F]+', '', text)
            text = re.sub(r'\s+', ' ', text).strip()
            return text

        train_df['cleaned_text'] = train_df['message'].apply(clean_text)
        val_df['cleaned_text'] = val_df['message'].apply(clean_text)
        test_df['cleaned_text'] = test_df['message'].apply(clean_text)

    # Prepare data for deep learning models
    vocab_size = 10000
    max_seq_length = 100
    embedding_dim = 100

    # Create and fit tokenizer
    tokenizer = Tokenizer(vocab_size=vocab_size, max_seq_length=max_seq_length)
    tokenizer.fit(train_df['cleaned_text'].tolist())

    # Save tokenizer
    tokenizer_path = os.path.join(model_dir, 'tokenizer.json')
    tokenizer.save(tokenizer_path)

    # Tokenize and pad sequences
    X_train_seq = tokenizer.pad_sequences(tokenizer.texts_to_sequences(train_df['cleaned_text'].tolist()))
    X_val_seq = tokenizer.pad_sequences(tokenizer.texts_to_sequences(val_df['cleaned_text'].tolist()))
    X_test_seq = tokenizer.pad_sequences(tokenizer.texts_to_sequences(test_df['cleaned_text'].tolist()))

    # Get labels
    y_train = train_df['label'].values
    y_val = val_df['label'].values
    y_test = test_df['label'].values

    # Create datasets
    train_dataset = TextDataset(X_train_seq, y_train)
    val_dataset = TextDataset(X_val_seq, y_val)
    test_dataset = TextDataset(X_test_seq, y_test)

    print(f"Prepared datasets - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Class imbalance handling
    print("\n===== Class Imbalance Handling =====")
    weighted_criterion = get_weighted_loss_function(train_df)
    balanced_sampler = create_balanced_sampler(y_train)

    # Hyperparameter optimization
    print("Optimizing CNN hyperparameters with Optuna...")
    study_cnn = optuna.create_study(direction='maximize', pruner=optuna.pruners.MedianPruner())
    study_cnn.optimize(
        lambda trial: objective_cnn(trial, train_dataset, val_dataset, vocab_size, max_seq_length, weighted_criterion, balanced_sampler),
        n_trials=10
    )

    print("Best CNN parameters:", study_cnn.best_params)
    print("Best CNN validation MCC score:", study_cnn.best_value)

    print("Optimizing BiLSTM hyperparameters with Optuna...")
    study_bilstm = optuna.create_study(direction='maximize', pruner=optuna.pruners.MedianPruner())
    study_bilstm.optimize(
        lambda trial: objective_bilstm(trial, train_dataset, val_dataset, vocab_size, weighted_criterion, balanced_sampler),
        n_trials=10
    )

    print("Best BiLSTM parameters:", study_bilstm.best_params)
    print("Best BiLSTM validation MCC score:", study_bilstm.best_value)

    # Save studies
    optuna_dir = os.path.join(results_dir, 'optuna_results')
    os.makedirs(optuna_dir, exist_ok=True)
    joblib.dump(study_cnn, os.path.join(optuna_dir, 'cnn_study.pkl'))
    joblib.dump(study_bilstm, os.path.join(optuna_dir, 'bilstm_study.pkl'))

    # Train and evaluate CNN with best parameters
    best_cnn_params = study_cnn.best_params
    batch_size = best_cnn_params['batch_size']

    if best_cnn_params.get('use_balanced_sampling', False):
        print("Using balanced sampling for CNN training...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=balanced_sampler)
    else:
        print("Using normal sampling for CNN training...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    cnn_model = TextCNN(
        vocab_size=vocab_size,
        embedding_dim=best_cnn_params['embedding_dim'],
        max_seq_length=max_seq_length,
        num_filters=best_cnn_params['num_filters'],
        filter_sizes=(3, 4, 5),
        num_classes=1,
        dropout=best_cnn_params['dropout']
    ).to(device)

    print(cnn_model)
    print(f"Number of parameters: {sum(p.numel() for p in cnn_model.parameters() if p.requires_grad):,}")

    criterion = weighted_criterion
    optimizer = Adam(cnn_model.parameters(), lr=best_cnn_params['learning_rate'])

    num_epochs = 10
    start_time = time.time()
    cnn_model, cnn_training_stats = train_model(
        cnn_model, train_loader, val_loader, criterion, optimizer, num_epochs, device, 'cnn'
    )
    cnn_training_time = time.time() - start_time
    print(f"CNN training completed in {cnn_training_time:.2f} seconds")

    torch.save(cnn_model.state_dict(), os.path.join(model_dir, 'cnn_model.pt'))
    print(f"CNN model saved to {os.path.join(model_dir, 'cnn_model.pt')}")

    visualize_training(cnn_training_stats, 'cnn')
    cnn_eval_results = evaluate_model(cnn_model, test_loader, criterion, device)
    cnn_eval_results['training_time'] = cnn_training_time
    visualize_evaluation(cnn_eval_results, 'cnn')

    # Train and evaluate BiLSTM with best parameters
    best_bilstm_params = study_bilstm.best_params
    batch_size = best_bilstm_params['batch_size']

    if best_bilstm_params.get('use_balanced_sampling', False):
        print("Using balanced sampling for BiLSTM training...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=balanced_sampler)
    else:
        print("Using normal sampling for BiLSTM training...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    bilstm_model = BiLSTMClassifier(
        vocab_size=vocab_size,
        embedding_dim=best_bilstm_params['embedding_dim'],
        hidden_dim=best_bilstm_params['hidden_dim'],
        num_layers=best_bilstm_params['num_layers'],
        num_classes=1,
        dropout=best_bilstm_params['dropout'],
        l2_reg=best_bilstm_params['l2_reg']
    ).to(device)

    print(bilstm_model)
    print(f"Number of parameters: {sum(p.numel() for p in bilstm_model.parameters() if p.requires_grad):,}")

    criterion = weighted_criterion
    optimizer = Adam(bilstm_model.parameters(), lr=best_bilstm_params['learning_rate'], weight_decay=best_bilstm_params['l2_reg'])

    num_epochs = 10
    start_time = time.time()
    bilstm_model, bilstm_training_stats = train_model(
        bilstm_model, train_loader, val_loader, criterion, optimizer, num_epochs, device, 'bilstm'
    )
    bilstm_training_time = time.time() - start_time
    print(f"BiLSTM training completed in {bilstm_training_time:.2f} seconds")

    torch.save(bilstm_model.state_dict(), os.path.join(model_dir, 'bilstm_model.pt'))
    print(f"BiLSTM model saved to {os.path.join(model_dir, 'bilstm_model.pt')}")

    visualize_training(bilstm_training_stats, 'bilstm')
    bilstm_eval_results = evaluate_model(bilstm_model, test_loader, criterion, device)
    bilstm_eval_results['training_time'] = bilstm_training_time
    visualize_evaluation(bilstm_eval_results, 'bilstm')

    # Save results
    cnn_results_df = pd.DataFrame({
        'Model': ['TextCNN'],
        'Training Time (s)': [cnn_eval_results['training_time']],
        'Train Accuracy': [cnn_training_stats['train_acc'][-1]],
        'Val Accuracy': [cnn_training_stats['val_acc'][-1]],
        'Test Accuracy': [cnn_eval_results['test_acc']],
        'Train F1': [cnn_training_stats['train_f1'][-1]],
        'Val F1': [cnn_training_stats['val_f1'][-1]],
        'Test F1': [cnn_eval_results['test_f1']],
        'Train MCC': [cnn_training_stats['train_mcc'][-1]],
        'Val MCC': [cnn_training_stats['val_mcc'][-1]],
        'Test MCC': [cnn_eval_results['test_mcc']],
        'Test Precision': [cnn_eval_results['test_prec']],
        'Test Recall': [cnn_eval_results['test_rec']],
        'Test ROC AUC': [cnn_eval_results['test_roc_auc']],
        'Test PR AUC': [cnn_eval_results['test_pr_auc']]
    })

    bilstm_results_df = pd.DataFrame({
        'Model': ['BiLSTM'],
        'Training Time (s)': [bilstm_eval_results['training_time']],
        'Train Accuracy': [bilstm_training_stats['train_acc'][-1]],
        'Val Accuracy': [bilstm_training_stats['val_acc'][-1]],
        'Test Accuracy': [bilstm_eval_results['test_acc']],
        'Train F1': [bilstm_training_stats['train_f1'][-1]],
        'Val F1': [bilstm_training_stats['val_f1'][-1]],
        'Test F1': [bilstm_eval_results['test_f1']],
        'Train MCC': [bilstm_training_stats['train_mcc'][-1]],
        'Val MCC': [bilstm_training_stats['val_mcc'][-1]],
        'Test MCC': [bilstm_eval_results['test_mcc']],
        'Test Precision': [bilstm_eval_results['test_prec']],
        'Test Recall': [bilstm_eval_results['test_rec']],
        'Test ROC AUC': [bilstm_eval_results['test_roc_auc']],
        'Test PR AUC': [bilstm_eval_results['test_pr_auc']]
    })

    # Combine results
    dl_results = pd.concat([cnn_results_df, bilstm_results_df])

    # Load baseline ML results if available
    baseline_ml_results_path = os.path.join(results_dir, 'metrics', 'baseline_ml_results.csv')
    if os.path.exists(baseline_ml_results_path):
        baseline_df = pd.read_csv(baseline_ml_results_path)
        print("\nBaseline ML results loaded.")
        all_models_df = pd.concat([baseline_df, dl_results])
    else:
        baseline_df = pd.DataFrame()
        print("\nNo baseline ML results found.")
        all_models_df = dl_results

    # Save combined results
    all_models_df.to_csv(os.path.join(results_dir, 'metrics', 'all_models_comparison.csv'), index=False)

    print("\nAll Models Comparison:")
    print(all_models_df[['Model', 'Test MCC', 'Test F1', 'Test ROC AUC', 'Test PR AUC', 'Training Time (s)']].sort_values('Test MCC', ascending=False))

    # Visualize comparison
    if len(all_models_df) > 1:
        metrics = ['Test F1', 'Test MCC', 'Test ROC AUC', 'Test PR AUC']
        melted_df = pd.melt(all_models_df, id_vars=['Model'], value_vars=metrics, var_name='Metric', value_name='Score')

        plt.figure(figsize=(14, 8))
        sns.barplot(x='Model', y='Score', hue='Metric', data=melted_df)
        plt.title('Performance Comparison Across All Models')
        plt.ylim(0, 1)
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'visualizations', 'all_models_performance_comparison.png'))
        plt.show()

        # Training time comparison
        plt.figure(figsize=(12, 6))
        sns.barplot(x='Model', y='Training Time (s)', data=all_models_df)
        plt.title('Training Time Comparison')
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'visualizations', 'all_models_training_time.png'))
        plt.show()

    # Test example predictions
    example_messages = [
        "Congratulations! You've won a ugx100000 gift card. Click here to claim: www.example.com",
        "Your account has been suspended. Please verify your identity by sending your PIN to this number.",
        "Hi, just checking if we're still meeting for lunch tomorrow at 12?",
        "URGENT: Your payment of ugx55000 has been processed. If this was not you, call immediately: 1-800-555-1234",
        "Your package will be delivered tomorrow between 10am and 2pm. No signature required."
    ]

    print("\n=== Testing Models on Example Messages ===")

    for model_name, model in [('TextCNN', cnn_model), ('BiLSTM', bilstm_model)]:
        print(f"\n=== {model_name} Predictions ===")
        model.eval()

        for i, text in enumerate(example_messages):
            # Preprocess text
            cleaned_text = re.sub(r'https?://\S+|www\.\S+', '', text.lower())
            cleaned_text = re.sub(r'\S+@\S+', '', cleaned_text)
            cleaned_text = re.sub(r'\b\d{10,}\b', '', cleaned_text)
            cleaned_text = re.sub(r'[^\x00-\x7F]+', '', cleaned_text)
            cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()

            # Tokenize and pad
            sequence = tokenizer.texts_to_sequences([cleaned_text])
            padded_sequence = tokenizer.pad_sequences(sequence)

            # Convert to tensor
            input_tensor = torch.tensor(padded_sequence, dtype=torch.long).to(device)

            # Make prediction
            with torch.no_grad():
                output = model(input_tensor).squeeze()
                prob = torch.sigmoid(output).item()
                prediction = "Scam" if prob > 0.5 else "Legitimate"

            print(f"\nExample {i+1}: {text}")
            print(f"Prediction: {prediction} (Confidence: {prob:.4f})")

    print("\nDeep learning models training and evaluation completed successfully!")
    print(f"Results saved to: {results_dir}")

if __name__ == "__main__":
    main()