# Advanced SMS Spam Detection with DistilBERT

This notebook implements a comprehensive SMS spam detection system using advanced preprocessing techniques and DistilBERT fine-tuning. The implementation follows best practices for mobile deployment and includes hyperparameter optimization.

## 1. Setup and Dependencies

First, we'll install and import all required libraries for our SMS spam detection pipeline.

In [1]:
# Install required packages
!pip install transformers torch pandas numpy scikit-learn optuna nltk spacy emoji regex tqdm matplotlib seaborn



In [2]:
# Import required libraries
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW  # Import AdamW from torch.optim instead of transformers
from transformers import (
    DistilBertTokenizer,
    DistilBertModel,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    roc_auc_score, confusion_matrix
)
import optuna
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re
import emoji
import logging
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Download required NLTK data
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\bdcalling123\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\bdcalling123\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\bdcalling123\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\bdcalling123\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

## 2. Load and Explore SMS Data

We'll load the SMS dataset and perform initial exploration to understand its characteristics.

In [3]:
# Load the SMS dataset
df = pd.read_csv("Datasets/spam (1).csv", encoding='latin1')

# Basic data exploration
print("Dataset shape:", df.shape)
print("\nFirst few rows:")
print(df.head(10))

# Class distribution
print("\nClass distribution:")
print(df['v1'].value_counts(normalize=True))

# Basic statistics
print("\nMessage length statistics:")
df['length'] = df['v2'].str.len()
print(df.groupby('v1')['length'].describe())

Dataset shape: (5572, 5)

First few rows:
     v1                                                 v2 Unnamed: 2  \
0   ham  Go until jurong point, crazy.. Available only ...        NaN   
1   ham                      Ok lar... Joking wif u oni...        NaN   
2  spam  Free entry in 2 a wkly comp to win FA Cup fina...        NaN   
3   ham  U dun say so early hor... U c already then say...        NaN   
4   ham  Nah I don't think he goes to usf, he lives aro...        NaN   
5  spam  FreeMsg Hey there darling it's been 3 week's n...        NaN   
6   ham  Even my brother is not like to speak with me. ...        NaN   
7   ham  As per your request 'Melle Melle (Oru Minnamin...        NaN   
8  spam  WINNER!! As a valued network customer you have...        NaN   
9  spam  Had your mobile 11 months or more? U R entitle...        NaN   

  Unnamed: 3 Unnamed: 4  
0        NaN        NaN  
1        NaN        NaN  
2        NaN        NaN  
3        NaN        NaN  
4        NaN        NaN 

## 3. Text Preprocessing Pipeline

Implementing a comprehensive preprocessing pipeline specifically designed for SMS data.

In [4]:
class SMSPreprocessor:
    def __init__(self):
        self.lemmatizer = WordNetLemmatizer()
        self.stop_words = set(stopwords.words('english'))
        
        # Common SMS abbreviations
        self.sms_dict = {
            "u": "you", "ur": "your", "2": "to", "4": "for",
            "b4": "before", "gr8": "great", "luv": "love",
            "msg": "message", "txt": "text", "asap": "as soon as possible"
        }
    
    def preprocess(self, text):
        """
        Complete preprocessing pipeline for SMS messages
        """
        # Stage 1: Initial Text Cleaning
        text = self._initial_cleaning(text)
        
        # Stage 2: SMS-Specific Preprocessing
        text = self._sms_specific_preprocessing(text)
        
        # Stage 3: Advanced Linguistic Processing
        text = self._linguistic_processing(text)
        
        return text
    
    def _initial_cleaning(self, text):
        # Convert to lowercase
        text = text.lower()
        
        # Replace URLs
        text = re.sub(r'http\S+|www\S+|https\S+', '<url>', text, flags=re.MULTILINE)
        
        # Replace email addresses
        text = re.sub(r'\S+@\S+', '<email>', text)
        
        # Replace phone numbers
        text = re.sub(r'\+?[1-9][0-9\-\(\)]{8,}', '<phone>', text)
        
        # Handle emoji
        text = emoji.replace_emoji(text, '<emoji>')
        
        # Normalize whitespace
        text = ' '.join(text.split())
        
        return text
    
    def _sms_specific_preprocessing(self, text):
        # Expand abbreviations
        words = text.split()
        words = [self.sms_dict.get(word, word) for word in words]
        text = ' '.join(words)
        
        # Handle repeated characters (e.g., "sooooo" → "so")
        text = re.sub(r'(.)\1{2,}', r'\1\1', text)
        
        # Normalize numbers
        text = re.sub(r'\$\d+(\.\d{2})?', '<money>', text)
        text = re.sub(r'\d+%', '<percentage>', text)
        
        return text
    
    def _linguistic_processing(self, text):
        # Tokenization
        tokens = word_tokenize(text)
        
        # Selective stop word removal (keep potentially important ones)
        important_stops = {'not', 'no', 'never', 'none', 'free', 'click'}
        tokens = [token for token in tokens if token not in self.stop_words or token in important_stops]
        
        # Light lemmatization (only for verbs and nouns)
        tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
        
        return ' '.join(tokens)

# Initialize preprocessor
preprocessor = SMSPreprocessor()

# Apply preprocessing to the dataset
df['processed_text'] = df['v2'].apply(preprocessor.preprocess)

# Show some examples
print("Original vs Processed Text Examples:")
for i in range(3):
    print(f"\nOriginal: {df['v2'].iloc[i]}")
    print(f"Processed: {df['processed_text'].iloc[i]}")

Original vs Processed Text Examples:

Original: Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
Processed: go jurong point , crazy .. available bugis n great world la e buffet .. cine got amore wat ..

Original: Ok lar... Joking wif u oni...
Processed: ok lar .. joking wif oni ..

Original: Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
Processed: free entry wkly comp win fa cup final tkts 21st may 2005. text fa 87121 receive entry question ( std text rate ) & c 's apply 0 < phone > over18 's


## 4. Build DistilBERT Model

Now we'll create our custom model architecture using DistilBERT as the base model.

In [5]:
class SMSSpamClassifier(nn.Module):
    def __init__(self, dropout_rate=0.3):
        super(SMSSpamClassifier, self).__init__()
        
        # Load pre-trained DistilBERT
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        
        # Freeze some layers for efficiency (optional)
        for param in self.distilbert.parameters():
            param.requires_grad = False
        
        # Unfreeze the last 2 transformer layers
        for layer in self.distilbert.transformer.layer[-2:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # Classification head
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.distilbert.config.hidden_size, 2)
        
        # Ensure model parameters are float32
        self.to(torch.float32)
        
        # Attention visualization layer
        self.attention_weights = None
    
    def forward(self, input_ids, attention_mask):
        # Get DistilBERT outputs
        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True
        )
        
        # Get the [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0, :]
        
        # Store attention weights for visualization
        self.attention_weights = outputs.attentions[-1]  # Last layer's attention
        
        # Apply dropout and classification
        x = self.dropout(cls_output)
        logits = self.classifier(x)
        
        return logits
    
    def get_attention_weights(self):
        return self.attention_weights

# Initialize tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = SMSSpamClassifier().to(device)

print("Model architecture:")
print(model)

Model architecture:
SMSSpamClassifier(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False

## 5. Training Configuration

Set up the training parameters, loss function, optimizer, and learning rate scheduler.

In [6]:
# Create PyTorch Dataset
class SMSDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Ensure float32 type for input tensors
        return {
            'input_ids': encoding['input_ids'].flatten().to(torch.long),  # Use long for input_ids
            'attention_mask': encoding['attention_mask'].flatten().to(torch.float32),  # Use float32 for attention_mask
            'label': torch.tensor(label, dtype=torch.long)  # Use long for labels
        }

# Prepare data splits
X = df['processed_text'].values
y = (df['v1'] == 'spam').astype(int).values

# Create stratified train, validation, and test splits
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.15, stratify=y, random_state=42
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.15/0.85, stratify=y_temp, random_state=42
)

# Create datasets
train_dataset = SMSDataset(X_train, y_train, tokenizer)
val_dataset = SMSDataset(X_val, y_val, tokenizer)
test_dataset = SMSDataset(X_test, y_test, tokenizer)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Training configuration
config = {
    'learning_rate': 2e-5,
    'num_epochs': 3,
    'warmup_steps': 0.1,  # 10% of total steps
    'weight_decay': 0.01,
    'gradient_clipping': 1.0
}

# Calculate total steps for warmup
total_steps = len(train_loader) * config['num_epochs']
warmup_steps = int(total_steps * config['warmup_steps'])

# Initialize optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Initialize loss function with class weights to handle imbalance
# Ensure class weights are float32
class_weights = torch.tensor(
    [1.0, (y_train == 0).sum() / (y_train == 1).sum()],
    device=device,
    dtype=torch.float32  # Explicitly set dtype to float32
)
criterion = nn.CrossEntropyLoss(weight=class_weights)

print("Training configuration:")
print(f"Number of training examples: {len(train_dataset)}")
print(f"Number of validation examples: {len(val_dataset)}")
print(f"Number of test examples: {len(test_dataset)}")
print(f"Number of training steps: {total_steps}")
print(f"Number of warmup steps: {warmup_steps}")
print(f"Class weights: {class_weights.cpu().numpy()}")

Training configuration:
Number of training examples: 3900
Number of validation examples: 836
Number of test examples: 836
Number of training steps: 732
Number of warmup steps: 73
Class weights: [1.       6.456979]


## 6. Model Training and Validation

Implement the training loop with early stopping and evaluation.

In [7]:
def train_epoch(model, train_loader, criterion, optimizer, scheduler, clip_value=1.0):
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []
    
    progress_bar = tqdm(train_loader, desc='Training')
    for batch in progress_bar:
        # Get batch data
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        scheduler.step()
        
        # Update metrics
        total_loss += loss.item()
        predictions.extend(outputs.argmax(dim=1).cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
        
        # Update progress bar
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, predictions)
    
    return avg_loss, accuracy

def evaluate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            predictions.extend(outputs.argmax(dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(val_loader)
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='binary'
    )
    
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'predictions': predictions,
        'true_labels': true_labels
    }

# Training loop with early stopping
best_val_loss = float('inf')
early_stopping_patience = 3
early_stopping_counter = 0
training_history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'val_f1': []
}

for epoch in range(config['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{config['num_epochs']}")
    
    # Training phase
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer,
        scheduler, config['gradient_clipping']
    )
    
    # Validation phase
    val_metrics = evaluate(model, val_loader, criterion)
    
    # Update training history
    training_history['train_loss'].append(train_loss)
    training_history['train_acc'].append(train_acc)
    training_history['val_loss'].append(val_metrics['loss'])
    training_history['val_acc'].append(val_metrics['accuracy'])
    training_history['val_f1'].append(val_metrics['f1'])
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['accuracy']:.4f}")
    print(f"Val Precision: {val_metrics['precision']:.4f} | Val Recall: {val_metrics['recall']:.4f}")
    print(f"Val F1: {val_metrics['f1']:.4f}")
    
    # Early stopping
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        torch.save(model.state_dict(), 'best_model.pt')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

print("\nTraining completed!")


Epoch 1/3


Training: 100%|██████████| 244/244 [05:38<00:00,  1.39s/it, loss=0.0170]
Training: 100%|██████████| 244/244 [05:38<00:00,  1.39s/it, loss=0.0170]
Evaluating: 100%|██████████| 27/27 [00:34<00:00,  1.27s/it]



Train Loss: 0.2406 | Train Acc: 0.9121
Val Loss: 0.0457 | Val Acc: 0.9880
Val Precision: 0.9397 | Val Recall: 0.9732
Val F1: 0.9561

Epoch 2/3

Epoch 2/3


Training: 100%|██████████| 244/244 [05:48<00:00,  1.43s/it, loss=0.0028]
Training: 100%|██████████| 244/244 [05:48<00:00,  1.43s/it, loss=0.0028]
Evaluating: 100%|██████████| 27/27 [00:36<00:00,  1.36s/it]



Train Loss: 0.1204 | Train Acc: 0.9882
Val Loss: 0.0455 | Val Acc: 0.9928
Val Precision: 0.9649 | Val Recall: 0.9821
Val F1: 0.9735

Epoch 3/3

Epoch 3/3


Training: 100%|██████████| 244/244 [05:33<00:00,  1.37s/it, loss=0.0020]
Training: 100%|██████████| 244/244 [05:33<00:00,  1.37s/it, loss=0.0020]
Evaluating: 100%|██████████| 27/27 [00:34<00:00,  1.26s/it]

Train Loss: 0.0835 | Train Acc: 0.9918
Val Loss: 0.0456 | Val Acc: 0.9928
Val Precision: 0.9649 | Val Recall: 0.9821
Val F1: 0.9735

Training completed!





## 7. Hyperparameter Optimization

Use Optuna for systematic hyperparameter tuning.

In [None]:
def objective(trial):
    # Hyperparameters to optimize
    params = {
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 5e-5, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [8, 16, 32]),
        'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.5),
        'weight_decay': trial.suggest_float('weight_decay', 0.01, 0.1),
        'warmup_steps': trial.suggest_float('warmup_steps', 0.1, 0.3)
    }
    
    # Create model with trial parameters
    model = SMSSpamClassifier(dropout_rate=params['dropout_rate']).to(device)
    
    # Create data loaders with trial batch size
    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=params['batch_size'], shuffle=False)
    
    # Training setup
    total_steps = len(train_loader) * config['num_epochs']
    warmup_steps = int(total_steps * params['warmup_steps'])
    
    optimizer = AdamW(
        model.parameters(),
        lr=params['learning_rate'],
        weight_decay=params['weight_decay']
    )
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Training loop
    best_val_f1 = 0
    for epoch in range(config['num_epochs']):
        # Training phase
        train_loss, _ = train_epoch(
            model, train_loader, criterion, optimizer,
            scheduler, config['gradient_clipping']
        )
        
        # Validation phase
        val_metrics = evaluate(model, val_loader, criterion)
        
        # Report intermediate value
        trial.report(val_metrics['f1'], epoch)
        
        # Handle pruning based on the intermediate value
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
        
        best_val_f1 = max(best_val_f1, val_metrics['f1'])
    
    return best_val_f1

# Create study object and specify the direction is 'maximize' the F1 score
study = optuna.create_study(direction='maximize')

# Run hyperparameter optimization
print("Running hyperparameter optimization...")
study.optimize(objective, n_trials=10)

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

# Plot optimization history
optuna.visualization.plot_optimization_history(study)
optuna.visualization.plot_param_importances(study)

[I 2025-10-15 01:21:00,957] A new study created in memory with name: no-name-285c3961-b602-4af3-bba9-5fba5d6d8b6d


Running hyperparameter optimization...


Training: 100%|██████████| 122/122 [05:20<00:00,  2.63s/it, loss=0.0624]
Training: 100%|██████████| 122/122 [05:20<00:00,  2.63s/it, loss=0.0624]
Evaluating: 100%|██████████| 27/27 [00:29<00:00,  1.08s/it]
Evaluating: 100%|██████████| 27/27 [00:29<00:00,  1.08s/it]
Training: 100%|██████████| 122/122 [04:40<00:00,  2.30s/it, loss=0.0129]
Training: 100%|██████████| 122/122 [04:40<00:00,  2.30s/it, loss=0.0129]
Evaluating: 100%|██████████| 27/27 [00:33<00:00,  1.23s/it]
Evaluating: 100%|██████████| 27/27 [00:33<00:00,  1.23s/it]
Training: 100%|██████████| 122/122 [05:19<00:00,  2.61s/it, loss=0.0063]
Training: 100%|██████████| 122/122 [05:19<00:00,  2.61s/it, loss=0.0063]
Evaluating: 100%|██████████| 27/27 [00:33<00:00,  1.25s/it]
[I 2025-10-15 01:37:59,018] Trial 0 finished with value: 0.9691629955947136 and parameters: {'learning_rate': 2.0471231850010034e-05, 'batch_size': 32, 'dropout_rate': 0.3756686765825775, 'weight_decay': 0.010838916812871116, 'warmup_steps': 0.11386301952033767}

## 8. Evaluation and Visualization

Calculate metrics and visualize model performance.

In [None]:
# Load best model
model.load_state_dict(torch.load('best_model.pt'))

# Evaluate on test set
test_metrics = evaluate(model, test_loader, criterion)

# Print final metrics
print("\nTest Set Metrics:")
print(f"Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Precision: {test_metrics['precision']:.4f}")
print(f"Recall: {test_metrics['recall']:.4f}")
print(f"F1 Score: {test_metrics['f1']:.4f}")

# Plot confusion matrix
plt.figure(figsize=(8, 6))
cm = confusion_matrix(test_metrics['true_labels'], test_metrics['predictions'])
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Plot training history
plt.figure(figsize=(12, 4))

# Plot losses
plt.subplot(1, 2, 1)
plt.plot(training_history['train_loss'], label='Training Loss')
plt.plot(training_history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot metrics
plt.subplot(1, 2, 2)
plt.plot(training_history['train_acc'], label='Train Accuracy')
plt.plot(training_history['val_acc'], label='Val Accuracy')
plt.plot(training_history['val_f1'], label='Val F1')
plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()

plt.tight_layout()
plt.show()

# Attention visualization for a sample message
def visualize_attention(model, tokenizer, text):
    # Tokenize and prepare input
    encoding = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=128
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Get model prediction and attention weights
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        attention = model.get_attention_weights()
    
    # Get prediction
    pred = outputs.argmax(dim=1).item()
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    # Plot attention heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention[0][0].cpu(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='viridis'
    )
    plt.title(f'Attention Visualization (Predicted: {"spam" if pred == 1 else "ham"})')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Visualize attention for a sample spam message
sample_text = X_test[y_test == 1][0]  # Get a spam message
print("Sample text:", sample_text)
visualize_attention(model, tokenizer, sample_text)