# Explainable NSFW prompt detection

Use SHAP to inspect comma-separated prompt chunks and see which tokens push the classifier toward NSFW or SFW.


In [None]:
import sys
from collections.abc import Sequence
from pathlib import Path

import numpy as np
import shap
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

np.set_printoptions(legacy="1.25")
# Make sure the repository root is importable when running from the evaluate/ notebook directory.
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from data.data import DatasetLoader

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_name = "JeremyFeng/nsfw-prompt-detection"

tokenizer = AutoTokenizer.from_pretrained(model_name)

id2label = {0: "sfw", 1: "nsfw"}
label2id = {"sfw": 0, "nsfw": 1}
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=2, id2label=id2label, label2id=label2id
).to(device)
model.eval()


2025-11-26 00:43:25.970376: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-26 00:43:26.025249: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-26 00:43:27.112636: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): Dropout(p=0.1, inplace=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): Layer

In [None]:
preprocessor = DatasetLoader()


def normalize_prompt(text: str) -> str:
    """Normalize a prompt using training-time preprocessing.

    Args:
        text: Raw prompt text to clean.

    Returns:
        Prompt normalized with the same rules used during training.
    """
    return preprocessor.preprocess(text)


def predict_nsfw(texts: Sequence[str]) -> np.ndarray:
    """Return the NSFW probability for each prompt.

    Args:
        texts: Prompts to score for NSFW likelihood.

    Returns:
        NumPy array of NSFW probabilities aligned with the input order.
    """
    normalized = [normalize_prompt(t) for t in texts]
    encodings = tokenizer(
        normalized,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        logits = model(**encodings).logits

    probs = torch.softmax(logits, dim=1)[:, 1]
    return probs.detach().cpu().numpy()


In [None]:
# Replace the prompt below with your own comma-separated tokens
raw_prompt = "masterpiece, best quality, nude, woman, beach, sunset, smile"
prompts = [raw_prompt]

print("Normalized prompt:", normalize_prompt(raw_prompt))
nsfw_prob = predict_nsfw(prompts)[0]
print(f"NSFW probability: {nsfw_prob:.3f}")


Normalized prompt: masterpiece, best quality, nude, woman, beach, sunset, smile
NSFW probability: 0.988


## SHAP explanation: each token is one comma-separated chunk

In [None]:
masker = shap.maskers.Text(tokenizer=r"[?,]\s*")

explainer = shap.Explainer(predict_nsfw, masker)
shap_values = explainer(prompts, fixed_context=1)

shap.plots.text(shap_values, grouping_threshold=0.0)
