In [1]:
import sys
sys.path.append('..')

In [2]:
"""
Full Example of a QA Pipeline with:
- Open-Source Toxic/Offensive Checks (via huggingface pipeline)
- Zero-Shot Domain Relevance Classification
- Naive Prompt-Injection Detection
- Guardrails (RAIL schema) for Post-Generation Validation
- LlamaIndex RAG Setup
- Customizable CONFIG
"""

from typing import Annotated, List, Dict, Any
from typing_extensions import TypedDict

# ---------------------------
#  CONFIG - CUSTOMIZE HERE
# ---------------------------
CONFIG: Dict[str, Any] = {
    # Offensive/Toxic Detection
    "OFFENSIVE_MODEL_NAME": "cardiffnlp/twitter-roberta-base-offensive",
    "OFFENSIVE_THRESHOLD": 0.7,  # Confidence threshold for marking text as "offensive"

    # Zero-Shot Classification
    "ZSC_MODEL_NAME": "facebook/bart-large-mnli",
    "DOMAIN_LABELS": ["agentic AI", "autonomous systems", "decision-making", "LLM tools"],
    "DOMAIN_THRESHOLD": 0.8,  # Confidence threshold for domain relevance

    # Prompt Injection
    "SUSPICIOUS_PHRASES": [
        "ignore previous instructions",
        "follow my instructions instead",
        "system role",
        "assistant role",
        "developer mode",
        "jailbreak",
        "override",
        "bypass"
    ],

    # LLM / Embedding settings
    "LLAMA_MODEL": "llama3.2",     # Example model if using Ollama
    "LLAMA_TEMPERATURE": 0.3,
    "LLAMA_MAX_TOKENS": 200,
    "LLAMA_TOP_P": 0.9,

    # Other pipeline settings
    "DATA_PATH": "../data",        # Directory to load PDFs from
    "FILE_EXTENSIONS": [".pdf"],   # Required file extensions
    "RECURSIVE_LOAD": True
}

# ------------------------------------------------------
#  GUARDRAILS RAIL SCHEMA (instead of older YAML rules)
# ------------------------------------------------------
# This minimal example enforces a maximum length and disallows certain strings.
MY_RAIL_SCHEMA = """
<rail version="0.1">

<output>
    <string
        name="answer"
        description="A concise answer about agentic AI, free from unsafe or out-of-domain content."
        on-fail="reask"
        max_length="600"
    >
        <!-- Disallow certain unsafe terms -->
        <disallowed-strings strings="bomb,attack,weapon,explosive,harm" />
    </string>
</output>

</rail>
"""

In [3]:
# ------------------------------------------------------
#  IMPORTS
# ------------------------------------------------------
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    pipeline
)

import guardrails as gd
from guardrails import Guard

# llama_index & supporting libs
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, PromptTemplate
from llama_index.core.llms import ChatMessage
from langgraph.graph import StateGraph

# If you have a local "utils.py" with display_graph_image, try importing:
try:
    from utils import display_graph_image
    HAS_UTILS = True
except ImportError:
    HAS_UTILS = False


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'guardrails_api_client.api.service_health_api'

In [None]:
# ---------------------------
#   GUARDRAILS SETUP (RAIL)
# ---------------------------
# This loads our RAIL schema, which is an XML string describing the output format + constraints.
guard = Guard.from_rail_string(MY_RAIL_SCHEMA)

In [None]:
# ---------------------------
#   OFFENSIVE / TOXIC CHECK
# ---------------------------
offensive_tokenizer = AutoTokenizer.from_pretrained(CONFIG["OFFENSIVE_MODEL_NAME"])
offensive_model = AutoModelForSequenceClassification.from_pretrained(CONFIG["OFFENSIVE_MODEL_NAME"])
offensive_pipeline = pipeline(
    "text-classification",
    model=offensive_model,
    tokenizer=offensive_tokenizer,
    return_all_scores=True
)

def is_offensive(text: str, threshold: float = CONFIG["OFFENSIVE_THRESHOLD"]) -> bool:
    if not text.strip():
        return False
    results = offensive_pipeline(text)
    for label_score_dict in results[0]:
        if "offensive" in label_score_dict["label"].lower():
            if label_score_dict["score"] > threshold:
                return True
    return False


In [None]:
# ---------------------------
#   ZERO-SHOT CLASSIFICATION
# ---------------------------
zsc_tokenizer = AutoTokenizer.from_pretrained(CONFIG["ZSC_MODEL_NAME"])
zsc_model = AutoModelForSequenceClassification.from_pretrained(CONFIG["ZSC_MODEL_NAME"])
zsc_pipeline = pipeline(
    "zero-shot-classification",
    model=zsc_model,
    tokenizer=zsc_tokenizer
)

def is_domain_relevant(user_query: str,
                       candidate_labels=None,
                       threshold: float = CONFIG["DOMAIN_THRESHOLD"]) -> bool:
    if candidate_labels is None:
        candidate_labels = CONFIG["DOMAIN_LABELS"]
    if not user_query.strip():
        return False

    result = zsc_pipeline(user_query, candidate_labels)
    top_score = max(result["scores"])
    return top_score >= threshold

In [17]:
# ---------------------------
#   PROMPT INJECTION CHECK
# ---------------------------
def detect_prompt_injection(user_query: str) -> bool:
    lower_query = user_query.lower()
    for phrase in CONFIG["SUSPICIOUS_PHRASES"]:
        if phrase in lower_query:
            return True
    return False

# ---------------------------
#   SAFETY CHECK WRAPPER
# ---------------------------
def is_safe_input(user_query: str) -> bool:
    if detect_prompt_injection(user_query):
        return False
    if is_offensive(user_query):
        return False
    return True

In [None]:
# ---------------------------
#   LLM & EMBEDDINGS
# ---------------------------
embed_model = HuggingFaceEmbedding(
    model_name="nomic-ai/modernbert-embed-base", 
    trust_remote_code=True,
    cache_folder="./hf_cache"
)

llm = Ollama(
    model=CONFIG["LLAMA_MODEL"],
    temperature=CONFIG["LLAMA_TEMPERATURE"],
    max_tokens=CONFIG["LLAMA_MAX_TOKENS"],
    top_p=CONFIG["LLAMA_TOP_P"]
)

Settings.llm = llm
Settings.embed_model = embed_model

# ---------------------------
#   INDEX & QUERY ENGINE
# ---------------------------
loader = SimpleDirectoryReader(
    input_dir=CONFIG["DATA_PATH"],
    required_exts=CONFIG["FILE_EXTENSIONS"],
    recursive=CONFIG["RECURSIVE_LOAD"]
)
docs = loader.load_data()

index = VectorStoreIndex.from_documents(docs, show_progress=True)
query_engine = index.as_query_engine(streaming=True)

In [19]:
# ---------------------------
#   PROMPT TEMPLATES
# ---------------------------
qa_prompt_tmpl_str = (
    "You are an expert assistant providing concise, accurate, and polite answers related to agentic AI only. "
    "If the user's query is unrelated or unsafe, politely decline to respond and provide no further information. "
    "Focus solely on agentic AI and its concepts.\n\n"
    "Context from documents:\n{context_str}\n"
    "User query:\n{query_str}\n\n"
    "Answer concisely and strictly within the context of agentic AI:"
)
qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
query_engine.update_prompts({"response_synthesizer:text_qa_template": qa_prompt_tmpl})


In [20]:
# ---------------------------
#   STATE & GRAPH
# ---------------------------
class State(TypedDict):
    messages: Annotated[List[tuple], "Chat history: (role, text)"]
    context: Annotated[str, "Context retrieved from documents"]

graph_builder = StateGraph(State)

# ---------------------------
#   GRAPH NODES
# ---------------------------
def route_query(state: State):
    """
    Route the query based on safety ONLY; domain checks are removed.
    """
    last_message = state["messages"][-1]
    user_query = last_message[1] if isinstance(last_message, tuple) else ""

    # Check for unsafe input (prompt injection, offensive, etc.)
    if not is_safe_input(user_query):
        print("Step: route_query - Unsafe input detected.")
        return {"next_node": "unsafe_input"}

    # If the input is safe, we proceed directly to rewrite_query
    # (or any other node you want for normal flow).
    return {"next_node": "rewrite_query"}

def rewrite_query(state: State):
    user_query = state["messages"][-1][1]
    rewrite_prompt = (
        "Rewrite the query below to ensure it is concise, precise, and optimized "
        "for retrieval within the domain of agentic AI:\n\n"
        f"Original query: {user_query}\n\nRewritten query:"
    )
    rewritten_response = llm.chat([ChatMessage(role="user", content=rewrite_prompt)])
    rewritten_query = rewritten_response.message.content.strip() if rewritten_response.message else "Failed to rewrite query."
    print("Step: rewrite_query")
    state["messages"].append(("system", f"Rewritten query: {rewritten_query}"))
    return {"messages": state["messages"], "rewritten_query": rewritten_query}

def rag_node(state: State):
    user_query = state.get("rewritten_query", state["messages"][-1][1])
    docs_result = query_engine.query(user_query)
    full_context = "".join(chunk for chunk in docs_result.response_gen)

    print("Step: rag_node - Context retrieved.")
    state["context"] = full_context
    state["messages"].append(("system", f"Context retrieved: {full_context}"))
    return {"messages": state["messages"], "context": full_context}

def chatbot(state: State):
    context = state.get("context", "")
    user_query = state["messages"][-1][1]
    final_prompt = (
        f"Using the following context:\n{context}\n\n"
        f"Answer the following query concisely and politely, strictly within the domain of agentic AI:\n{user_query}"
    )

    response = llm.chat([ChatMessage(role="user", content=final_prompt)])
    response_content = response.message.content if response.message else "Failed to generate a meaningful response."
    print("Step: chatbot")
    state["messages"].append(("assistant", response_content))
    return {"messages": state["messages"]}

def review_response(state: State):
    """
    Post-generation checks:
      1) Offensive detection on the final assistant message.
      2) Guardrails validation (via RAIL schema).
    """
    final_msg_role, final_msg_text = state["messages"][-1]

    if final_msg_role != "assistant":
        return {"messages": state["messages"]}

    # (1) Offensive check
    if is_offensive(final_msg_text, threshold=CONFIG["OFFENSIVE_THRESHOLD"]):
        print("Step: review_response - Final response flagged as offensive.")
        state["messages"][-1] = (
            "assistant",
            "I’m sorry, I cannot provide that answer."
        )
        return {"messages": state["messages"]}

    # (2) Guardrails RAIL validation
    #    We call guard(...) with the final_msg_text.
    validated_output, raw_output = guard(final_msg_text)
    if guard.errored:
        print("Step: review_response - Guardrails validation failed.")
        state["messages"][-1] = (
            "assistant",
            "I'm sorry, I cannot provide that answer under these guidelines."
        )
        return {"messages": state["messages"]}

    # If guardrails pass, replace final text with validated version
    # validated_output is typically a dict, e.g. {"answer": "..."} 
    # depending on how your RAIL schema is structured.
    # We just store validated_output["answer"] or the entire validated_output as final text.
    final_answer = validated_output.get("answer", "")
    print("Step: review_response - Final response approved.")
    state["messages"][-1] = ("assistant", final_answer)
    return {"messages": state["messages"]}

def unsafe_input(state: State):
    print("Step: unsafe_input")
    state["messages"].append((
        "assistant",
        "Your query or content appears unsafe or disallowed. I’m unable to process it."
    ))
    return {"messages": state["messages"]}

def out_of_domain(state: State):
    print("Step: out_of_domain")
    state["messages"].append((
        "assistant",
        "I'm sorry, I can only answer questions strictly related to agentic AI."
    ))
    return {"messages": state["messages"]}

In [None]:
# ---------------------------
#   BUILD THE STATE GRAPH
# ---------------------------
graph_builder.add_node("route_query", route_query)
graph_builder.add_node("rewrite_query", rewrite_query)
graph_builder.add_node("rag", rag_node)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("review_response", review_response)
graph_builder.add_node("unsafe_input", unsafe_input)
#graph_builder.add_node("out_of_domain", out_of_domain)

graph_builder.add_conditional_edges(
    "route_query",
    lambda state: state["next_node"],
    {
        "rewrite_query": "rewrite_query",
        "unsafe_input": "unsafe_input",
       # "out_of_domain": "out_of_domain",
    }
)
graph_builder.add_edge("rewrite_query", "rag")
graph_builder.add_edge("rag", "chatbot")
graph_builder.add_edge("chatbot", "review_response")

graph_builder.set_entry_point("route_query")
graph = graph_builder.compile()

if HAS_UTILS:
    display_graph_image(graph)

In [None]:
from langchain_core.messages import HumanMessage
final_state = graph.invoke(
    {
        "messages": [HumanMessage(content="what is the weather in sf")]
    },
    config={"configurable": {"thread_id": 42}}
)

In [None]:
final_state

In [None]:
final_state = graph.invoke(
    {
        "messages": [HumanMessage(content="what is bomb")]
    },
    config={"configurable": {"thread_id": 42}}
)

In [None]:
final_state

In [None]:
final_state = graph.invoke(
    {
        "messages": [HumanMessage(content="what is difference between agents AI and workflows")]
    },
    config={"configurable": {"thread_id": 42}}
)

In [None]:
final_state