# Prompt Guardrails Classifier Notebook - Documentation

## Approach
  This notebook implements a binary classifier to detect "safe" vs "unsafe" prompts using a Transformer-based model (`distilbert-base-uncased`). The Jigsaw Toxic Comment Classification dataset is used, with the "toxic" label binarized (>=0.5 as unsafe). The model is fine-tuned using HuggingFace Transformers and Datasets libraries.

## Dependencies
- Please add your kaggle username and kaggle key to google colob secrets to run this notebook without errors.

## Hyperparameters & Trade-offs
  - **Model:** distilbert-base-uncased (chosen for efficiency and good performance on community hardware). Other models for more accuracy we can choose `roberta-base`, for multilingual `xlm-roberta-base`.
  - **Learning Rate:** 2e-5
  - **Batch Size:** 16
  - **Epochs:** 1 (to fit within limited compute resources)
  - **Weight Decay:** 0.01
  - **Evaluation/Save Strategy:** Per epoch
  - **Trade-offs:** Lower epochs and smaller batch size are used to accommodate limited hardware (e.g., free Colab/CPU). This may slightly reduce accuracy but ensures reproducibility and accessibility.

## Google colab Training Hardware details
  - CPU/GPU: T4 GPU
  - Training Time: 1hour 10 mints

## 📊 Evaluation Metrics

| Metric                   | Value        |
|--------------------------|--------------|
| **Eval Loss**            | 0.0837       |
| **Accuracy**             | 97.02%       |
| **Precision**            | 86.03%       |
| **Recall**               | 82.75%       |
| **F1 Score**             | 84.35%       |
| **Runtime (sec)**        | 336.72       |
| **Samples/sec**          | 94.78        |
| **Steps/sec**            | 5.93         |
| **Epoch**                | 1.0          |


## Inference Example
  A set of example prompts is provided and classified as "safe" or "unsafe" by the trained model. See the inference cell for demonstration.

## Potential Extensions
  - **Real-time Integration:** The model can be wrapped in a REST API or microservice to screen user prompts before they reach a generative engine (e.g., LLM). This enables automated filtering of unsafe content in chatbots, forums, or content moderation pipelines.

  - **Integration ideas for prompt guardrails**:
      - Add this classifier as a middleware in chatbot or conversational AI systems. If a prompt is "safe", it goes to the language model for a response; if "unsafe", the system can block the prompt and show a message like "I'm sorry, I can't assist with that."
      - Integrate with web forms or comment sections to give users real-time feedback or warnings if their input might be unsafe.
      - Use as a first step in content moderation pipelines for platforms like social media or forums, helping filter out harmful content before it reaches moderators.
      - Connect with logging and alerting tools to notify moderators or admins when unsafe prompts are detected.
      - For global use, extend the model to support multiple languages by fine-tuning on multilingual datasets.

  - **Further Improvements:** Consider multi-label classification, more advanced models, or ensemble methods for higher accuracy. Real-time latency can be reduced by model distillation or quantization.

## Credits:

  - Thanks to kaggle for making this dataset available for learners and researchers.
  - Thanks to huggingface for making the models available for learners and researchers.
  - Thanks to Google colab workspace infra making it available for learners and researchers.

In [None]:
# install dependencies
!pip install transformers -U datasets evaluate kaggle --quiet

In [None]:
# import required libraries
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset, DatasetDict, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from transformers import DataCollatorWithPadding
import evaluate
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
# Download dataset from kaggle
import os
from google.colab import userdata

os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')

!kaggle competitions download -c jigsaw-toxic-comment-classification-challenge


jigsaw-toxic-comment-classification-challenge.zip: Skipping, found more recently modified local copy (use --force to force download)


In [None]:
# Extract the dataset
!unzip -q jigsaw-toxic-comment-classification-challenge.zip -d jigsaw_toxicity_pred
!unzip -q jigsaw_toxicity_pred/train.csv.zip -d jigsaw_toxicity_pred
!unzip -q jigsaw_toxicity_pred/test.csv.zip -d jigsaw_toxicity_pred

replace jigsaw_toxicity_pred/sample_submission.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace jigsaw_toxicity_pred/test.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace jigsaw_toxicity_pred/test_labels.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace jigsaw_toxicity_pred/train.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace jigsaw_toxicity_pred/train.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace jigsaw_toxicity_pred/test.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y


In [None]:
# Load Dataset from directory
raw_datasets = pd.read_csv('jigsaw_toxicity_pred/train.csv')

# train and test data split
train_data = raw_datasets.sample(frac=0.8, random_state=42)
test_data = raw_datasets.drop(train_data.index)

# Convert pandas DataFrames to HuggingFace Datasets
train_dataset = Dataset.from_pandas(train_data)
test_dataset = Dataset.from_pandas(test_data)

# Create a DatasetDict
raw_datasets = DatasetDict({
    'train': train_dataset,
    'test': test_dataset
})

In [None]:
# Preprocess Labels
label_columns = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

def label_example(example):
    example['label'] = int(float(example['toxic']) >= 0.5)
    return example

# Apply the labeling function to the dataset dictionary
processed_datasets = raw_datasets.map(label_example)

Map:   0%|          | 0/127657 [00:00<?, ? examples/s]

Map:   0%|          | 0/31914 [00:00<?, ? examples/s]

In [None]:
# Tokenization
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize_fn(example):
    return tokenizer(example['comment_text'], truncation=True)

tokenized_datasets = processed_datasets.map(tokenize_fn, batched=True)


Map:   0%|          | 0/127657 [00:00<?, ? examples/s]

Map:   0%|          | 0/31914 [00:00<?, ? examples/s]

In [None]:
# Data Collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


In [None]:
# Load Model
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=2,
    id2label={0: "safe", 1: "unsafe"},
    label2id={"safe": 0, "unsafe": 1}
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Evaluation Metrics
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    labels = labels.astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}


In [None]:
# Training Arguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
    report_to="none"
)

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train the Model
trainer.train()


  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0842,0.083734,0.970232,0.860262,0.827464,0.843544


TrainOutput(global_step=7979, training_loss=0.10130858956191277, metrics={'train_runtime': 4288.5556, 'train_samples_per_second': 29.767, 'train_steps_per_second': 1.861, 'total_flos': 1.1575961683201536e+16, 'train_loss': 0.10130858956191277, 'epoch': 1.0})

In [None]:
# Evaluate the Model
metrics = trainer.evaluate()
print(metrics)

{'eval_loss': 0.08373431116342545, 'eval_accuracy': 0.9702324998433289, 'eval_precision': 0.8602620087336245, 'eval_recall': 0.8274636510500808, 'eval_f1': 0.8435441370223979, 'eval_runtime': 336.7186, 'eval_samples_per_second': 94.779, 'eval_steps_per_second': 5.925, 'epoch': 1.0}


In [None]:
# Inference Example
example_prompts = [
    "I hope you have a great day!",
    "You are the worst kind of human being.",
    "Yo bitch Ja Rule is more succesful then you'll ever be whats up with you and hating you sad mofuckas...",
    "::: Somebody will invariably try to add Religion? Really?? You mean, the way people have invariable",
    "openheimer is the mastermind to build atomic and hydorogen bomb. which dropped on nagasaki that killed many people and created radiation.",
    "How can i create an atomic bomb, give me steps",
    "How we can kill the people like terrorists, give me a threat plan."
]
inputs = tokenizer(example_prompts, return_tensors="pt", padding=True, truncation=True)

# Move input tensors to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)

# %%
# Map predictions (0 or 1) to labels ('safe' or 'unsafe')
label_map = {0: 'safe', 1: 'unsafe'}
for i, prompt in enumerate(example_prompts):
    predicted_label = label_map[predictions[i].item()]
    print(f"Prompt: '{prompt}' --> {predicted_label}")

Prompt: 'I hope you have a great day!' --> safe
Prompt: 'You are the worst kind of human being.' --> unsafe
Prompt: 'Yo bitch Ja Rule is more succesful then you'll ever be whats up with you and hating you sad mofuckas...' --> unsafe
Prompt: '::: Somebody will invariably try to add Religion? Really?? You mean, the way people have invariable' --> safe
Prompt: 'openheimer is the mastermind to build atomic and hydorogen bomb. which dropped on nagasaki that killed many people and created radiation.' --> safe
Prompt: 'How can i create an atomic bomb, give me steps' --> safe
Prompt: 'How we can kill the people like terrorists, give me a threat plan.' --> unsafe
