In [None]:
from datasets import load_dataset, concatenate_datasets
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import numpy as np

# 1. Load datasets
print("🔄 Loading datasets...")
dataset_names = [
    "imoxto/prompt-injection-cleaned",
    "JasperLS/prompt-injections",
    "rubend18/ChatGPT-Jailbreak-Prompts"
]

datasets = []
for name in dataset_names:
    ds = load_dataset(name)
    if "train" in ds:
        datasets.append(ds["train"])
    else:
        datasets.append(ds[list(ds.keys())[0]])

# 2. Normalize labels (0 = clean, 1 = prompt injection)
def normalize(example):
    # Default heuristic: assume 1 means injection, 0 safe
    label = int(example.get("label", 1))  # fallback to malicious if unsure
    example["label"] = 1 if label != 0 else 0
    return example

all_data = concatenate_datasets(datasets)
all_data = all_data.map(normalize)

# 3. Tokenization
print("🔠 Tokenizing...")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=256)

tokenized = all_data.map(tokenize, batched=True)
tokenized = tokenized.train_test_split(test_size=0.2, seed=42)
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# 4. Model
print("🧠 Loading model...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

# 5. Metrics
def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    preds = np.argmax(eval_pred.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

# 6. Training
training_args = TrainingArguments(
    output_dir="./firewall_prompt_injection_model",
    evaluation_strategy="epoch",
    logging_dir="./logs",
    num_train_epochs=4,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# 7. Train!
print("🚀 Training...")
trainer.train()

# 8. Evaluate
print("📊 Evaluation...")
predictions = trainer.predict(tokenized["test"])
print(classification_report(tokenized["test"]["label"], np.argmax(predictions.predictions, axis=1)))
