# Prompt Shield: Distilguard

This project shows how a classifier helps detect prompt injection attacks and thus only allows benign prompts to be forwarded to the protected LLM.

The Jupyter notebook consists of several parts:
1) Training the classifier 
2) Evaluating the classifier

## 1) Training the classifier

In [16]:
# Import necessary libraries
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
import tempfile
import os
from peft import LoraConfig, get_peft_model, PeftModel
import numpy as np
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


In [2]:
# Load dataset from Hugging Face
dataset = load_dataset("xTRam1/safe-guard-prompt-injection")

In [3]:
# Split the train set into 90% train and 10% validation sets
split = dataset["train"].train_test_split(test_size=0.1)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
print("New splits:", {k: len(dataset[k]) for k in dataset})

New splits: {'train': 7412, 'test': 2060, 'validation': 824}


In [4]:
# Define text and label column names for tokenizer
text_col = "text"
label_col = "label"

In [5]:
# Choose model and tokenizer
# in this case DistilBERT uncased since it's lightweight and effective for understanding prompts 
MODEL_NAME = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# The maximum sequence length for DistilBERT, which will truncate about 2% of the dataset
MAX_LENGTH = 512


In [6]:

def preprocess(examples):
    return tokenizer(examples[text_col], truncation=True, max_length=MAX_LENGTH)

# Map tokenization
tokenized = dataset.map(preprocess,batched=True)

# Rename column to labels to fit Trainer API
tokenized = tokenized.rename_column(label_col, "labels")

# Convert to PyTorch tensors
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
print(tokenized)


Map: 100%|██████████| 7412/7412 [00:00<00:00, 29162.17 examples/s]
Map: 100%|██████████| 824/824 [00:00<00:00, 34029.86 examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 7412
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 2060
    })
    validation: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 824
    })
})





Use LoRA on top of the base model for faster training and lower memory use without impacting performance.

In [7]:
NUM_LABELS = 2 # 0 for benign, 1 for malicious

# Load DistilBERT base model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)

# LoRA configuration
peft_config = LoraConfig(
    r=8,                
    lora_alpha=32,
    target_modules=["q_lin", "v_lin"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS",
)

# Create a PEFT model by adding the LoRA adapter on top of the base model
model = get_peft_model(model, peft_config)

# Move model to device (MPS or CPU) - MPS for faster training on Mac
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

print("Model loaded and moved to:", next(model.parameters()).device)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded and moved to: mps:0


Calcuate evaluation metrics during training

In [8]:

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", zero_division=0)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}


In [9]:
training_args = TrainingArguments(
    output_dir="outputs/prompt_shield", # Directory to save model checkpoints
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    eval_strategy="epoch", # run evaluation after each epoch
    save_strategy="epoch", # save checkpoint after each epoch
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=100,
    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    fp16=False, # For MPS compatibility
)

# Dynamic padding of all sequences in a batch to the same length
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


  trainer = Trainer(


In [None]:
# Run training
trainer.train()



[34m[1mwandb[0m: Currently logged in as: [33mlilysijiali[0m ([33mlilysijiali-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS

  super().__init__(loader)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0588,0.046051,0.981796,0.968504,0.972332,0.970414
2,0.0669,0.036319,0.984223,0.98,0.968379,0.974155
3,0.0498,0.033987,0.985437,0.98008,0.972332,0.97619


  super().__init__(loader)
  super().__init__(loader)


('models/prompt_shield_distilbert_lora/tokenizer_config.json',
 'models/prompt_shield_distilbert_lora/special_tokens_map.json',
 'models/prompt_shield_distilbert_lora/vocab.txt',
 'models/prompt_shield_distilbert_lora/added_tokens.json',
 'models/prompt_shield_distilbert_lora/tokenizer.json')

In [None]:
# Save the merged model and tokenizer
merged_dir = "models/prompt_shield_merged"
os.makedirs(merged_dir, exist_ok=True)
tokenizer.save_pretrained(merged_dir)

merged = model.merge_and_unload()
merged.save_pretrained(merged_dir)
print(f"Saved merged model to {merged_dir}")


Saved merged model to models/prompt_shield_merged


## 2) Evaluate classifier 

In [12]:
final_test_metrics = trainer.evaluate(tokenized["test"])
print(final_test_metrics)

  super().__init__(loader)


{'eval_loss': 0.03776386380195618, 'eval_accuracy': 0.9864077669902913, 'eval_precision': 0.9814241486068112, 'eval_recall': 0.9753846153846154, 'eval_f1': 0.9783950617283951, 'eval_runtime': 20.0163, 'eval_samples_per_second': 102.916, 'eval_steps_per_second': 6.445, 'epoch': 3.0}


| Metric                   | Value                |
|--------------------------|----------------------|
| **eval_loss**            | 0.0390               |
| **eval_accuracy**        | 0.9854               |
| **eval_precision**       | 0.9769               |
| **eval_recall**          | 0.9769               |
| **eval_f1**              | 0.9769               |
| **eval_runtime (s)**     | 17.0447              |
| **samples/second**       | 120.858              |
| **steps/second**         | 7.568                |
| **epoch**                | 3.0                  |
