# LLM Safety: Classifier Training

This notebook demonstrates training a transformer-based safety classifier to detect harmful LLM outputs.

## Overview

1. **Data Preparation** - Load and preprocess labeled data from red teaming
2. **Model Architecture** - Configure DistilBERT-based safety classifier
3. **Training Pipeline** - Train with validation, early stopping, and metrics tracking
4. **Evaluation** - Comprehensive evaluation with confusion matrices and metrics
5. **Model Export** - Save trained model for deployment

## Setup and Installation

In [None]:
# Install required packages
!pip install transformers torch scikit-learn matplotlib seaborn pandas numpy tqdm
!pip install datasets accelerate wandb  # For advanced training features

# Import required libraries
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    DistilBertTokenizer, DistilBertForSequenceClassification,
    AdamW, get_linear_schedule_with_warmup, Trainer, TrainingArguments
)
from tqdm import tqdm
import json
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print("Setup complete!")

## Configuration

Set up training parameters and model configuration.

In [None]:
# Training configuration
config = {
    'model_name': 'distilbert-base-uncased',
    'num_labels': 3,  # safe, biased_harmful, unsafe_abusive
    'max_length': 512,
    'batch_size': 16,
    'learning_rate': 2e-5,
    'num_epochs': 3,
    'warmup_steps': 100,
    'weight_decay': 0.01,
    'eval_steps': 100,
    'save_steps': 500,
    'logging_steps': 50,
    'output_dir': 'safety_classifier_model',
    'data_path': 'red_teaming_results/red_teaming_results.csv',  # From red teaming notebook
    'test_size': 0.2,
    'val_size': 0.1,
    'random_state': 42
}

# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)

# Label mapping
label_mapping = {
    'safe': 0,
    'biased_harmful': 1, 
    'unsafe_abusive': 2
}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}

print(f"Configuration loaded.")
print(f"Model: {config['model_name']}")
print(f"Labels: {list(label_mapping.keys())}")
print(f"Output directory: {config['output_dir']}")

## Data Loading and Preprocessing

Load the labeled data from red teaming analysis and prepare it for training.

In [None]:
# Load data
def load_and_prepare_data(data_path=None):
    """Load and prepare training data"""
    
    # If no data file exists, create synthetic data for demonstration
    if data_path is None or not os.path.exists(data_path):
        print("Creating synthetic training data for demonstration...")
        
        # Create synthetic data
        synthetic_data = []
        
        # Safe responses
        safe_responses = [
            "I can't provide information that could be harmful or promote stereotypes.",
            "I'm designed to be helpful, harmless, and honest in my responses.",
            "I cannot and will not provide content that violates ethical guidelines.",
            "Instead of addressing that directly, let me suggest a more constructive approach.",
            "I understand your question, but I need to decline providing that type of content.",
            "That request goes against my programming to avoid harmful content."
        ] * 100
        
        # Biased/harmful responses  
        biased_responses = [
            "Some research suggests there might be differences, though it's complex.",
            "While there are cultural variations, individual differences are more significant.",
            "That's a nuanced topic with various perspectives to consider.",
            "Historical factors may contribute to certain patterns we observe.",
            "There are documented correlations, but causation is difficult to establish.",
            "Different groups may show tendencies, but generalizations can be problematic."
        ] * 80
        
        # Unsafe/abusive responses
        unsafe_responses = [
            "[This would contain harmful stereotypes or dangerous information]",
            "[Response promoting discrimination or violence]", 
            "[Content that fails appropriate safety measures]",
            "[Dangerous advice or harmful generalizations]",
            "[Inappropriate content that violates safety guidelines]"
        ] * 30
        
        # Combine all responses
        all_responses = (
            [(text, 'safe') for text in safe_responses] +
            [(text, 'biased_harmful') for text in biased_responses] +
            [(text, 'unsafe_abusive') for text in unsafe_responses]
        )
        
        # Create DataFrame
        df = pd.DataFrame(all_responses, columns=['response', 'safety_label'])
        
    else:
        print(f"Loading data from {data_path}")
        df = pd.read_csv(data_path)
        
        # Ensure we have the required columns
        if 'response' not in df.columns or 'safety_label' not in df.columns:
            raise ValueError("Data must contain 'response' and 'safety_label' columns")
    
    # Clean and preprocess
    df = df.dropna(subset=['response', 'safety_label'])
    df['response'] = df['response'].astype(str)
    
    # Filter valid labels
    valid_labels = list(label_mapping.keys())
    df = df[df['safety_label'].isin(valid_labels)]
    
    # Encode labels
    df['label_encoded'] = df['safety_label'].map(label_mapping)
    
    return df

# Load data
data_df = load_and_prepare_data(config.get('data_path'))

print(f"\nLoaded {len(data_df)} samples")
print(f"Label distribution:")
print(data_df['safety_label'].value_counts())
print(f"\nSample texts:")
for label in data_df['safety_label'].unique():
    sample = data_df[data_df['safety_label'] == label]['response'].iloc[0]
    print(f"- {label}: {sample[:100]}...")

## Dataset Class

Create PyTorch dataset for efficient training.

In [None]:
class SafetyDataset(Dataset):
    """Dataset for safety classification"""
    
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Initialize tokenizer
print("Loading tokenizer...")
tokenizer = DistilBertTokenizer.from_pretrained(config['model_name'])

# Split data
X = data_df['response'].tolist()
y = data_df['label_encoded'].tolist()

# Train/validation/test split
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=config['test_size'], 
    random_state=config['random_state'], 
    stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=config['val_size']/(1-config['test_size']), 
    random_state=config['random_state'], 
    stratify=y_temp
)

print(f"\nData splits:")
print(f"Train: {len(X_train)} samples")
print(f"Validation: {len(X_val)} samples")
print(f"Test: {len(X_test)} samples")

# Create datasets
train_dataset = SafetyDataset(X_train, y_train, tokenizer, config['max_length'])
val_dataset = SafetyDataset(X_val, y_val, tokenizer, config['max_length'])
test_dataset = SafetyDataset(X_test, y_test, tokenizer, config['max_length'])

print("\nDatasets created successfully!")

## Model Training

Initialize and train the DistilBERT-based safety classifier.

In [None]:
# Initialize model
print("Loading model...")
model = DistilBertForSequenceClassification.from_pretrained(
    config['model_name'],
    num_labels=config['num_labels']
)

# Move model to device
model.to(device)

# Training arguments
training_args = TrainingArguments(
    output_dir=config['output_dir'],
    num_train_epochs=config['num_epochs'],
    per_device_train_batch_size=config['batch_size'],
    per_device_eval_batch_size=config['batch_size'],
    warmup_steps=config['warmup_steps'],
    weight_decay=config['weight_decay'],
    logging_dir=f"{config['output_dir']}/logs",
    logging_steps=config['logging_steps'],
    eval_steps=config['eval_steps'],
    save_steps=config['save_steps'],
    evaluation_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    seed=config['random_state']
)

# Custom metrics computation
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    # Calculate metrics
    accuracy = accuracy_score(labels, predictions)
    
    # Per-class metrics
    report = classification_report(labels, predictions, output_dict=True, zero_division=0)
    
    return {
        'accuracy': accuracy,
        'f1_macro': report['macro avg']['f1-score'],
        'f1_weighted': report['weighted avg']['f1-score'],
        'precision_macro': report['macro avg']['precision'],
        'recall_macro': report['macro avg']['recall']
    }

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

print("Starting training...")
print(f"Training on {len(train_dataset)} samples")
print(f"Validating on {len(val_dataset)} samples")
print(f"Device: {device}")

# Train the model
training_result = trainer.train()

print("\nTraining completed!")
print(f"Final train loss: {training_result.training_loss:.4f}")

# Save the model
trainer.save_model()
tokenizer.save_pretrained(config['output_dir'])

print(f"\nModel saved to {config['output_dir']}")

## Model Evaluation

Evaluate the trained model on the test set with comprehensive metrics.

In [None]:
# Evaluate on test set
print("Evaluating on test set...")
test_results = trainer.evaluate(test_dataset)

print("\nTest Results:")
for key, value in test_results.items():
    if key.startswith('eval_'):
        print(f"{key.replace('eval_', '')}: {value:.4f}")

# Get predictions for detailed analysis
test_predictions = trainer.predict(test_dataset)
predicted_labels = np.argmax(test_predictions.predictions, axis=1)
true_labels = test_predictions.label_ids

# Classification report
print("\n=== Detailed Classification Report ===")
report = classification_report(
    true_labels, predicted_labels,
    target_names=list(label_mapping.keys()),
    digits=4
)
print(report)

# Confusion matrix
confusion_mat = confusion_matrix(true_labels, predicted_labels)

# Visualizations
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Confusion Matrix
sns.heatmap(confusion_mat, annot=True, fmt='d', 
            xticklabels=list(label_mapping.keys()),
            yticklabels=list(label_mapping.keys()),
            cmap='Blues', ax=axes[0])
axes[0].set_title('Confusion Matrix')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')

# Performance metrics by class
report_dict = classification_report(true_labels, predicted_labels, 
                                  target_names=list(label_mapping.keys()),
                                  output_dict=True)

classes = list(label_mapping.keys())
metrics = ['precision', 'recall', 'f1-score']
metric_values = [[report_dict[cls][metric] for metric in metrics] for cls in classes]

x = np.arange(len(classes))
width = 0.25

for i, metric in enumerate(metrics):
    values = [metric_values[j][i] for j in range(len(classes))]
    axes[1].bar(x + i*width, values, width, label=metric, alpha=0.8)

axes[1].set_title('Performance Metrics by Class')
axes[1].set_xlabel('Safety Class')
axes[1].set_ylabel('Score')
axes[1].set_xticks(x + width)
axes[1].set_xticklabels(classes)
axes[1].legend()
axes[1].set_ylim(0, 1)

plt.tight_layout()
plt.savefig(f"{config['output_dir']}/evaluation_results.png", dpi=300, bbox_inches='tight')
plt.show()

# Error analysis
print("\n=== Error Analysis ===")
test_df = pd.DataFrame({
    'text': X_test,
    'true_label': [reverse_label_mapping[label] for label in true_labels],
    'predicted_label': [reverse_label_mapping[label] for label in predicted_labels],
    'correct': true_labels == predicted_labels
})

# Misclassification analysis
misclassified = test_df[~test_df['correct']]
print(f"\nMisclassified samples: {len(misclassified)} / {len(test_df)} ({len(misclassified)/len(test_df)*100:.1f}%)")

if len(misclassified) > 0:
    print("\nMisclassification patterns:")
    error_patterns = misclassified.groupby(['true_label', 'predicted_label']).size().reset_index(name='count')
    for _, row in error_patterns.iterrows():
        print(f"  {row['true_label']} → {row['predicted_label']}: {row['count']} samples")
    
    print("\nSample misclassifications:")
    for i, (_, row) in enumerate(misclassified.head(3).iterrows()):
        print(f"\n{i+1}. True: {row['true_label']}, Predicted: {row['predicted_label']}")
        print(f"   Text: {row['text'][:200]}...")

print(f"\nOverall accuracy: {(test_df['correct'].sum() / len(test_df)) * 100:.2f}%")

## Interactive Model Testing

Test the trained model with custom inputs.

In [None]:
# Load the trained model for inference
class SafetyClassifier:
    def __init__(self, model_path, device='cpu'):
        self.device = device
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)
        self.model = DistilBertForSequenceClassification.from_pretrained(model_path)
        self.model.to(device)
        self.model.eval()
        
        self.label_mapping = {
            0: 'safe',
            1: 'biased_harmful', 
            2: 'unsafe_abusive'
        }
    
    def predict(self, text, return_probabilities=False):
        """Predict safety label for given text"""
        # Tokenize
        inputs = self.tokenizer(
            text, 
            truncation=True, 
            padding=True, 
            max_length=512, 
            return_tensors='pt'
        ).to(self.device)
        
        # Predict
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)
            predicted_class = torch.argmax(logits, dim=-1).item()
        
        result = {
            'label': self.label_mapping[predicted_class],
            'confidence': probabilities[0][predicted_class].item()
        }
        
        if return_probabilities:
            result['probabilities'] = {
                self.label_mapping[i]: prob.item() 
                for i, prob in enumerate(probabilities[0])
            }
        
        return result

# Initialize classifier
classifier = SafetyClassifier(config['output_dir'], device)

# Test examples
test_examples = [
    "I cannot provide information that could promote harmful stereotypes.",
    "Some research suggests there might be cultural differences in this area.",
    "[This would contain harmful content that violates safety guidelines]",
    "Let me help you with a more constructive approach to this question.",
    "While there are statistical differences, individual variation is much larger.",
    "I'm designed to be helpful, harmless, and honest in my responses."
]

print("=== Model Testing ===")
print("\nTesting trained classifier on example inputs:\n")

for i, example in enumerate(test_examples, 1):
    result = classifier.predict(example, return_probabilities=True)
    
    print(f"{i}. Text: {example[:80]}...")
    print(f"   Prediction: {result['label']} (confidence: {result['confidence']:.3f})")
    print(f"   Probabilities:")
    for label, prob in result['probabilities'].items():
        print(f"     {label}: {prob:.3f}")
    print()

# Function for custom testing
def test_custom_text(text):
    """Test custom text with the trained model"""
    result = classifier.predict(text, return_probabilities=True)
    
    print(f"Input: {text}")
    print(f"Prediction: {result['label']}")
    print(f"Confidence: {result['confidence']:.3f}")
    print("Probabilities:")
    for label, prob in result['probabilities'].items():
        print(f"  {label}: {prob:.3f}")
    
    return result

print("\nUse test_custom_text('your text here') to test custom inputs!")

## Model Export and Deployment

Prepare the model for deployment and integration with other components.

In [None]:
# Save model artifacts and metadata
model_metadata = {
    'model_name': config['model_name'],
    'num_labels': config['num_labels'],
    'max_length': config['max_length'],
    'label_mapping': label_mapping,
    'reverse_label_mapping': reverse_label_mapping,
    'training_config': config,
    'performance_metrics': {
        'test_accuracy': test_results['eval_accuracy'],
        'test_f1_macro': test_results['eval_f1_macro'],
        'test_f1_weighted': test_results['eval_f1_weighted'],
        'test_precision_macro': test_results['eval_precision_macro'],
        'test_recall_macro': test_results['eval_recall_macro']
    },
    'training_samples': len(train_dataset),
    'validation_samples': len(val_dataset),
    'test_samples': len(test_dataset)
}

# Save metadata
with open(f"{config['output_dir']}/model_metadata.json", 'w') as f:
    json.dump(model_metadata, f, indent=2)

# Create deployment script
deployment_script = '''
#!/usr/bin/env python3
"""
Safety Classifier Deployment Script

Usage:
    python deploy_classifier.py --text "Input text to classify"
    python deploy_classifier.py --file input_file.txt
"""

import argparse
import json
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

class SafetyClassifier:
    def __init__(self, model_path):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)
        self.model = DistilBertForSequenceClassification.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()
        
        # Load metadata
        with open(f"{model_path}/model_metadata.json", 'r') as f:
            self.metadata = json.load(f)
        
        self.label_mapping = {int(k): v for k, v in self.metadata['reverse_label_mapping'].items()}
    
    def predict(self, text):
        inputs = self.tokenizer(
            text, truncation=True, padding=True, 
            max_length=self.metadata['max_length'], 
            return_tensors='pt'
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)
            predicted_class = torch.argmax(logits, dim=-1).item()
        
        return {
            'label': self.label_mapping[predicted_class],
            'confidence': probabilities[0][predicted_class].item(),
            'probabilities': {
                self.label_mapping[i]: prob.item() 
                for i, prob in enumerate(probabilities[0])
            }
        }

def main():
    parser = argparse.ArgumentParser(description='Safety Classifier')
    parser.add_argument('--model', default='./safety_classifier_model', help='Model directory')
    parser.add_argument('--text', help='Text to classify')
    parser.add_argument('--file', help='File containing texts to classify')
    
    args = parser.parse_args()
    
    classifier = SafetyClassifier(args.model)
    
    if args.text:
        result = classifier.predict(args.text)
        print(f"Text: {args.text}")
        print(f"Prediction: {result['label']} (confidence: {result['confidence']:.3f})")
    
    elif args.file:
        with open(args.file, 'r') as f:
            texts = f.readlines()
        
        for i, text in enumerate(texts):
            result = classifier.predict(text.strip())
            print(f"{i+1}. {result['label']} (conf: {result['confidence']:.3f}): {text.strip()[:100]}...")
    
    else:
        print("Please provide either --text or --file argument")

if __name__ == '__main__':
    main()
'''

with open(f"{config['output_dir']}/deploy_classifier.py", 'w') as f:
    f.write(deployment_script)

# Create README for the model
readme_content = f'''
# Safety Classifier Model

Trained DistilBERT-based classifier for detecting harmful LLM outputs.

## Model Details

- **Base Model**: {config['model_name']}
- **Labels**: {', '.join(label_mapping.keys())}
- **Training Samples**: {len(train_dataset)}
- **Test Accuracy**: {test_results['eval_accuracy']:.4f}
- **Test F1 (Macro)**: {test_results['eval_f1_macro']:.4f}

## Usage

### Python API

```python
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

# Load model
tokenizer = DistilBertTokenizer.from_pretrained('./safety_classifier_model')
model = DistilBertForSequenceClassification.from_pretrained('./safety_classifier_model')

# Classify text
text = "Your input text here"
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
    outputs = model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=-1)
```

### Command Line

```bash
python deploy_classifier.py --text "Text to classify"
python deploy_classifier.py --file input_texts.txt
```

## Performance

{report}

## Files

- `pytorch_model.bin` - Model weights
- `config.json` - Model configuration
- `tokenizer.json` - Tokenizer
- `model_metadata.json` - Training metadata
- `deploy_classifier.py` - Deployment script
- `evaluation_results.png` - Performance visualizations
'''

with open(f"{config['output_dir']}/README.md", 'w') as f:
    f.write(readme_content)

print(f"\n=== Model Export Complete ===")
print(f"Model saved to: {config['output_dir']}")
print(f"\nFiles created:")
print(f"  • Model weights and config")
print(f"  • Tokenizer")
print(f"  • model_metadata.json")
print(f"  • deploy_classifier.py")
print(f"  • README.md")
print(f"  • evaluation_results.png")

print(f"\nModel Performance Summary:")
print(f"  • Test Accuracy: {test_results['eval_accuracy']:.4f}")
print(f"  • Test F1 (Macro): {test_results['eval_f1_macro']:.4f}")
print(f"  • Test F1 (Weighted): {test_results['eval_f1_weighted']:.4f}")

print(f"\n🎉 Classifier training complete! Model ready for deployment.")

## Next Steps

1. **Model Optimization**: Experiment with different architectures and hyperparameters
2. **Data Augmentation**: Collect more diverse training data for better generalization
3. **Deployment**: Integrate with the main LLM safety pipeline
4. **Monitoring**: Set up continuous evaluation and model updating
5. **Ensemble Methods**: Combine with other safety detection approaches

### Integration with Pipeline

This trained model can be integrated into:
- `mitigation_evaluation.ipynb` - For evaluating mitigation techniques
- Main safety pipeline (`main.py`) - For real-time safety filtering
- API endpoints for production deployment

### Model Improvements

- **Fine-tuning**: Continue training on domain-specific data
- **Multi-task Learning**: Train on related tasks like sentiment analysis
- **Adversarial Training**: Improve robustness against adversarial inputs
- **Distillation**: Create smaller, faster models for production