# Security String Classification Training with Unsloth

This notebook demonstrates step-by-step training of a security classification model using Unsloth and Llama 3.1. The model classifies strings as either "Secret" or "Non-sensitive" based on their context in issue reports.

## Overview
- **Task**: Binary classification of security-sensitive strings
- **Model**: Llama 3.1 8B with LoRA fine-tuning
- **Framework**: Unsloth for efficient training
- **Data**: CSV format with candidate strings and issue reports

## Step 1: Environment Setup and Imports

First, let's import all necessary libraries and set up the environment for Unsloth.

In [None]:
import pandas as pd
import numpy as np
import os

# Set environment for Unsloth
os.environ["UNSLOTH_IS_PRESENT"] = "1"

from tqdm import tqdm
import torch
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import Dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

# Set random seed for reproducibility
torch.manual_seed(69420)

print("✅ Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Step 2: Data Loading and Exploration

Load the training, validation, and test datasets and explore their structure.

In [None]:
# Load your data
df_train = pd.read_csv("../Data/train.csv")
df_val = pd.read_csv("../Data/val.csv")
df_test = pd.read_csv("../Data/test.csv")

print(f"Train shape: {df_train.shape}")
print(f"Val shape: {df_val.shape}")
print(f"Test shape: {df_test.shape}")

# Display data structure
print("\n📊 Training data sample:")
print(df_train.head(2))

print("\n📋 Column information:")
print(df_train.columns.tolist())
print(f"\nData types:\n{df_train.dtypes}")

## Step 3: Model Configuration

Configure the Llama 3.1 model with optimal settings for security classification.

In [None]:
# ============================================
# Model Configuration (following finetune_balanced.py style)
# ============================================
max_seq_length = 1024  # Context window size
dtype = None  # Auto-detect optimal dtype
load_in_4bit = True  # Use 4-bit quantization for memory efficiency

print("🔧 Model Configuration:")
print(f"Max sequence length: {max_seq_length}")
print(f"4-bit quantization: {load_in_4bit}")

# Load pre-quantized Llama 3.1 model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

print("✅ Base model loaded successfully!")

## Step 4: LoRA Configuration

Add LoRA (Low-Rank Adaptation) layers for efficient fine-tuning.

In [None]:
# Add LoRA adapters for efficient fine-tuning
model = FastLanguageModel.get_peft_model(
    model,
    r=16,  # LoRA rank - balance between efficiency and performance
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,  # LoRA scaling parameter
    lora_dropout=0,  # No dropout for stability
    bias="none",
    use_gradient_checkpointing="unsloth",  # Unsloth's optimized checkpointing
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

print("✅ LoRA adapters added successfully!")
print(f"📊 LoRA Configuration:")
print(f"  - Rank (r): 16")
print(f"  - Alpha: 16") 
print(f"  - Dropout: 0")
print(f"  - Target modules: 7 attention/MLP layers")

## Step 5: Data Preprocessing

Preprocess the data to create context windows and format labels appropriately.

In [None]:
import re

def preprocess(text):
    """Apply text preprocessing to clean the input"""
    input_string = text   
    input_string = re.sub(r'[\'"\│]', '', input_string)
    dir_list_clean = re.sub(r'drwx[-\s]*\d+\s+\w+\s+\w+\s+\d+\s+\w+\s+\d+\s+[0-9a-fA-F-]+.*','',input_string)
    shell_code_free_text = re.sub(r'```shell([^`]+)```','',dir_list_clean,flags=re.IGNORECASE)
    shell_code_free_text = re.sub(r'```Shell\s*"([^"]*)"\s*```','',shell_code_free_text,flags=re.IGNORECASE)
    saved_game_free_text = re.sub(r'<details><summary>Saved game</summary>\n\n```(.*?)```', '', shell_code_free_text)
    remove_packages = re.sub(r'(\w+\.)+\w+','',saved_game_free_text)
    java_exp_free_text = re.sub(r'at\s[\w.$]+\.([\w]+)\(([^:]+:\d+)\)','',remove_packages)
    url_with_fragment_text = re.sub(r'https?://[^\s#]+#[A-Za-z0-9\-\=\+]+','', java_exp_free_text, flags=re.IGNORECASE)
    url_free_text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '',url_with_fragment_text)
    commit_free_text = re.sub(r'commit[ ]?(?:id)?[ ]?[:]?[ ]?([0-9a-f]{40})\b', '', url_free_text, flags=re.IGNORECASE)
    file_path_free_text = re.sub(r"/[\w/. :-]+",'',commit_free_text)
    file_path_free_text = re.sub(r'(/[^/\s]+)+','',file_path_free_text)
    sha256_free_text = re.sub(r'sha256\s*[:]?[=]?\s*[a-fA-F0-9]{64}','',file_path_free_text)
    sha1_free_text = re.sub(r'git-tree-sha1\s*=\s*[a-fA-F0-9]+','',sha256_free_text)
    build_id_free_text = re.sub(r'build-id\s*[:]?[=]?\s*([a-fA-F0-9]+)','',sha1_free_text)
    guids_free_text = re.sub(r'GUIDs:\s+([0-9a-fA-F-]+\s+[0-9a-fA-F-]+\s+[0-9a-fA-F-]+)','',build_id_free_text)
    uuids_free_text = re.sub(r'([0-9a-fA-F-]+\s*,\s*[0-9a-fA-F-]+\s*,\s*[0-9a-fA-F-]+)','',guids_free_text)
    event_id_free_text = re.sub(r'<([^>]+)>','',uuids_free_text)
    UUID_free_text = re.sub(r'(?:UUID|GUID|version|id)[\\=:"\'\s]*\b[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}\b','',event_id_free_text,flags=re.IGNORECASE)
    hex_free_text = re.sub(r'(?:data|address|id)[\\=:"\'\s]*\b0x[0-9a-fA-F]+\b','',UUID_free_text,flags=re.IGNORECASE)
    ss_free_text = re.sub(r'Screenshot_(\d{4}[_-]\d{2}[_-]\d{2}[_-]\d{2}[_-]\d{2}[_-]\d{2}[_-]\d{2}[_-]\w+)','',hex_free_text,flags=re.IGNORECASE)
    cleaned_text = ss_free_text
    return cleaned_text

def create_context_window(text, target_string, window_size=200):
    """Create context window around target string"""
    target_index = text.find(target_string)
    if target_index != -1:
        start_index = max(0, target_index - window_size)
        end_index = min(len(text), target_index + len(target_string) + window_size)
        context_window = text[start_index:end_index]
        return context_window
    return None

# Apply preprocessing
print("🔄 Preprocessing data...")
df_train['modified_text'] = df_train.apply(lambda row: create_context_window(row['text'], row['candidate_string']), axis=1)
df_val['modified_text'] = df_val.apply(lambda row: create_context_window(row['text'], row['candidate_string']), axis=1)
df_test['modified_text'] = df_test.apply(lambda row: create_context_window(row['text'], row['candidate_string']), axis=1)

# Convert labels to text format
df_train['label'] = df_train['label'].replace({0: 'Non-sensitive', 1: 'Secret'})
df_val['label'] = df_val['label'].replace({0: 'Non-sensitive', 1: 'Secret'})
df_test['label'] = df_test['label'].replace({0: 'Non-sensitive', 1: 'Secret'})

print("✅ Data preprocessing complete!")
print(f"📊 Label distribution in training data:")
print(df_train['label'].value_counts())

## Step 6: Prompt Template Design

Create the prompt template using ChatML format for consistent training and inference.

In [None]:
def format_prompt(candidate_string, issue_report, label=None):
    """Format the training prompt using ChatML format"""
    
    system_prompt = """You are a security auditor or classifier specialized in identifying and categorizing sensitive secrets from issue reports. Classify the given candidate string as either "Non-sensitive" or "Secret" based on its context.

A "Secret" includes sensitive information such as: 
- API keys and secrets (e.g., `sk_test_ABC123`)  
- Private and secret keys (e.g., private SSH keys, private cryptographic keys)  
- Authentication keys and tokens (e.g., `Bearer <token>`)  
- Database connection strings with credentials (e.g., `mongodb://user:password@host:port`)  
- Passwords, usernames, and any other private information that should not be shared openly.  

A "Non-sensitive" string is not considered secret and can be shared openly. This includes:  
- Public keys of any form (e.g., public SSH keys)  
- Non-sensitive configuration values or identifiers  
- Actual-looking keys that are clearly marked as dummy/test (e.g., with comments like '# dummy key' or variable names like 'test_key')  
- Strings that just look random or patterned but are not actually secrets (e.g., `xyz123`, 'xxxx', `abc123`, `EXAMPLE_KEY`, `token_value`)  
- Strings that are clearly placeholders or redacted text (e.g., 'XXXXXXXX', '[REDACTED]', '[TRUNCATED]')  
- **Obfuscated or masked values (e.g., '****', '****123', 'abc...xyz')**  

These are always considered **"Non-sensitive"**, even if they appear in a sensitive context.

Reply with only the classification: "Non-sensitive" or "Secret"."""

    user_prompt = f"""Classify the given candidate string based on its role in the provided issue report.

candidate_string: {candidate_string}
issue_report: {issue_report}"""

    if label is not None:
        # Training format
        prompt = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_prompt}<|im_end|>
<|im_start|>assistant
{label}<|im_end|>"""
    else:
        # Inference format
        prompt = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_prompt}<|im_end|>
<|im_start|>assistant
"""
    
    return prompt

def formatting_prompts_func(examples):
    """Format examples for training using the ChatML style"""
    candidate_strings = examples["candidate_string"]
    issue_reports = examples["modified_text"]
    labels = examples["label"]
    texts = []
    
    for candidate_string, issue_report, label in zip(candidate_strings, issue_reports, labels):
        text = format_prompt(candidate_string, issue_report, label)
        texts.append(text)
    
    return texts

# Test the prompt formatting
sample_row = df_train.iloc[0]
sample_prompt = format_prompt(
    sample_row['candidate_string'], 
    sample_row['modified_text'][:200] + "...", 
    sample_row['label']
)

print("✅ Prompt template created!")
print("\n📝 Sample formatted prompt:")
print(sample_prompt[:500] + "..." if len(sample_prompt) > 500 else sample_prompt)

## Step 7: Dataset Preparation

Prepare datasets for training using HuggingFace's Dataset format.

In [None]:
# Prepare datasets
X_train = df_train
X_eval = df_val
X_test = df_test
y_true = X_test['label']

# Convert to HuggingFace datasets
train_dataset = Dataset.from_pandas(X_train[["candidate_string", "modified_text", "label"]])
eval_dataset = Dataset.from_pandas(X_eval[["candidate_string", "modified_text", "label"]])

print("✅ Datasets prepared!")
print(f"📊 Dataset sizes:")
print(f"  - Training: {len(train_dataset):,} examples")
print(f"  - Validation: {len(eval_dataset):,} examples")
print(f"  - Test: {len(X_test):,} examples")

# Display dataset info
print(f"\n📋 Training dataset features: {train_dataset.features}")

## Step 8: Training Configuration

Set up the training arguments and SFT trainer with optimal hyperparameters.

In [None]:
# ============================================
# Training Configuration (Optimized for Lower VRAM Usage)
# ============================================
output_dir = "../models/llama-3.1-ft-unsloth"

# Initialize trainer with memory-optimized arguments
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    formatting_func=formatting_prompts_func,
    args=TrainingArguments(
        per_device_train_batch_size=1,  # Reduced from 4 to 1
        per_device_eval_batch_size=1,   # Reduced from 2 to 1
        gradient_accumulation_steps=8,  # Increased from 2 to 8 to maintain effective batch size
        warmup_steps=20,
        num_train_epochs=5,
        max_steps=-1,
        learning_rate=1e-4,
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=500,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
        seed=3407,
        output_dir=output_dir,
        save_strategy="epoch",
        save_total_limit=1,
        eval_strategy="epoch",
        report_to="none",
        remove_unused_columns=False,
        dataloader_drop_last=True,
        label_smoothing_factor=0.1,
        # Memory optimization settings
        dataloader_pin_memory=False,     # Disable pin memory to save VRAM
        gradient_checkpointing=True,     # Enable gradient checkpointing
        torch_compile=False,             # Disable torch compile for memory savings
        ddp_find_unused_parameters=False, # Optimize DDP
    ),
)

print("✅ Memory-optimized training configuration complete!")
print(f"🎯 Training parameters (VRAM optimized):")
print(f"  - Batch size: 1 (per device) - REDUCED for memory saving")
print(f"  - Gradient accumulation: 8 steps - INCREASED to maintain effective batch size")
print(f"  - Effective batch size: 1 × 8 = 8")
print(f"  - Learning rate: 1e-4")
print(f"  - Max epochs: 5")
print(f"  - Optimizer: AdamW 8-bit")
print(f"  - Save strategy: epoch")
print(f"  - Eval strategy: epoch")
print(f"  - Gradient checkpointing: ENABLED for memory savings")
print(f"  - Pin memory: DISABLED for VRAM optimization")

In [None]:
# ============================================
# Additional Memory Optimization Techniques
# ============================================

# Clear any existing cached memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("🧹 Cleared CUDA cache")

# Optional: Further reduce sequence length if still using too much memory
# Uncomment the next lines if you need even more memory savings
# max_seq_length = 512  # Reduce from 1024 to 512
# print(f"⚠️  Sequence length reduced to {max_seq_length} for memory optimization")

# Monitor current memory usage before training
if torch.cuda.is_available():
    current_memory = torch.cuda.memory_allocated() / 1024**3
    reserved_memory = torch.cuda.memory_reserved() / 1024**3
    print(f"📊 Current VRAM usage before training:")
    print(f"  - Allocated: {current_memory:.2f} GB")
    print(f"  - Reserved: {reserved_memory:.2f} GB")
    
    # Get GPU info
    gpu_props = torch.cuda.get_device_properties(0)
    total_memory = gpu_props.total_memory / 1024**3
    available_memory = total_memory - reserved_memory
    print(f"  - Total GPU memory: {total_memory:.2f} GB")
    print(f"  - Available memory: {available_memory:.2f} GB")
    
    if reserved_memory > 12:  # If using more than 12GB
        print("⚠️  HIGH MEMORY USAGE DETECTED!")
        print("💡 Consider these additional optimizations:")
        print("   1. Reduce max_seq_length to 512 or 256")
        print("   2. Use DeepSpeed ZeRO if available")
        print("   3. Enable more aggressive gradient checkpointing")
        print("   4. Consider using 8-bit AdamW optimizer")

print("✅ Memory optimization setup complete!")

## Step 9: Memory Monitoring and Training

Monitor GPU memory usage and start the training process.

In [None]:
# Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f"🖥️  GPU Information:")
print(f"  - GPU: {gpu_stats.name}")
print(f"  - Max memory: {max_memory} GB")
print(f"  - Reserved before training: {start_gpu_memory} GB")
print(f"  - Available memory: {max_memory - start_gpu_memory:.1f} GB")

# Start training
print("\n🚀 Starting training...")
trainer_stats = trainer.train()

print("✅ Training completed!")

## Step 10: Training Results and Memory Analysis

Analyze training results and memory usage statistics.

In [None]:
# Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)

print(f"📊 Training Statistics:")
print(f"  - Training time: {trainer_stats.metrics['train_runtime']} seconds")
print(f"  - Training time: {round(trainer_stats.metrics['train_runtime']/60, 2)} minutes")
print(f"  - Peak reserved memory: {used_memory} GB")
print(f"  - Peak memory for training: {used_memory_for_lora} GB")
print(f"  - Peak memory % of max: {used_percentage}%")
print(f"  - Training memory % of max: {lora_percentage}%")

print(f"\n🎯 Training Metrics:")
for key, value in trainer_stats.metrics.items():
    if key.startswith('train_'):
        print(f"  - {key}: {value}")

## Step 11: Model Saving

Save the trained model in multiple formats for different use cases.

In [None]:
# Save the model
print("💾 Saving model...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

# Save as merged 16bit model for easier inference
print("💾 Saving merged model...")
model.save_pretrained_merged(output_dir + "_merged", tokenizer, save_method="merged_16bit")

print("✅ Model saving complete!")
print(f"📁 Models saved to:")
print(f"  - LoRA adapters: {output_dir}")
print(f"  - Merged model: {output_dir}_merged")

## Step 12: Comprehensive Model Testing

Now let's run a comprehensive evaluation on the test set to get detailed performance metrics, similar to the evaluation notebook.

In [None]:
FastLanguageModel.for_inference(model)


def format_prompt_for_inference(candidate_string, issue_report):
    """Format prompt for inference using ChatML style"""
    
    system_prompt = """You are a security expert classifier. Classify the given candidate string as either "Non-sensitive" or "Secret" based on its context.

A "Secret" includes sensitive information such as: 
- API keys and secrets (e.g., `sk_test_ABC123`)  
- Private and secret keys (e.g., private SSH keys, private cryptographic keys)  
- Authentication keys and tokens (e.g., `Bearer <token>`)  
- Database connection strings with credentials (e.g., `mongodb://user:password@host:port`)  
- Passwords, usernames, and any other private information that should not be shared openly.  

A "Non-sensitive" string is not considered secret and can be shared openly. This includes:  
- Public keys of any form (e.g., public SSH keys)  
- Non-sensitive configuration values or identifiers  
- Actual-looking keys that are clearly marked as dummy/test (e.g., with comments like '# dummy key' or variable names like 'test_key')  
- Strings that just look random or patterned but are not actually secrets (e.g., `xyz123`, 'xxxx', `abc123`, `EXAMPLE_KEY`, `token_value`)  
- Strings that are clearly placeholders or redacted text (e.g., 'XXXXXXXX', '[REDACTED]', '[TRUNCATED]')  
- **Obfuscated or masked values (e.g., '****', '****123', 'abc...xyz')**  

These are always considered **"Non-sensitive"**, even if they appear in a sensitive context.

Reply with only the classification: "Non-sensitive" or "Secret"."""

    user_prompt = f"""Classify the given candidate string based on its role in the provided issue report.

candidate_string: {candidate_string}
issue_report: {issue_report}"""

    # Inference format using ChatML style
    prompt = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_prompt}<|im_end|>
<|im_start|>assistant
"""
    
    return prompt

def extract_label(model_response):
    """Extract label from model response"""
    if "Secret" in model_response:
        return "Secret"
    else:
        return "Non-sensitive"

def predict_single(candidate_string, issue_report, model, tokenizer):
    """Single prediction function for testing"""
    # Format prompt for inference
    test_prompt = format_prompt_for_inference(candidate_string, issue_report)
    
    # Tokenize
    inputs = tokenizer(
        test_prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_length
    )
    
    # Move to GPU
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    
    # Generate prediction
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=5,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=False
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0])
    assistant_marker = '<|im_start|>assistant'
    if assistant_marker in response:
        model_response = response.split(assistant_marker)[-1].strip()
    else:
        model_response = response[len(test_prompt):].strip()
    
    predicted_label = extract_label(model_response)
    return predicted_label, model_response

print("✅ Prediction functions defined!")
print("  - format_prompt_for_inference: Formats prompts for inference")
print("  - extract_label: Extracts classification labels from model responses")
print("  - predict_single: Makes single predictions with the trained model")

In [None]:

def predict_batch(test_df, model, tokenizer, batch_size=8):
    """Batch prediction function for comprehensive testing"""
    y_pred = []
    errors = []
    
    print(f"🔄 Running batch predictions on {len(test_df):,} examples...")
    
    for i in tqdm(range(0, len(test_df), batch_size), desc="Predicting"):
        batch = test_df.iloc[i:i+batch_size]
        
        for idx, row in batch.iterrows():
            try:
                predicted_label = predict_single(
                    row["candidate_string"], 
                    row["modified_text"], 
                    model, 
                    tokenizer
                )
                y_pred.append(predicted_label)
                print(predicted_label)
            except Exception as e:
                errors.append(f"Error at index {idx}: {e}")
                y_pred.append("Non-sensitive")  # Default prediction
                continue
    
    if errors:
        print(f"⚠️  {len(errors)} errors occurred during prediction:")
        for error in errors[:3]:  # Show first 3 errors
            print(f"  - {error}")
        if len(errors) > 3:
            print(f"  - ... and {len(errors) - 3} more errors")
    
    return y_pred

print("✅ Batch prediction function defined!")

# def predict_batch(test_df, model, tokenizer, batch_size=32):
#     y_pred, errors = [], []
#     print(f"🔄 Running batch predictions on {len(test_df):,} examples...")

#     for i in tqdm(range(0, len(test_df), batch_size), desc="Predicting"):
#         batch = test_df.iloc[i:i+batch_size]
#         prompts = [
#             format_prompt_for_inference(row["candidate_string"], row["modified_text"])
#             for _, row in batch.iterrows()
#         ]
        
#         try:
#             # Tokenize as a batch
#             inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")

#             with torch.no_grad():
#                 outputs = model.generate(
#                     **inputs,
#                     max_new_tokens=5,   # only need short outputs
#                     do_sample=False,    # deterministic
#                     temperature=0.0,
#                     pad_token_id=tokenizer.eos_token_id,
#                     eos_token_id=tokenizer.eos_token_id
#                 )

#             decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
#             for resp in decoded:
#                 y_pred.append(extract_label(resp))

#         except Exception as e:
#             print(e)
#             for idx in batch.index:
#                 errors.append(f"Error at index {idx}: {e}")
#                 y_pred.append("Non-sensitive")

#     if errors:
#         print(f"⚠️ {len(errors)} errors occurred during prediction.")
#     return y_pred


In [None]:
# Run comprehensive evaluation on test set
print("🚀 Running comprehensive evaluation on test set...")

# Get predictions for the entire test set
y_pred = predict_batch(X_test, model, tokenizer)
y_true_test = X_test['label'].tolist()

print(f"✅ Evaluation completed!")
print(f"📊 Prediction Summary:")
print(f"  - Total predictions: {len(y_pred):,}")
print(f"  - Unique predicted labels: {set(y_pred)}")
print(f"  - True label distribution: {X_test['label'].value_counts().to_dict()}")

# Quick accuracy check
correct_predictions = sum(1 for true, pred in zip(y_true_test, y_pred) if true == pred)
quick_accuracy = correct_predictions / len(y_pred) if len(y_pred) > 0 else 0
print(f"  - Quick accuracy: {quick_accuracy:.3f} ({correct_predictions}/{len(y_pred)})")

In [None]:
# ============================================
# Detailed Performance Metrics
# ============================================
accuracy = 0.0
precision_avg = 0.0
recall_avg = 0.0
f1_avg = 0.0

if len(y_pred) > 0:
    print("\n" + "="*50)
    print("📈 DETAILED PERFORMANCE METRICS")
    print("="*50)
    
    # Classification Report
    print("\n📊 Classification Report:")
    print(classification_report(y_true_test, y_pred))
    
    # Calculate precision, recall, F1-score for each class
    from sklearn.metrics import precision_recall_fscore_support
    labels = sorted(set(y_true_test))
    precision, recall, f1, support = precision_recall_fscore_support(
        y_true_test, 
        y_pred, 
        labels=labels
    )
    
    print(f"\n🏷️  Per-Class Detailed Metrics:")
    for i, label in enumerate(labels):
        print(f"\n  {label.upper()}:")
        print(f"    - Precision: {precision[i]:.3f}")
        print(f"    - Recall: {recall[i]:.3f}")
        print(f"    - F1-score: {f1[i]:.3f}")
        print(f"    - Support: {support[i]:,}")
    
    # Overall accuracy and binary metrics
    def map_func(x):
        return 1 if x == "Secret" else 0

    y_true_mapped = np.array([map_func(label) for label in y_true_test])
    y_pred_mapped = np.array([map_func(label) for label in y_pred])

    # Calculate overall accuracy
    overall_accuracy = accuracy_score(y_true=y_true_mapped, y_pred=y_pred_mapped)
    
    # Calculate weighted averages for overall metrics
    from sklearn.metrics import precision_recall_fscore_support
    precision_overall, recall_overall, f1_overall, _ = precision_recall_fscore_support(
        y_true_mapped, y_pred_mapped, average='weighted'
    )
    
    # Store these for later use in performance summary
    accuracy = overall_accuracy
    precision_avg = precision_overall
    recall_avg = recall_overall
    f1_avg = f1_overall
    
    print(f"\n🎯 Overall Performance:")
    print(f"  - Overall Accuracy: {overall_accuracy:.3f}")
    print(f"  - Weighted Precision: {precision_overall:.3f}")
    print(f"  - Weighted Recall: {recall_overall:.3f}")
    print(f"  - Weighted F1-Score: {f1_overall:.3f}")
    
    # Per-class accuracy
    for label_val, name in zip([0, 1], ["Non-sensitive", "Secret"]):
        label_indices = np.where(y_true_mapped == label_val)[0]
        if len(label_indices) > 0:
            label_accuracy = accuracy_score(
                y_true=y_true_mapped[label_indices], 
                y_pred=y_pred_mapped[label_indices]
            )
            print(f"  - Accuracy for {name}: {label_accuracy:.3f}")

else:
    print("❌ Cannot calculate metrics - no valid predictions made.")
    # Set default values if no predictions
    accuracy = 0.0
    precision_avg = 0.0
    recall_avg = 0.0
    f1_avg = 0.0

In [None]:
# ============================================
# Confusion Matrix and Error Analysis
# ============================================

if len(y_pred) > 0:
    print("\n" + "="*50)
    print("🔍 CONFUSION MATRIX & ERROR ANALYSIS")
    print("="*50)
    
    # Confusion Matrix
    cm = confusion_matrix(y_true=y_true_mapped, y_pred=y_pred_mapped)
    print("\n📊 Confusion Matrix:")
    print("Predicted →")
    print("Actual ↓     Non-sens  Secret")
    print(f"Non-sens      {cm[0,0]:6d}   {cm[0,1]:6d}")
    print(f"Secret        {cm[1,0]:6d}   {cm[1,1]:6d}")
    
    # Calculate derived metrics
    tn, fp, fn, tp = cm.ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # Recall for Secret
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # Recall for Non-sensitive
    
    print(f"\n📊 Additional Metrics:")
    print(f"  - True Positives (Secret correctly identified): {tp:,}")
    print(f"  - True Negatives (Non-sensitive correctly identified): {tn:,}")
    print(f"  - False Positives (Non-sensitive labeled as Secret): {fp:,}")
    print(f"  - False Negatives (Secret labeled as Non-sensitive): {fn:,}")
    print(f"  - Sensitivity (Secret recall): {sensitivity:.3f}")
    print(f"  - Specificity (Non-sensitive recall): {specificity:.3f}")
    
    # Error Analysis
    print(f"\n🔍 Error Breakdown:")
    false_positives = [(true, pred, idx) for idx, (true, pred) in enumerate(zip(y_true_test, y_pred)) 
                      if true == 'Non-sensitive' and pred == 'Secret']
    false_negatives = [(true, pred, idx) for idx, (true, pred) in enumerate(zip(y_true_test, y_pred)) 
                      if true == 'Secret' and pred == 'Non-sensitive']
    
    print(f"  - False Positives: {len(false_positives):,} (Non-sensitive → Secret)")
    print(f"  - False Negatives: {len(false_negatives):,} (Secret → Non-sensitive)")
    
    # Show sample errors
    if false_negatives:
        print(f"\n❌ Sample False Negatives (Security Risk):")
        for i, (true, pred, idx) in enumerate(false_negatives[:3]):
            candidate = X_test.iloc[idx]['candidate_string']
            print(f"  {i+1}. Candidate: {candidate[:80]}{'...' if len(candidate) > 80 else ''}")
            print(f"     True: {true} → Predicted: {pred}")
    
    if false_positives:
        print(f"\n⚠️  Sample False Positives (Usability Impact):")
        for i, (true, pred, idx) in enumerate(false_positives[:3]):
            candidate = X_test.iloc[idx]['candidate_string']
            print(f"  {i+1}. Candidate: {candidate[:80]}{'...' if len(candidate) > 80 else ''}")
            print(f"     True: {true} → Predicted: {pred}")
    
    # Risk Assessment
    print(f"\n🚨 Risk Assessment:")
    if fn > 0:
        fn_rate = fn / (tp + fn)
        if fn_rate < 0.05:
            print(f"  ✅ LOW SECURITY RISK: False negative rate = {fn_rate:.3f}")
        elif fn_rate < 0.10:
            print(f"  ⚠️  MODERATE SECURITY RISK: False negative rate = {fn_rate:.3f}")
        else:
            print(f"  ❌ HIGH SECURITY RISK: False negative rate = {fn_rate:.3f}")
    
    if fp > 0:
        fp_rate = fp / (tn + fp)
        if fp_rate < 0.05:
            print(f"  ✅ LOW USABILITY IMPACT: False positive rate = {fp_rate:.3f}")
        elif fp_rate < 0.10:
            print(f"  ⚠️  MODERATE USABILITY IMPACT: False positive rate = {fp_rate:.3f}")
        else:
            print(f"  ❌ HIGH USABILITY IMPACT: False positive rate = {fp_rate:.3f}")

else:
    print("❌ Cannot perform error analysis - no valid predictions made.")

In [None]:
# ============================================
# Performance Summary and Results Saving
# ============================================

if len(y_pred) > 0:
    print("\n" + "="*50)
    print("📋 PERFORMANCE SUMMARY")
    print("="*50)
    
    # Import required modules
    import json
    from datetime import datetime
    
    # Create performance summary using correctly defined variables
    performance_summary = {
        'total_predictions': len(y_pred),
        'successful_predictions': len([p for p in y_pred if p in ['Secret', 'Non-sensitive']]),
        'failed_predictions': len([p for p in y_pred if p not in ['Secret', 'Non-sensitive']]),
        'accuracy': accuracy,  # Now properly defined
        'precision': precision_avg,  # Now properly defined
        'recall': recall_avg,  # Now properly defined
        'f1_score': f1_avg,  # Now properly defined
        'false_negatives': fn if 'fn' in locals() else 0,
        'false_positives': fp if 'fp' in locals() else 0,
        'true_positives': tp if 'tp' in locals() else 0,
        'true_negatives': tn if 'tn' in locals() else 0
    }
    
    print(f"✅ Performance Overview:")
    print(f"  - Total Test Samples: {performance_summary['total_predictions']:,}")
    print(f"  - Successful Predictions: {performance_summary['successful_predictions']:,}")
    print(f"  - Failed Predictions: {performance_summary['failed_predictions']:,}")
    print(f"  - Accuracy: {performance_summary['accuracy']:.3f}")
    print(f"  - Precision: {performance_summary['precision']:.3f}")
    print(f"  - Recall: {performance_summary['recall']:.3f}")
    print(f"  - F1-Score: {performance_summary['f1_score']:.3f}")
    
    # Model quality assessment
    print(f"\n🎯 Model Quality Assessment:")
    if accuracy >= 0.95:
        print("  ✅ EXCELLENT performance (≥95% accuracy)")
    elif accuracy >= 0.90:
        print("  ✅ GOOD performance (≥90% accuracy)")
    elif accuracy >= 0.80:
        print("  ⚠️  ACCEPTABLE performance (≥80% accuracy)")
    else:
        print("  ❌ POOR performance (<80% accuracy)")
    
    # Save detailed results to CSV
    try:
        print(f"\n💾 Saving Results...")
        
        # Prepare detailed results
        detailed_results = []
        for i, (true_label, pred_label) in enumerate(zip(y_true_test, y_pred)):
            result = {
                'index': i,
                'candidate_string': X_test.iloc[i]['candidate_string'],
                'true_label': true_label,
                'predicted_label': pred_label,
                'correct': true_label == pred_label,
                'error_type': 'Correct' if true_label == pred_label else 
                             'False Positive' if true_label == 'Non-sensitive' and pred_label == 'Secret' else
                             'False Negative' if true_label == 'Secret' and pred_label == 'Non-sensitive' else
                             'Other Error'
            }
            detailed_results.append(result)
        
        # Save to CSV
        results_df = pd.DataFrame(detailed_results)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"training_test_results_{timestamp}.csv"
        results_df.to_csv(results_file, index=False)
        print(f"  ✅ Detailed results saved to: {results_file}")
        
        # Save performance summary
        summary_file = f"training_performance_summary_{timestamp}.json"
        with open(summary_file, 'w') as f:
            json.dump(performance_summary, f, indent=2)
        print(f"  ✅ Performance summary saved to: {summary_file}")
        
        # Show error samples summary
        error_samples = results_df[results_df['correct'] == False]
        if len(error_samples) > 0:
            print(f"\n🔍 Error Samples Preview:")
            print(f"  - Total Errors: {len(error_samples):,}")
            print(f"  - False Negatives: {len(error_samples[error_samples['error_type'] == 'False Negative']):,}")
            print(f"  - False Positives: {len(error_samples[error_samples['error_type'] == 'False Positive']):,}")
            print(f"  - Other Errors: {len(error_samples[error_samples['error_type'] == 'Other Error']):,}")
    
    except Exception as e:
        print(f"❌ Error saving results: {e}")
    
    print(f"\n🎉 TRAINING AND TESTING COMPLETE!")
    print(f"Model saved at: {output_dir}")
    print(f"Ready for deployment or further evaluation.")

else:
    print("❌ No valid predictions to summarize.")
    print("⚠️  Consider debugging the model inference process.")