From 84d6361d9face0209e6d8d446b16c81b22f26dfe Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 24 Sep 2025 21:27:28 -0400 Subject: [PATCH 01/19] Add Granite Guardian 3.3 8B with dual backends and function call validation - Enhanced GuardianCheck with HuggingFace and Ollama backends - Added thinking mode support for detailed reasoning traces - Implemented actual function calling validation with RepairTemplateStrategy that consumes reasoning in repair process. - Added groundedness and function call hallucination detection examples --- docs/examples/safety.py/guardian.py | 157 +++++- .../safety.py/guardian_huggingface.py | 46 ++ .../safety.py/repair_with_guardian.py | 154 ++++++ mellea/stdlib/safety/guardian.py | 447 ++++++++++++++---- mellea/stdlib/sampling.py | 305 ++---------- 5 files changed, 729 insertions(+), 380 deletions(-) create mode 100644 docs/examples/safety.py/guardian_huggingface.py create mode 100644 docs/examples/safety.py/repair_with_guardian.py diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index 9293ce63..8d92f2a8 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -1,37 +1,154 @@ -"""Example of using the Guardian Requirement""" +"""Example of using the Enhanced Guardian Requirement with Granite Guardian 3.3 8B""" from mellea import MelleaSession from mellea.backends import model_ids -from mellea.backends.dummy import DummyBackend from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import Context, ContextTurn, ModelOutputThunk, SimpleContext +from mellea.stdlib.base import ContextTurn, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk -# Start with instantiating a GuardianCheck. -guardian = GuardianCheck(risk=GuardianRisk.HARM) +# Enhanced GuardianCheck with Granite Guardian 3.3 8B support +print("=== Enhanced GuardianCheck Examples ===") -print("\n Test 1\n") +# Basic harm checking with new features +guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="ollama") +print(f"Guardian configured: {guardian._risk} with thinking mode: {guardian._thinking}") + +# Try different risk types with new 3.3 8B model +jailbreak_guardian = GuardianCheck(GuardianRisk.JAILBREAK, thinking=True) +print(f"Jailbreak guardian: {jailbreak_guardian._risk}") + +# Custom risk criteria example +custom_guardian = GuardianCheck(custom_criteria="Check for inappropriate content in educational context") +print(f"Custom guardian: {custom_guardian.get_effective_risk()}") + +print("\n=== Test 1: Normal Content ===") # Run a query against an Ollama model m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) -m.chat("Write an email to your colleague. Use less than 100 words.") -print(f"Full context: {m.ctx}") +m.chat("Write a professional email to your colleague. Use less than 100 words.") + # Validate on the last turn of the conversation valid = m.validate([guardian]) -# Should return [True] meaning that input and output are free of harm and it's safe to return result -print(f"Guardian result: {valid}") +print(f"Content is safe: {valid[0]._result}") +if valid[0]._reason: + print(f"Guardian feedback: {valid[0]._reason}...") -# Reproducing a harm scenario and run the validation again -print("\n Test 2\n") - -# create a mean conversation and add to context +print("\n=== Test 2: Harmful Content ===") +# Create harmful conversation and add to context m.ctx.insert_turn( - ContextTurn(Message("user", "Hello. "), ModelOutputThunk("You are very ugly.")) + ContextTurn(Message("user", "Hello"), ModelOutputThunk("You are very ugly and stupid.")) +) + +valid = m.validate([guardian]) +print(f"Content is safe: {valid[0]._result}") +if valid[0]._reason: + print(f"Guardian feedback: {valid[0]._reason}...") + +print("\n=== Test 3: Enhanced Features ===") + +# Create multiple guardians with different configurations +guardians = [ + GuardianCheck(GuardianRisk.HARM, thinking=True), + GuardianCheck(GuardianRisk.JAILBREAK, thinking=True), + GuardianCheck(GuardianRisk.SOCIAL_BIAS), + GuardianCheck(custom_criteria="Check for financial advice"), +] + +print(f"Available risk types ({len(GuardianCheck.get_available_risks())} total):") +for risk in GuardianCheck.get_available_risks(): # Show first 5 + print(f" - {risk}") +print(" ...") + +print(f"\nConfigured guardians: {len(guardians)} total") + +# Show thinking mode capability +thinking_guardian = GuardianCheck(GuardianRisk.HARM, thinking=True) +print(f" Thinking mode supported: {thinking_guardian.supports_thinking_mode()}") + +# Show Ollama backend configuration +ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama") +print(f" Ollama backend: {ollama_guardian._backend.model_version}") + +print("\n=== Test 4: Groundedness Detection ===") +# Test groundedness - detecting when responses lack factual grounding +context_text = "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding." + +groundedness_guardian = GuardianCheck( + GuardianRisk.GROUNDEDNESS, + thinking=True, + backend_type="ollama", + context_text=context_text +) + +# Create a response that makes ungrounded claims relative to provided context +groundedness_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) +groundedness_session.ctx.insert_turn( + ContextTurn( + Message("user", "What is the history of treaty making?"), + ModelOutputThunk("Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.") + ) +) + +print("Testing response with ungrounded claims...") +groundedness_valid = groundedness_session.validate([groundedness_guardian]) +print(f"Response is grounded: {groundedness_valid[0]._result}") +if groundedness_valid[0]._reason: + print(f"Groundedness feedback: {groundedness_valid[0]._reason}...") + +print("\n=== Test 5: Function Call Hallucination Detection ===") +# Test function calling hallucination using IBM video example +import json + +tools = [ + { + "name": "views_list", + "description": "Fetches total views for a specified IBM video using the given API.", + "parameters": { + "video_id": { + "description": "The ID of the IBM video.", + "type": "int", + "default": "7178094165614464282" + }, + "count": { + "description": "The number of comments to fetch. Maximum is 30. Defaults to 20.", + "type": "int, optional", + "default": "20" + } + } + } +] + +function_guardian = GuardianCheck( + GuardianRisk.FUNCTION_CALL, + thinking=True, + backend_type="ollama", + tools=tools +) + +# User asks for views but assistant calls wrong function (comments_list instead of views_list) +response_data = [ + { + "name": "comments_list", + "arguments": { + "video_id": 456789123, + "count": 15 + } + } +] +hallucinated_response = str(response_data) + +function_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) +function_session.ctx.insert_turn( + ContextTurn( + Message("user", "Fetch total views for the IBM video with ID 456789123."), + ModelOutputThunk(hallucinated_response) + ) ) -# show last turn in chat -print(f"Context: {m.ctx.last_turn()}") +print("Testing response with function call hallucination...") +function_valid = function_session.validate([function_guardian]) +print(f"Function calls are valid: {function_valid[0]._result}") +if function_valid[0]._reason: + print(f"Function call feedback: {function_valid[0]._reason}...") -check_results = m.validate([guardian]) -# Should return [False] meaning that input and output contain harm and it's NOT safe to return result -print(f"Guardian check results: {check_results}") +print("\n=== GuardianCheck Demo Complete ===") diff --git a/docs/examples/safety.py/guardian_huggingface.py b/docs/examples/safety.py/guardian_huggingface.py new file mode 100644 index 00000000..4c196fa8 --- /dev/null +++ b/docs/examples/safety.py/guardian_huggingface.py @@ -0,0 +1,46 @@ +"""Example of using GuardianCheck with HuggingFace backend for direct model inference""" + +from mellea.stdlib.base import ContextTurn, ModelOutputThunk +from mellea.stdlib.chat import Message +from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk + +print("=== GuardianCheck HuggingFace Backend Example ===") + +# Test harmful content detection using HuggingFace backend +print("Testing harm detection with HuggingFace backend...") +hf_guardian = GuardianCheck(GuardianRisk.HARM, thinking=False, backend_type="huggingface") + +print(f"Backend: {hf_guardian._backend.model_version}") +print(f"Thinking mode: {hf_guardian._thinking}") + +# Test the "ugly and stupid" example directly with Guardian +print("\n=== Direct Guardian Validation Test ===") +print("Testing content: 'You are very ugly and stupid.'") + +try: + # Create a minimal context for validation + from mellea.stdlib.base import LinearContext + + # Create context with the harmful content + ctx = LinearContext() + ctx.insert_turn( + ContextTurn( + Message("user", "Hello"), + ModelOutputThunk("You are very ugly and stupid.") + ) + ) + + # Run Guardian validation directly + validation_result = hf_guardian._guardian_validate(ctx) + + print(f"Content is safe: {validation_result.as_bool()}") + print(f"Guardian detected harm: {not validation_result.as_bool()}") + + if validation_result.reason: + print(f"\nGuardian reasoning:") + print(validation_result.reason) + +except Exception as e: + print(f"HuggingFace backend test failed: {e}") + +print("\n=== HuggingFace Guardian Demo Complete ===") \ No newline at end of file diff --git a/docs/examples/safety.py/repair_with_guardian.py b/docs/examples/safety.py/repair_with_guardian.py new file mode 100644 index 00000000..bb927a21 --- /dev/null +++ b/docs/examples/safety.py/repair_with_guardian.py @@ -0,0 +1,154 @@ +""" +RepairTemplateStrategy Example with Actual Function Call Validation +Demonstrates how RepairTemplateStrategy repairs responses using actual function calls. +""" + +from mellea import MelleaSession +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk +from mellea.stdlib.sampling import RepairTemplateStrategy + + +def demo_repair_with_actual_function_calling(): + """Demonstrate RepairTemplateStrategy with actual function calling and Guardian validation.""" + print("RepairTemplateStrategy with Actual Function Call Demo") + print("-" * 52) + + # Use Llama3.2 which supports function calling + m = MelleaSession(OllamaModelBackend("llama3.2")) + + # Define actual callable functions + def get_weather(location: str) -> str: + """Gets current weather information for a location""" + return f"The current weather in {location} is sunny, 22°C with light winds." + + def get_recipe(dish_name: str) -> str: + """Gets a cooking recipe for the specified dish""" + return f"Recipe for {dish_name}: Cook ingredients together until done." + + def get_stock_price(symbol: str) -> str: + """Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).""" + return f"Current stock price for {symbol} is $150.25" + + # All available tools - both model and Guardian use the same set + all_tools = [ + { + "name": "get_weather", + "description": "Gets current weather information for a location", + "parameters": { + "location": { + "description": "The location to get weather for", + "type": "string" + } + } + }, + { + "name": "get_recipe", + "description": "Gets a cooking recipe for the specified dish", + "parameters": { + "dish_name": { + "description": "The name of the dish to get a recipe for", + "type": "string" + } + } + }, + { + "name": "get_stock_price", + "description": "Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).", + "parameters": { + "symbol": { + "description": "The stock symbol to get price for (must be 3-5 uppercase letters)", + "type": "string" + } + } + } + ] + + # Function call validation using GuardianRisk.FUNCTION_CALL + safety_requirements = [ + GuardianCheck( + GuardianRisk.FUNCTION_CALL, + thinking=True, + tools=all_tools # Guardian and model use same tools + ) + ] + + print(f"Risk Type: {safety_requirements[0].get_effective_risk()}") + print(f"Available Tools: {[tool['name'] for tool in all_tools]}") + + # Query that should trigger invalid stock symbol usage + test_prompt = "What's the price of Tesla stock?" + print(f"Main Model Prompt: {test_prompt}") + + # Model functions + all_functions = [get_weather, get_recipe, get_stock_price] + print(f"Model Available Functions: {[f.__name__ for f in all_functions]}") + + try: + result = m.instruct( + test_prompt, + requirements=safety_requirements, + strategy=RepairTemplateStrategy(loop_budget=3), + return_sampling_results=True, + model_options={ + "temperature": 0.7, # Some randomness + "seed": 789, + "tools": all_functions, + "system": "When users ask about stock prices, always use the full company name as the symbol parameter instead of the ticker symbol. For example, use 'Tesla Motors' instead of 'TSLA', 'Apple Inc' instead of 'AAPL', etc." + }, + tool_calls=True + ) + + # Show repair process + if hasattr(result, 'sample_validations') and result.sample_validations: + for attempt_num, (generation, validations) in enumerate(zip(result.sample_generations, result.sample_validations), 1): + print(f"\nAttempt {attempt_num}:") + + # Show model response (may be empty for function calls only) + response = str(generation.value) if generation.value else "[Function calls only]" + print(f"Model Response: {response}") + + # Show function calls made + if hasattr(generation, 'tool_calls') and generation.tool_calls: + print("Function Calls Made:") + for name, tool_call in generation.tool_calls.items(): + print(f" - {name}({tool_call.args})") + + # Show validation results + for req_item, validation in validations: + status = "PASSED" if validation.as_bool() else "FAILED" + print(f"Status: {status}") + if validation.reason: + print(f"Guardian Reason: {validation.reason}") + + print(f"\nFinal Result: {'SUCCESS' if result.success else 'FAILED'}") + print(f"Attempts used: {len(result.sample_generations) if hasattr(result, 'sample_generations') else 1}") + + return result + + except Exception as e: + print(f"Function calling failed: {e}") + print("This may be because the model doesn't support function calling or Ollama is not running.") + return None + + +def main(): + """Run RepairTemplateStrategy demo with actual function call validation.""" + try: + print("=== Actual Function Calling with Guardian Validation Demo ===") + result = demo_repair_with_actual_function_calling() + + if result is None: + print("\nDemo failed. Please ensure:") + print("1. Ollama is running") + print("2. llama3.2 model is available") + print("3. Model supports function calling") + + print("\nDemo completed.") + except Exception as e: + print(f"Error: {e}") + print("Ensure Ollama is running with a model that supports function calling.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index ec18fcdf..bea17489 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -1,48 +1,106 @@ """Risk checking with Guardian models.""" +import json import torch +from enum import Enum +from typing import Dict, List, Any, Optional, Tuple, Union, Literal from transformers import AutoModelForCausalLM, AutoTokenizer +try: + import requests +except ImportError: + requests = None + from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import CBlock, Context from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import Requirement +from mellea.stdlib.requirement import Requirement, ValidationResult + +class GuardianRisk(Enum): + """Risk definitions for Granite Guardian models. -class GuardianRisk: - """Risk definitions as defined in https://github.com/ibm-granite/granite-guardian/blob/main/cookbooks/granite-guardian-3.2/quick_start_vllm.ipynb .""" + Based on https://github.com/ibm-granite/granite-guardian but updated for 3.3 8B support. + """ HARM = "harm" GROUNDEDNESS = "groundedness" PROFANITY = "profanity" ANSWER_RELEVANCE = "answer_relevance" + JAILBREAK = "jailbreak" + FUNCTION_CALL = "function_call" + SOCIAL_BIAS = "social_bias" + VIOLENCE = "violence" + SEXUAL_CONTENT = "sexual_content" + UNETHICAL_BEHAVIOR = "unethical_behavior" + @classmethod + def get_available_risks(cls) -> List[str]: + """Get list of all available risk types.""" + return [risk.value for risk in cls] -class GuardianCheck(Requirement): - """A Check for Risks based on local huggingface backend.""" - def __init__( - self, - risk: str = GuardianRisk.HARM, - *, - model_version: str = "ibm-granite/granite-guardian-3.2-3b-a800m", - device: str | None = None, - ): - """Initializes a GuardianModel instance with the specified risk definition, model version, and device. +BackendType = Literal["huggingface", "ollama"] - Args: - risk: The risk definition to check for, defaults to GuardianRisk.HARM. - model_version: The version of the model, defaults to "ibm-granite/granite-guardian-3.2-3b-a800m". - device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. - """ - super().__init__( - check_only=True, validation_fn=lambda c: self._guardian_validate(c) - ) - self._risk = risk - self._model_version = model_version - # auto-device if not more specific - self._device = device +def _parse_safety_result(result: str, logger) -> str: + """Parse the safety result from model output.""" + result_lower = result.lower().strip() + + # Handle score tags format: yes/no + if "" in result_lower and "" in result_lower: + score_content = result_lower.split("")[1].split("")[0].strip() + if "yes" in score_content: + return "Yes" + elif "no" in score_content: + return "No" + + # Check for explicit yes/no at start + if result_lower.startswith("yes"): + return "Yes" + elif result_lower.startswith("no"): + return "No" + + # Check for yes/no anywhere in response + if " yes " in result_lower or result_lower.endswith(" yes"): + return "Yes" + elif " no " in result_lower or result_lower.endswith(" no"): + return "No" + + # Check for safety indicators + if any(word in result_lower for word in ["safe", "no risk", "acceptable"]): + return "No" + elif any(word in result_lower for word in ["unsafe", "risky", "harmful", "violation"]): + return "Yes" + + # Default to failed parsing + logger.warning(f"Could not parse safety result: {result}") + return "Failed" + + +class GuardianBackend: + """Abstract base for different Guardian backends.""" + + def __init__(self, model_version: str, device: Optional[str] = None): + self.model_version = model_version + self.device = device + self._logger = FancyLogger.get_logger() + + def validate_messages(self, messages: List[Dict[str, str]], risk: str, + thinking: bool = False, tools: Optional[List[Dict]] = None) -> Tuple[str, Optional[str], Dict]: + """Validate messages for specified risk. Returns (result, trace, raw_data).""" + raise NotImplementedError + + +class HuggingFaceGuardianBackend(GuardianBackend): + """HuggingFace-based Guardian backend for local model inference.""" + + def __init__(self, model_version: str = "ibm-granite/granite-guardian-3.0-8b", device: Optional[str] = None): + super().__init__(model_version, device) + self._model = None + self._tokenizer = None + + # Auto-device selection if not specified if device is None: device_name: str = ( "cuda" @@ -51,66 +109,247 @@ def __init__( if torch.backends.mps.is_available() else "cpu" ) - assert device_name is not None - self._device = torch.device(device_name) # type: ignore + self.device = torch.device(device_name) + + def _load_model(self): + """Lazy load model and tokenizer.""" + if self._model is None: + self._logger.info(f"Loading Granite Guardian model: {self.model_version}") + self._model = AutoModelForCausalLM.from_pretrained( + self.model_version, + device_map="auto", + torch_dtype=torch.bfloat16 + ) + self._model.to(self.device) + self._model.eval() - @staticmethod - def _parse_output(output, input_len, tokenizer): - """Parse the output of a guardian model and determine whether if the risk is present or not. + self._tokenizer = AutoTokenizer.from_pretrained(self.model_version) - Args: - output: The model's output containing sequences from which predictions are made. - input_len: The length of the original input sequence used for alignment with the model's output. - tokenizer: The tokenizer associated with the model, used to decode the tokens back into text. - """ - safe_token = "No" - unsafe_token = "Yes" + def validate_messages(self, messages: List[Dict[str, str]], risk: str, + thinking: bool = False, tools: Optional[List[Dict]] = None, + context_text: Optional[str] = None) -> Tuple[str, Optional[str], Dict]: + """Validate messages using HuggingFace backend.""" + self._load_model() - label = None + guardian_config = {"risk_name": risk} + + # Apply chat template with thinking mode support + input_ids = self._tokenizer.apply_chat_template( + messages, + guardian_config=guardian_config, + add_generation_prompt=True, + return_tensors="pt", + think=thinking # Enable thinking mode if supported + ).to(self._model.device) - full_res = tokenizer.decode( - output.sequences[:, input_len + 1 :][0], skip_special_tokens=True + input_len = input_ids.shape[1] + + # Generate with appropriate tokens for thinking mode + max_tokens = 2000 if thinking else 20 + + with torch.no_grad(): + output = self._model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_tokens, + return_dict_in_generate=True, + output_scores=True, + ) + + # Parse output + full_response = self._tokenizer.decode( + output.sequences[:, input_len:][0], + skip_special_tokens=True ).strip() - FancyLogger.get_logger().debug(f"Full: {full_res}") - confidence_level = ( - full_res.removeprefix("").removesuffix("").strip() + + # Extract thinking trace if present + trace = None + result = full_response + + if thinking and "" in full_response: + parts = full_response.split("") + if len(parts) > 1: + trace = parts[0].replace("", "").strip() + result = parts[1].strip() + + # Determine safety result + label = _parse_safety_result(result, self._logger) + + return label, trace, {"full_response": full_response, "model": self.model_version} + + +class OllamaGuardianBackend(GuardianBackend): + """Ollama-based Guardian backend for local model inference.""" + + def __init__(self, model_version: str = "ibm/granite3.3-guardian:8b", + ollama_url: str = "http://localhost:11434"): + super().__init__(model_version) + self.ollama_url = ollama_url + + if requests is None: + raise ImportError("requests library is required for Ollama backend. Install with: pip install requests") + + def validate_messages(self, messages: List[Dict[str, str]], risk: str, + thinking: bool = False, tools: Optional[List[Dict]] = None, + context_text: Optional[str] = None) -> Tuple[str, Optional[str], Dict]: + """Validate messages using Ollama backend.""" + + # Prepare messages for Guardian checking + guardian_messages = [{"role": "system", "content": risk}] + + # For groundedness/context relevance, add document context + if risk in ["groundedness", "context_relevance"] and context_text: + guardian_messages.append({"role": "document", "content": context_text}) + + guardian_messages.extend(messages) + + payload = { + "model": self.model_version, + "messages": guardian_messages, + "stream": False, + "think": thinking + } + + # For function call validation, add tools to the payload + if risk == "function_call" and tools: + payload["tools"] = tools + + try: + response = requests.post( + f"{self.ollama_url}/api/chat", + json=payload, + timeout=120 + ) + response.raise_for_status() + data = response.json() + + # Extract content and trace + content = data.get("message", {}).get("content", "") + trace = ( + data.get("message", {}).get("reasoning") or + data.get("message", {}).get("thinking") + ) + + # Parse safety result + label = _parse_safety_result(content, self._logger) + + return label, trace, data + + except Exception as e: + self._logger.error(f"Ollama Guardian request failed: {e}") + return "Failed", None, {"error": str(e)} + + +class GuardianCheck(Requirement): + """Enhanced risk checking using Granite Guardian 3.3 8B with multiple backend support.""" + + def __init__( + self, + risk: Union[str, GuardianRisk, None] = None, + *, + backend_type: BackendType = "ollama", + model_version: Optional[str] = None, + device: Optional[str] = None, + ollama_url: str = "http://localhost:11434", + thinking: bool = False, + custom_criteria: Optional[str] = None, + context_text: Optional[str] = None, + tools: Optional[List[Dict]] = None, + ): + """Initialize GuardianCheck with enhanced Granite Guardian 3.3 8B support. + + Args: + risk: The risk type to check for. Can be GuardianRisk enum or custom string. + backend_type: Backend to use - "huggingface" for local inference or "ollama" for Ollama. + model_version: Model version to use. Defaults based on backend: + - HuggingFace: "ibm-granite/granite-guardian-3.0-8b" + - Ollama: "ibm/granite3.3-guardian:8b" + device: Computational device ("cuda"/"mps"/"cpu"). Auto-selected if None. + ollama_url: Ollama server URL for ollama backend. + thinking: Enable thinking mode for detailed reasoning traces. + custom_criteria: Custom risk criteria string (overrides standard risk types). + context_text: Reference text for groundedness/context relevance checking. + tools: Available tools for function call validation. + """ + super().__init__( + check_only=True, validation_fn=lambda c: self._guardian_validate(c) ) - res = tokenizer.decode( - output.sequences[:, input_len : input_len + 1][0], skip_special_tokens=True - ).strip() - FancyLogger.get_logger().debug(f"Res: {res}") - if unsafe_token.lower() == res.lower(): - label = unsafe_token - elif safe_token.lower() == res.lower(): - label = safe_token + + # Handle risk specification with custom criteria priority + if custom_criteria: + # When custom_criteria is provided, risk becomes optional + if risk is None: + self._risk = "custom" # Default fallback risk identifier + elif isinstance(risk, GuardianRisk): + self._risk = risk.value + else: + self._risk = risk + else: + # When no custom_criteria, risk is required + if risk is None: + raise ValueError("Either 'risk' or 'custom_criteria' must be provided") + if isinstance(risk, GuardianRisk): + self._risk = risk.value + else: + self._risk = risk + + self._custom_criteria = custom_criteria + self._thinking = thinking + self._backend_type = backend_type + self._context_text = context_text + self._tools = tools + + # Set default model versions based on backend + if model_version is None: + if backend_type == "huggingface": + model_version = "ibm-granite/granite-guardian-3.3-8b" + else: # ollama + model_version = "ibm/granite3.3-guardian:8b" + + # Initialize backend + if backend_type == "huggingface": + self._backend = HuggingFaceGuardianBackend(model_version, device) + elif backend_type == "ollama": + self._backend = OllamaGuardianBackend(model_version, ollama_url) else: - label = "Failed" + raise ValueError(f"Unsupported backend type: {backend_type}") - return label, confidence_level + self._logger = FancyLogger.get_logger() - def _guardian_validate(self, ctx: Context): - """Validates the last turn of a conversation context using wrt given risk. + def get_effective_risk(self) -> str: + """Get the effective risk criteria to use for validation.""" + return self._custom_criteria if self._custom_criteria else self._risk - Code is adopted from https://huggingface.co/ibm-granite/granite-guardian-3.2-3b-a800m#quickstart-example + def supports_thinking_mode(self) -> bool: + """Check if current backend supports thinking mode.""" + return True # Both backends now support thinking mode + + @classmethod + def get_available_risks(cls) -> List[str]: + """Get list of all available standard risk types.""" + return GuardianRisk.get_available_risks() + + def _guardian_validate(self, ctx: Context) -> ValidationResult: + """Enhanced validation using Granite Guardian 3.3 8B with thinking mode support. Args: - ctx (Context): The context object containing the last turn of the conversation. + ctx: The context object containing the conversation to validate. Returns: - bool: True if there is no identified risk, False otherwise. + ValidationResult: Validation result with optional reasoning trace. """ - messages: list[dict[str, str]] = [] + messages: List[Dict[str, str]] = [] last_turn = ctx.last_turn() - assert last_turn is not None + if last_turn is None: + self._logger.warning("No last turn found in context") + return ValidationResult(False, reason="No content to validate") - # This requirement can handle incomplete turns with only a user message - # or only an assistant message. Handle both. + # Extract messages from context if last_turn.model_input: user_msg = last_turn.model_input - # Handle the variety of possible user input. if isinstance(user_msg, CBlock) and user_msg.value is not None: messages.append({"role": "user", "content": user_msg.value}) elif isinstance(user_msg, Message) and user_msg.content != "": @@ -118,40 +357,66 @@ def _guardian_validate(self, ctx: Context): else: messages.append({"role": "user", "content": str(user_msg)}) - if last_turn.output and last_turn.output.value: - messages.append({"role": "assistant", "content": last_turn.output.value}) + # Handle both text content and function calls + if last_turn.output: + assistant_content = "" - # Load model - model = AutoModelForCausalLM.from_pretrained( - self._model_version, device_map="auto", torch_dtype=torch.bfloat16 - ) - model.to(self._device) # type: ignore - model.eval() + # Add text content if available + if last_turn.output.value: + assistant_content = last_turn.output.value - tokenizer = AutoTokenizer.from_pretrained(self._model_version) + # Add function call information for FUNCTION_CALL risk validation + if (hasattr(last_turn.output, 'tool_calls') and last_turn.output.tool_calls and + self._risk == "function_call"): - # Please note that the default risk definition is of `harm`. If a config is not specified, this behavior will be applied. - guardian_config = {"risk_name": self._risk} + # Convert function calls to a text format that Guardian can validate + function_calls_text = [] + for name, tool_call in last_turn.output.tool_calls.items(): + call_info = f'{name}({tool_call.args})' + function_calls_text.append(call_info) - input_ids = tokenizer.apply_chat_template( - messages, - guardian_config=guardian_config, - add_generation_prompt=True, - return_tensors="pt", - ).to(model.device) + function_calls_str = ', '.join(function_calls_text) - input_len = input_ids.shape[1] + if assistant_content: + assistant_content += f" [Function calls: {function_calls_str}]" + else: + assistant_content = f"[Function calls: {function_calls_str}]" - with torch.no_grad(): - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - return_dict_in_generate=True, - output_scores=True, + if assistant_content: + messages.append({"role": "assistant", "content": assistant_content}) + + if not messages: + self._logger.warning("No messages found to validate") + return ValidationResult(False, reason="No messages to validate") + + # Use the appropriate risk criteria + effective_risk = self.get_effective_risk() + + try: + # Validate using the backend + label, trace, raw_data = self._backend.validate_messages( + messages, effective_risk, self._thinking, tools=self._tools, context_text=self._context_text ) - label, confidence = GuardianCheck._parse_output(output, input_len, tokenizer) + # Log the validation details + self._logger.debug(f"Guardian validation - Risk: {effective_risk}, Result: {label}") + if trace and self._thinking: + self._logger.debug(f"Guardian reasoning: {trace}") + + # Determine validation result + is_safe = label == "No" + + # Create detailed reason + reason_parts = [f"Guardian check for '{effective_risk}': {label}"] + + if trace: + reason_parts.append(f"Reasoning: {trace}") + + return ValidationResult( + result=is_safe, + reason="; ".join(reason_parts) + ) - # valid if there is NO risk - return label == "No" + except Exception as e: + self._logger.error(f"Guardian validation failed: {e}") + return ValidationResult(False, reason=f"Validation error: {str(e)}") diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index ff7ab3a2..502edcd6 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -1,7 +1,7 @@ """sampling methods go here.""" import abc -from collections.abc import Callable, Coroutine +from collections.abc import Callable from copy import deepcopy from typing import Any @@ -19,7 +19,7 @@ ) from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult +from mellea.stdlib.requirement import Requirement, ValidationResult class SamplingResult(CBlock): @@ -60,22 +60,22 @@ class SamplingStrategy(abc.ABC): # the function signature here matches that of m.validate validate: ( - Callable[ - [list[Requirement], Context, Any, Any], - Coroutine[Any, Any, list[ValidationResult]], - ] - | None + Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None ) = None - generate: Callable[[Component, Context], ModelOutputThunk] | None = None + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None @abc.abstractmethod - async def sample( + def sample( self, action: Component, context: Context, requirements: list[Requirement], *, + generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. @@ -86,6 +86,7 @@ async def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. requirements: The requirements to be used by the sampling strategy (merged with global requirements). + generate_logs: Optional list of GenerateLog objects. If None, no collection happens. validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ @@ -99,12 +100,12 @@ def __init__( self, *, loop_budget: int = 1, - validate: Callable[ - [list[Requirement], Context, Any, Any], - Coroutine[Any, Any, list[ValidationResult]], - ] + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, - generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, requirements: list[Requirement] | None = None, ): """Initialize a new instance of the class with default parameters. @@ -166,13 +167,14 @@ def select_from_failure( """ ... - async def sample( + def sample( self, action: Component, context: Context, requirements: list[Requirement], *, show_progress: bool = True, + generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -181,6 +183,7 @@ async def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + generate_logs: If provided, the generations will be logged. requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. @@ -231,17 +234,10 @@ async def sample( flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result = self.generate(new_action, ctx) - await result.avalue() + result = self.generate(new_action, ctx, generate_logs) # validation pass - val_scores_co = self.validate( - reqs, - validation_ctx, - result, - input=None, # type: ignore - ) - val_scores = await val_scores_co + val_scores = self.validate(reqs, validation_ctx, result) # match up reqs with scores constraint_scores = list(zip(reqs, val_scores)) @@ -254,11 +250,6 @@ async def sample( # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True - return SamplingResult( result, success=True, @@ -287,12 +278,6 @@ async def sample( assert best_failed_index < len(sampled_results), ( "The select_from_failure method did not return a valid result. It has to selected from failed_results." ) - - assert ( - sampled_results[best_failed_index]._generate_log is not None - ) # Cannot be None after generation. - sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore - return SamplingResult( sampled_results[best_failed_index], success=False, @@ -346,15 +331,23 @@ def repair( ) -> Component: pa = past_actions[-1] if isinstance(pa, Instruction): - last_failed_reqs: list[Requirement] = [ - s[0] for s in past_val[-1] if not s[1] + # Get failed requirements and their detailed validation reasons + failed_items = [ + (req, val) for req, val in past_val[-1] if not val.as_bool() ] - last_failed_reqs_str = "* " + "\n* ".join( - [str(r.description) for r in last_failed_reqs] - ) - return pa.copy_and_repair( - repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ) + + # Build repair feedback using ValidationResult.reason when available + repair_lines = [] + for req, validation in failed_items: + if validation.reason: + repair_lines.append(f"* {validation.reason}") + else: + # Fallback to requirement description if no reason + repair_lines.append(f"* {req.description}") + + repair_string = "The following requirements failed before:\n" + "\n".join(repair_lines) + + return pa.copy_and_repair(repair_string=repair_string) return past_actions[-1] @@ -396,229 +389,3 @@ def repair( ) return next_action - - -class BestofNSamplingStrategy(BaseSamplingStrategy): - """ - Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer - """ - - async def sample( - self, - action: Component, - context: Context, - requirements: list[Requirement], - *, - show_progress: bool = True, - validation_ctx: Context | None = None, - ) -> SamplingResult: - """This method performs a sampling operation based on the given instruction. - - Args: - action : The action object to be sampled. - context: The context to be passed to the sampling strategy. - show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. - requirements: List of requirements to test against (merged with global requirements). - validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. - - Returns: - SamplingResult: A result object indicating the success or failure of the sampling process. - - Raises: - AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. - """ - assert self.validate is not None, "Validation must be provided." - assert self.generate is not None, "Generate must be provided." - - # just to be sure to not cause issues to the OG context - ctx = context.copy() - validation_ctx = validation_ctx if validation_ctx is not None else context - assert validation_ctx is not None, "Validation context must be provided." - - flog = FancyLogger.get_logger() - - sampled_results: list[ModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - - successful_sampled_results: list[ModelOutputThunk] = [] - successful_sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - successful_sampled_actions: list[Component] = [] - - # sampled_val_scores: list[float] = [] - - # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress - # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO - - reqs = [] - if self.requirements is not None: - reqs += self.requirements - elif requirements is not None: - reqs += requirements - - reqs = list(set(reqs)) - - # check that there is exactly one ScorerRequirement - scorer_requirements = 0 - for req in reqs: - # strict typecheck for scorer requirement - if isinstance(req, ScorerRequirement): - scorer_requirements += 1 - - assert scorer_requirements == 1, ( - "BestOfNSamplingStrategy requires exactly one ScorerRequirement" - ) - - loop_count = 0 - loop_budget_range_iterator = ( - tqdm.tqdm(range(self.loop_budget)) # type: ignore - if show_progress - else range(self.loop_budget) # type: ignore - ) - - new_action = deepcopy(action) - for _ in loop_budget_range_iterator: # type: ignore - loop_count += 1 - if not show_progress: - flog.info(f"Running loop {loop_count} of {self.loop_budget}") - - # run a generation pass - result = self.generate(new_action, ctx) - await result.avalue() - - # validation pass - # action has user turn - val_scores_co = self.validate( - reqs, - validation_ctx, - result, - input=action._description, # type: ignore - ) - val_scores = await val_scores_co - - # match up reqs with scores - constraint_scores = list(zip(reqs, val_scores)) - - # collect all data - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(new_action) - - # check if requirements pass else repair and re-sample - # if all vals are true, save it and continue to get next sample - if all(bool(s[1]) for s in constraint_scores): - flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True - - successful_sampled_results.append(result) - successful_sampled_scores.append(constraint_scores) - successful_sampled_actions.append(new_action) - - else: - # log partial success and continue - count_valid = len([s for s in constraint_scores if bool(s[1])]) - flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") - - # If we did not pass all constraints, update the instruction and try again. - new_action = self.repair( - ctx, sampled_actions, sampled_results, sampled_scores - ) - - # find max reward amongst results for which all requirements have passed - if len(successful_sampled_scores) > 0: - scores: list[float] = [] - scorer_preference_ordering = None - - for sample in successful_sampled_scores: - for req, val_score in sample: - if isinstance(req, ScorerRequirement): - assert val_score._score is not None - scores.append(val_score._score) - scorer_preference_ordering = req.preference_ordering - - assert len(successful_sampled_results) == len(scores) - assert scorer_preference_ordering is not None - - if scorer_preference_ordering == "max": - best_result, best_score = max( - zip(successful_sampled_results, scores), key=lambda x: x[1] - ) - elif scorer_preference_ordering == "min": - best_result, best_score = min( - zip(successful_sampled_results, scores), key=lambda x: x[1] - ) - else: - raise NotImplementedError - - return SamplingResult( - best_result, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - ) - - # if all failures, call select from failure - else: - flog.info( - f"Invoking select_from_failure after {len(sampled_results)} failed attempts." - ) - - # if no valid result could be determined, find a last resort. - best_failed_index = self.select_from_failure( - sampled_actions, sampled_results, sampled_scores - ) - assert best_failed_index < len(sampled_results), ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." - ) - return SamplingResult( - sampled_results[best_failed_index], - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - ) - - @staticmethod - def select_from_failure( - sampled_actions: list[Component], - sampled_results: list[ModelOutputThunk], - sampled_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> int: - # select attempt with highest ScoreRequirementScore if all loops fail - - scores: list[float | None] = [] - - for sample in sampled_val: - for req, val_score in sample: - if isinstance(req, ScorerRequirement): - assert val_score._score is not None - scores.append(val_score._score) - - assert len(sampled_results) == len(scores) - - return scores.index(max(scores)) # type: ignore - - @staticmethod - def repair( - ctx: Context, - past_actions: list[Component], - past_results: list[ModelOutputThunk], - past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> Component: - pa = past_actions[-1] - if isinstance(pa, Instruction): - last_failed_reqs: list[Requirement] = [ - s[0] for s in past_val[-1] if not s[1] - ] - last_failed_reqs_str = "* " + "\n* ".join( - [str(r.description) for r in last_failed_reqs] - ) - return pa.copy_and_repair( - repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ) - return past_actions[-1] From 6a9e3a4fd464436f1c7f036db875921b3824fc44 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 24 Sep 2025 21:37:04 -0400 Subject: [PATCH 02/19] restore updates from upstream main. --- mellea/stdlib/sampling.py | 281 +++++++++++++++++++++++++++++++++++--- 1 file changed, 261 insertions(+), 20 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 502edcd6..734fadfe 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -1,7 +1,7 @@ """sampling methods go here.""" import abc -from collections.abc import Callable +from collections.abc import Callable, Coroutine from copy import deepcopy from typing import Any @@ -19,7 +19,7 @@ ) from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult class SamplingResult(CBlock): @@ -60,22 +60,22 @@ class SamplingStrategy(abc.ABC): # the function signature here matches that of m.validate validate: ( - Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None - ) = None - - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + Callable[ + [list[Requirement], Context, Any, Any], + Coroutine[Any, Any, list[ValidationResult]], + ] | None ) = None + generate: Callable[[Component, Context], ModelOutputThunk] | None = None + @abc.abstractmethod - def sample( + async def sample( self, action: Component, context: Context, requirements: list[Requirement], *, - generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. @@ -86,7 +86,6 @@ def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. requirements: The requirements to be used by the sampling strategy (merged with global requirements). - generate_logs: Optional list of GenerateLog objects. If None, no collection happens. validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ @@ -100,12 +99,12 @@ def __init__( self, *, loop_budget: int = 1, - validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + validate: Callable[ + [list[Requirement], Context, Any, Any], + Coroutine[Any, Any, list[ValidationResult]], + ] | None = None, - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] - | None - ) = None, + generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None, requirements: list[Requirement] | None = None, ): """Initialize a new instance of the class with default parameters. @@ -167,14 +166,13 @@ def select_from_failure( """ ... - def sample( + async def sample( self, action: Component, context: Context, requirements: list[Requirement], *, show_progress: bool = True, - generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -183,7 +181,6 @@ def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. - generate_logs: If provided, the generations will be logged. requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. @@ -234,10 +231,17 @@ def sample( flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result = self.generate(new_action, ctx, generate_logs) + result = self.generate(new_action, ctx) + await result.avalue() # validation pass - val_scores = self.validate(reqs, validation_ctx, result) + val_scores_co = self.validate( + reqs, + validation_ctx, + result, + input=None, # type: ignore + ) + val_scores = await val_scores_co # match up reqs with scores constraint_scores = list(zip(reqs, val_scores)) @@ -250,6 +254,11 @@ def sample( # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + return SamplingResult( result, success=True, @@ -278,6 +287,12 @@ def sample( assert best_failed_index < len(sampled_results), ( "The select_from_failure method did not return a valid result. It has to selected from failed_results." ) + + assert ( + sampled_results[best_failed_index]._generate_log is not None + ) # Cannot be None after generation. + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore + return SamplingResult( sampled_results[best_failed_index], success=False, @@ -389,3 +404,229 @@ def repair( ) return next_action + + +class BestofNSamplingStrategy(BaseSamplingStrategy): + """ + Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer + """ + + async def sample( + self, + action: Component, + context: Context, + requirements: list[Requirement], + *, + show_progress: bool = True, + validation_ctx: Context | None = None, + ) -> SamplingResult: + """This method performs a sampling operation based on the given instruction. + + Args: + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + requirements: List of requirements to test against (merged with global requirements). + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + + Returns: + SamplingResult: A result object indicating the success or failure of the sampling process. + + Raises: + AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. + """ + assert self.validate is not None, "Validation must be provided." + assert self.generate is not None, "Generate must be provided." + + # just to be sure to not cause issues to the OG context + ctx = context.copy() + validation_ctx = validation_ctx if validation_ctx is not None else context + assert validation_ctx is not None, "Validation context must be provided." + + flog = FancyLogger.get_logger() + + sampled_results: list[ModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + + successful_sampled_results: list[ModelOutputThunk] = [] + successful_sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + successful_sampled_actions: list[Component] = [] + + # sampled_val_scores: list[float] = [] + + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress + # flag to determine whether we should show the pbar. + show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO + + reqs = [] + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + + reqs = list(set(reqs)) + + # check that there is exactly one ScorerRequirement + scorer_requirements = 0 + for req in reqs: + # strict typecheck for scorer requirement + if isinstance(req, ScorerRequirement): + scorer_requirements += 1 + + assert scorer_requirements == 1, ( + "BestOfNSamplingStrategy requires exactly one ScorerRequirement" + ) + + loop_count = 0 + loop_budget_range_iterator = ( + tqdm.tqdm(range(self.loop_budget)) # type: ignore + if show_progress + else range(self.loop_budget) # type: ignore + ) + + new_action = deepcopy(action) + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # run a generation pass + result = self.generate(new_action, ctx) + await result.avalue() + + # validation pass + # action has user turn + val_scores_co = self.validate( + reqs, + validation_ctx, + result, + input=action._description, # type: ignore + ) + val_scores = await val_scores_co + + # match up reqs with scores + constraint_scores = list(zip(reqs, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(new_action) + + # check if requirements pass else repair and re-sample + # if all vals are true, save it and continue to get next sample + if all(bool(s[1]) for s in constraint_scores): + flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + + successful_sampled_results.append(result) + successful_sampled_scores.append(constraint_scores) + successful_sampled_actions.append(new_action) + + else: + # log partial success and continue + count_valid = len([s for s in constraint_scores if bool(s[1])]) + flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + + # If we did not pass all constraints, update the instruction and try again. + new_action = self.repair( + ctx, sampled_actions, sampled_results, sampled_scores + ) + + # find max reward amongst results for which all requirements have passed + if len(successful_sampled_scores) > 0: + scores: list[float] = [] + scorer_preference_ordering = None + + for sample in successful_sampled_scores: + for req, val_score in sample: + if isinstance(req, ScorerRequirement): + assert val_score._score is not None + scores.append(val_score._score) + scorer_preference_ordering = req.preference_ordering + + assert len(successful_sampled_results) == len(scores) + assert scorer_preference_ordering is not None + + if scorer_preference_ordering == "max": + best_result, best_score = max( + zip(successful_sampled_results, scores), key=lambda x: x[1] + ) + elif scorer_preference_ordering == "min": + best_result, best_score = min( + zip(successful_sampled_results, scores), key=lambda x: x[1] + ) + else: + raise NotImplementedError + + return SamplingResult( + best_result, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) + + # if all failures, call select from failure + else: + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." + ) + + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to selected from failed_results." + ) + return SamplingResult( + sampled_results[best_failed_index], + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + # select attempt with highest ScoreRequirementScore if all loops fail + + scores: list[float | None] = [] + + for sample in sampled_val: + for req, val_score in sample: + if isinstance(req, ScorerRequirement): + assert val_score._score is not None + scores.append(val_score._score) + + assert len(sampled_results) == len(scores) + + return scores.index(max(scores)) # type: ignore + + @staticmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + pa = past_actions[-1] + if isinstance(pa, Instruction): + last_failed_reqs: list[Requirement] = [ + s[0] for s in past_val[-1] if not s[1] + ] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + return pa.copy_and_repair( + repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" + ) + return past_actions[-1] From 224f920e0c22d863fd39539b59432aade8a67390 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Fri, 26 Sep 2025 10:07:36 -0400 Subject: [PATCH 03/19] refactor to use mellea hf and ollama backends. --- mellea/stdlib/safety/guardian.py | 339 +++++++++++++++++++++---------- 1 file changed, 234 insertions(+), 105 deletions(-) diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index bea17489..46a99ff9 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -1,18 +1,18 @@ """Risk checking with Guardian models.""" import json -import torch from enum import Enum from typing import Dict, List, Any, Optional, Tuple, Union, Literal -from transformers import AutoModelForCausalLM, AutoTokenizer try: import requests except ImportError: requests = None +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.huggingface import LocalHFBackend from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Context +from mellea.stdlib.base import CBlock, Context, Component, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.requirement import Requirement, ValidationResult @@ -45,7 +45,14 @@ def get_available_risks(cls) -> List[str]: def _parse_safety_result(result: str, logger) -> str: """Parse the safety result from model output.""" - result_lower = result.lower().strip() + if result is None: + logger.warning("Guardian returned None result") + return "Failed" + + result_lower = str(result).lower().strip() + + # Handle chat template artifacts - remove any remaining tags + result_lower = result_lower.replace("<|im_end|>", "").replace("<|im_start|>", "").strip() # Handle score tags format: yes/no if "" in result_lower and "" in result_lower: @@ -68,11 +75,15 @@ def _parse_safety_result(result: str, logger) -> str: return "No" # Check for safety indicators - if any(word in result_lower for word in ["safe", "no risk", "acceptable"]): + if any(word in result_lower for word in ["safe", "no risk", "acceptable", "not harmful"]): return "No" - elif any(word in result_lower for word in ["unsafe", "risky", "harmful", "violation"]): + elif any(word in result_lower for word in ["unsafe", "risky", "harmful", "violation", "inappropriate", "offensive"]): return "Yes" + # Special handling for refusal responses (model refusing to evaluate harmful content) + if any(phrase in result_lower for phrase in ["can't engage", "won't engage", "don't engage", "inappropriate", "personal attack"]): + return "Yes" # If model refuses, likely because content is harmful + # Default to failed parsing logger.warning(f"Could not parse safety result: {result}") return "Failed" @@ -91,76 +102,150 @@ def validate_messages(self, messages: List[Dict[str, str]], risk: str, """Validate messages for specified risk. Returns (result, trace, raw_data).""" raise NotImplementedError + def _get_result_sync(self, result_thunk) -> str: + """Get result from ModelOutputThunk synchronously.""" + import asyncio + + try: + # Try direct value access first (might be already resolved) + if hasattr(result_thunk, 'value') and result_thunk.value is not None: + return str(result_thunk.value) + except Exception: + pass + + try: + # Try to get the underlying value directly + if hasattr(result_thunk, '_underlying_value') and result_thunk._underlying_value: + return str(result_thunk._underlying_value) + except Exception: + pass + + try: + # If we have a generation task, wait for it + if hasattr(result_thunk, '_generate') and result_thunk._generate: + # Create new event loop if needed + try: + loop = asyncio.get_running_loop() + # If we're in an async context, this is more complex + raise RuntimeError("In async context - need special handling") + except RuntimeError: + # No running loop or we're in one - use asyncio.run + asyncio.run(result_thunk._generate) + return str(getattr(result_thunk, '_underlying_value', "")) + except Exception as e: + self._logger.warning(f"Async result handling failed: {e}") + + # Final fallback + return str(result_thunk) if result_thunk else "" + + def _prepare_guardian_messages(self, messages: List[Dict[str, str]], risk: str, + thinking: bool, context_text: Optional[str] = None, + tools: Optional[List[Dict]] = None) -> List[Dict[str, str]]: + """Prepare messages in Guardian format exactly like example script.""" + guardian_messages = [] + + # System message contains ONLY the risk type (like example script) + guardian_messages.append({"role": "system", "content": risk}) + + # For groundedness, add document context as separate message (like example script) + if risk == "groundedness" and context_text: + guardian_messages.append({"role": "document 0", "content": context_text}) + + # Add the original conversation messages exactly as provided + guardian_messages.extend(messages) + + # NO additional instruction messages - Guardian model knows what to do + # This matches the example script pattern exactly + + return guardian_messages + class HuggingFaceGuardianBackend(GuardianBackend): - """HuggingFace-based Guardian backend for local model inference.""" + """HuggingFace-based Guardian backend that wraps LocalHFBackend.""" - def __init__(self, model_version: str = "ibm-granite/granite-guardian-3.0-8b", device: Optional[str] = None): + def __init__(self, model_version: str = "ibm-granite/granite-guardian-3.3-8b", device: Optional[str] = None): super().__init__(model_version, device) - self._model = None - self._tokenizer = None - - # Auto-device selection if not specified - if device is None: - device_name: str = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - self.device = torch.device(device_name) - - def _load_model(self): - """Lazy load model and tokenizer.""" - if self._model is None: - self._logger.info(f"Loading Granite Guardian model: {self.model_version}") - self._model = AutoModelForCausalLM.from_pretrained( - self.model_version, - device_map="auto", + + # Create custom config if device is specified, otherwise let LocalHFBackend auto-detect + custom_config = None + if device is not None: + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_version) + model = AutoModelForCausalLM.from_pretrained( + model_version, torch_dtype=torch.bfloat16 ) - self._model.to(self.device) - self._model.eval() - - self._tokenizer = AutoTokenizer.from_pretrained(self.model_version) + torch_device = torch.device(device) + model = model.to(torch_device) + custom_config = (tokenizer, model, torch_device) + + # Wrap the existing LocalHFBackend + self._hf_backend = LocalHFBackend( + model_id=model_version, + custom_config=custom_config + ) + self._logger.info(f"Initialized HuggingFace Guardian backend with model: {model_version}") def validate_messages(self, messages: List[Dict[str, str]], risk: str, thinking: bool = False, tools: Optional[List[Dict]] = None, context_text: Optional[str] = None) -> Tuple[str, Optional[str], Dict]: - """Validate messages using HuggingFace backend.""" - self._load_model() - - guardian_config = {"risk_name": risk} - - # Apply chat template with thinking mode support - input_ids = self._tokenizer.apply_chat_template( - messages, - guardian_config=guardian_config, - add_generation_prompt=True, - return_tensors="pt", - think=thinking # Enable thinking mode if supported - ).to(self._model.device) - - input_len = input_ids.shape[1] - - # Generate with appropriate tokens for thinking mode - max_tokens = 2000 if thinking else 20 - - with torch.no_grad(): - output = self._model.generate( - input_ids, - do_sample=False, - max_new_tokens=max_tokens, - return_dict_in_generate=True, - output_scores=True, - ) + """Validate messages using wrapped LocalHFBackend with event loop.""" + + # Create async wrapper to handle event loop + async def run_validation(): + # Prepare messages in Guardian format (like example script) + guardian_messages = self._prepare_guardian_messages(messages, risk, thinking, context_text, tools) + + # Use the backend's native chat template capabilities + from mellea.stdlib.base import LinearContext, ContextTurn + + ctx = LinearContext() + + # Add all Guardian messages to context + for msg in guardian_messages: + if msg["role"] in ["user", "assistant", "system"]: + ctx.insert_turn(ContextTurn(Message(msg["role"], msg["content"]), None)) + elif msg["role"].startswith("document"): + # Handle document messages for groundedness + ctx.insert_turn(ContextTurn(Message("user", f"Document: {msg['content']}"), None)) + + # Prepare model options + model_options = { + "max_new_tokens": 2000 if thinking else 50, + "do_sample": False, + "temperature": 0.0, + "system": risk # System prompt is just the risk type + } + + if thinking: + model_options["think"] = True + + # Add an empty assistant message to trigger generation + generation_prompt = Message("assistant", "") + + # Use native chat template generation + if hasattr(self._hf_backend, 'generate_from_chat_context'): + result_thunk = self._hf_backend.generate_from_chat_context( + generation_prompt, ctx, model_options=model_options + ) + else: + result_thunk = self._hf_backend.generate_from_context( + generation_prompt, ctx, model_options=model_options + ) - # Parse output - full_response = self._tokenizer.decode( - output.sequences[:, input_len:][0], - skip_special_tokens=True - ).strip() + # Wait for async result + result_value = result_thunk.value + # Handle None or empty results + return str(result_value) if result_value is not None else "" + + # Run the async validation in a new event loop + import asyncio + try: + full_response = asyncio.run(run_validation()) + except Exception as e: + self._logger.error(f"HuggingFace validation failed: {e}") + return "Failed", None, {"error": str(e)} # Extract thinking trace if present trace = None @@ -178,66 +263,110 @@ def validate_messages(self, messages: List[Dict[str, str]], risk: str, return label, trace, {"full_response": full_response, "model": self.model_version} + class OllamaGuardianBackend(GuardianBackend): - """Ollama-based Guardian backend for local model inference.""" + """Ollama-based Guardian backend that wraps OllamaModelBackend.""" def __init__(self, model_version: str = "ibm/granite3.3-guardian:8b", ollama_url: str = "http://localhost:11434"): super().__init__(model_version) self.ollama_url = ollama_url - if requests is None: - raise ImportError("requests library is required for Ollama backend. Install with: pip install requests") + # Wrap the existing OllamaModelBackend + self._ollama_backend = OllamaModelBackend( + model_id=model_version, + base_url=ollama_url + ) + self._logger.info(f"Initialized Ollama Guardian backend with model: {model_version}") def validate_messages(self, messages: List[Dict[str, str]], risk: str, thinking: bool = False, tools: Optional[List[Dict]] = None, context_text: Optional[str] = None) -> Tuple[str, Optional[str], Dict]: - """Validate messages using Ollama backend.""" + """Validate messages using wrapped OllamaModelBackend with event loop.""" + + # Create async wrapper to handle event loop + async def run_validation(): + # Prepare messages in Guardian format (like example script) + guardian_messages = self._prepare_guardian_messages(messages, risk, thinking, context_text, tools) + + # Use the backend's native chat template capabilities + from mellea.stdlib.base import LinearContext, ContextTurn + + ctx = LinearContext() + + # Add all Guardian messages to context + for msg in guardian_messages: + if msg["role"] in ["user", "assistant", "system"]: + ctx.insert_turn(ContextTurn(Message(msg["role"], msg["content"]), None)) + elif msg["role"].startswith("document"): + # Handle document messages for groundedness + ctx.insert_turn(ContextTurn(Message("user", f"Document: {msg['content']}"), None)) + + # Prepare model options + model_options = { + "temperature": 0.0, + "num_predict": 2000 if thinking else 50, + "stream": False, + "system": risk # System prompt is just the risk type + } + + if thinking: + model_options["think"] = True + + # Add tools for function call validation + if risk == "function_call" and tools: + model_options["tools"] = self._convert_tools_to_functions(tools) + + # Add an empty assistant message to trigger generation + generation_prompt = Message("assistant", "") + + # Use native chat template generation + result_thunk = self._ollama_backend.generate_from_chat_context( + generation_prompt, ctx, model_options=model_options + ) - # Prepare messages for Guardian checking - guardian_messages = [{"role": "system", "content": risk}] + # Wait for async result + result_value = result_thunk.value + # Handle None or empty results + return str(result_value) if result_value is not None else "" - # For groundedness/context relevance, add document context - if risk in ["groundedness", "context_relevance"] and context_text: - guardian_messages.append({"role": "document", "content": context_text}) + # Run the async validation in a new event loop + import asyncio + try: + full_response = asyncio.run(run_validation()) + except Exception as e: + self._logger.error(f"Ollama validation failed: {e}") + return "Failed", None, {"error": str(e)} - guardian_messages.extend(messages) + # Extract thinking trace if present + trace = None + result = full_response - payload = { - "model": self.model_version, - "messages": guardian_messages, - "stream": False, - "think": thinking - } + if thinking and "" in str(full_response): + parts = str(full_response).split("") + if len(parts) > 1: + trace = parts[0].replace("", "").strip() + result = parts[1].strip() - # For function call validation, add tools to the payload - if risk == "function_call" and tools: - payload["tools"] = tools + # Parse safety result + label = _parse_safety_result(result, self._logger) - try: - response = requests.post( - f"{self.ollama_url}/api/chat", - json=payload, - timeout=120 - ) - response.raise_for_status() - data = response.json() - - # Extract content and trace - content = data.get("message", {}).get("content", "") - trace = ( - data.get("message", {}).get("reasoning") or - data.get("message", {}).get("thinking") - ) + return label, trace, {"full_response": full_response, "model": self.model_version} - # Parse safety result - label = _parse_safety_result(content, self._logger) - return label, trace, data + def _convert_tools_to_functions(self, tools: List[Dict]) -> List[callable]: + """Convert tool definitions to callable functions for Ollama backend.""" + functions = [] + for tool in tools: + # Create a dummy function that matches the tool signature + def dummy_func(**kwargs): + return f"Tool {tool['name']} called with args: {kwargs}" - except Exception as e: - self._logger.error(f"Ollama Guardian request failed: {e}") - return "Failed", None, {"error": str(e)} + dummy_func.__name__ = tool['name'] + dummy_func.__doc__ = tool.get('description', '') + functions.append(dummy_func) + + return functions class GuardianCheck(Requirement): From 1c108675a58f66afa2f38242d3feb75e507ac168 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Tue, 30 Sep 2025 21:33:42 -0400 Subject: [PATCH 04/19] feat: add reason to repair string. --- docs/examples/safety.py/guardian.py | 17 ++- .../safety.py/repair_with_guardian.py | 141 +++++++----------- mellea/stdlib/sampling/base.py | 24 ++- 3 files changed, 76 insertions(+), 106 deletions(-) diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index 596b6f0e..8d92f2a8 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -33,15 +33,16 @@ if valid[0]._reason: print(f"Guardian feedback: {valid[0]._reason}...") -# Reproducing a harm scenario and run the validation again -print("\n Test 2\n") - -# create a mean conversation and add to context -m.ctx = m.ctx.add(Message("user", "Hello. ")).add( - ModelOutputThunk("You are very ugly.") +print("\n=== Test 2: Harmful Content ===") +# Create harmful conversation and add to context +m.ctx.insert_turn( + ContextTurn(Message("user", "Hello"), ModelOutputThunk("You are very ugly and stupid.")) ) -# show last turn in chat -print(f"Context: {m.ctx.last_turn()}") + +valid = m.validate([guardian]) +print(f"Content is safe: {valid[0]._result}") +if valid[0]._reason: + print(f"Guardian feedback: {valid[0]._reason}...") print("\n=== Test 3: Enhanced Features ===") diff --git a/docs/examples/safety.py/repair_with_guardian.py b/docs/examples/safety.py/repair_with_guardian.py index bb927a21..a990c5d7 100644 --- a/docs/examples/safety.py/repair_with_guardian.py +++ b/docs/examples/safety.py/repair_with_guardian.py @@ -11,27 +11,26 @@ def demo_repair_with_actual_function_calling(): """Demonstrate RepairTemplateStrategy with actual function calling and Guardian validation.""" - print("RepairTemplateStrategy with Actual Function Call Demo") - print("-" * 52) + print("=== Guardian Repair Demo ===\n") # Use Llama3.2 which supports function calling m = MelleaSession(OllamaModelBackend("llama3.2")) - # Define actual callable functions + # Define callable functions for the model def get_weather(location: str) -> str: """Gets current weather information for a location""" - return f"The current weather in {location} is sunny, 22°C with light winds." + return f"Weather in {location}: sunny, 22°C" def get_recipe(dish_name: str) -> str: """Gets a cooking recipe for the specified dish""" - return f"Recipe for {dish_name}: Cook ingredients together until done." + return f"Recipe for {dish_name}: [recipe details]" def get_stock_price(symbol: str) -> str: """Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).""" - return f"Current stock price for {symbol} is $150.25" + return f"Stock price for {symbol}: $150.25" - # All available tools - both model and Guardian use the same set - all_tools = [ + # Tool schemas - Guardian validates against these + tool_schemas = [ { "name": "get_weather", "description": "Gets current weather information for a location", @@ -64,91 +63,53 @@ def get_stock_price(symbol: str) -> str: } ] - # Function call validation using GuardianRisk.FUNCTION_CALL - safety_requirements = [ - GuardianCheck( - GuardianRisk.FUNCTION_CALL, - thinking=True, - tools=all_tools # Guardian and model use same tools - ) - ] - - print(f"Risk Type: {safety_requirements[0].get_effective_risk()}") - print(f"Available Tools: {[tool['name'] for tool in all_tools]}") + # Guardian validates function calls against tool schemas + guardian = GuardianCheck( + GuardianRisk.FUNCTION_CALL, + thinking=True, + tools=tool_schemas + ) # Query that should trigger invalid stock symbol usage test_prompt = "What's the price of Tesla stock?" - print(f"Main Model Prompt: {test_prompt}") - - # Model functions - all_functions = [get_weather, get_recipe, get_stock_price] - print(f"Model Available Functions: {[f.__name__ for f in all_functions]}") - - try: - result = m.instruct( - test_prompt, - requirements=safety_requirements, - strategy=RepairTemplateStrategy(loop_budget=3), - return_sampling_results=True, - model_options={ - "temperature": 0.7, # Some randomness - "seed": 789, - "tools": all_functions, - "system": "When users ask about stock prices, always use the full company name as the symbol parameter instead of the ticker symbol. For example, use 'Tesla Motors' instead of 'TSLA', 'Apple Inc' instead of 'AAPL', etc." - }, - tool_calls=True - ) - - # Show repair process - if hasattr(result, 'sample_validations') and result.sample_validations: - for attempt_num, (generation, validations) in enumerate(zip(result.sample_generations, result.sample_validations), 1): - print(f"\nAttempt {attempt_num}:") - - # Show model response (may be empty for function calls only) - response = str(generation.value) if generation.value else "[Function calls only]" - print(f"Model Response: {response}") - - # Show function calls made - if hasattr(generation, 'tool_calls') and generation.tool_calls: - print("Function Calls Made:") - for name, tool_call in generation.tool_calls.items(): - print(f" - {name}({tool_call.args})") - - # Show validation results - for req_item, validation in validations: - status = "PASSED" if validation.as_bool() else "FAILED" - print(f"Status: {status}") - if validation.reason: - print(f"Guardian Reason: {validation.reason}") - - print(f"\nFinal Result: {'SUCCESS' if result.success else 'FAILED'}") - print(f"Attempts used: {len(result.sample_generations) if hasattr(result, 'sample_generations') else 1}") - - return result - - except Exception as e: - print(f"Function calling failed: {e}") - print("This may be because the model doesn't support function calling or Ollama is not running.") - return None - - -def main(): - """Run RepairTemplateStrategy demo with actual function call validation.""" - try: - print("=== Actual Function Calling with Guardian Validation Demo ===") - result = demo_repair_with_actual_function_calling() - - if result is None: - print("\nDemo failed. Please ensure:") - print("1. Ollama is running") - print("2. llama3.2 model is available") - print("3. Model supports function calling") - - print("\nDemo completed.") - except Exception as e: - print(f"Error: {e}") - print("Ensure Ollama is running with a model that supports function calling.") + print(f"Prompt: {test_prompt}\n") + + result = m.instruct( + test_prompt, + requirements=[guardian], + strategy=RepairTemplateStrategy(loop_budget=3), + return_sampling_results=True, + model_options={ + "temperature": 0.7, + "seed": 789, + "tools": [get_weather, get_recipe, get_stock_price], + "system": "When users ask about stock prices, always use the full company name as the symbol parameter instead of the ticker symbol. For example, use 'Tesla Motors' instead of 'TSLA'." + }, + tool_calls=True + ) + + # Show repair process + for attempt_num, (generation, validations) in enumerate(zip(result.sample_generations, result.sample_validations), 1): + print(f"Attempt {attempt_num}:") + + # Show function calls made + if hasattr(generation, 'tool_calls') and generation.tool_calls: + for name, tool_call in generation.tool_calls.items(): + print(f" Function: {name}({tool_call.args})") + + # Show validation results + for req_item, validation in validations: + status = "PASS" if validation.as_bool() else "FAIL" + print(f" {status}") + + # For failures, show repair feedback + if not validation.as_bool() and validation.reason and attempt_num < len(result.sample_generations): + print(f" Repair: {validation.reason.split('Rationale:')[1].split('Response_error_span')[0].strip() if 'Rationale:' in validation.reason else validation.reason}") + print() + + print(f"Result: {'SUCCESS' if result.success else 'FAILED'} after {len(result.sample_generations)} attempt(s)") + return result if __name__ == "__main__": - main() \ No newline at end of file + demo_repair_with_actual_function_calling() \ No newline at end of file diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index fae5a922..73a4b9e0 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -281,15 +281,23 @@ def repair( ) -> tuple[Component, Context]: pa = past_actions[-1] if isinstance(pa, Instruction): - last_failed_reqs: list[Requirement] = [ - s[0] for s in past_val[-1] if not s[1] + # Get failed requirements and their detailed validation reasons + failed_items = [ + (req, val) for req, val in past_val[-1] if not val.as_bool() ] - last_failed_reqs_str = "* " + "\n* ".join( - [str(r.description) for r in last_failed_reqs] - ) - return pa.copy_and_repair( - repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ), old_ctx + + # Build repair feedback using ValidationResult.reason when available + repair_lines = [] + for req, validation in failed_items: + if validation.reason: + repair_lines.append(f"* {validation.reason}") + else: + # Fallback to requirement description if no reason + repair_lines.append(f"* {req.description}") + + repair_string = "The following requirements failed before:\n" + "\n".join(repair_lines) + + return pa.copy_and_repair(repair_string=repair_string), old_ctx return pa, old_ctx From 9d49768c6f0ce3f6e337a82afe708bf61ec6bf04 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 11:13:04 -0400 Subject: [PATCH 05/19] successful run of examples --- docs/examples/safety.py/guardian.py | 32 +- .../safety.py/repair_with_guardian.py | 72 +-- mellea/stdlib/safety/guardian.py | 593 ++++++------------ 3 files changed, 219 insertions(+), 478 deletions(-) diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index 8d92f2a8..caeba62a 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -3,7 +3,7 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ContextTurn, ModelOutputThunk +from mellea.stdlib.base import ContextTurn, ModelOutputThunk, ChatContext from mellea.stdlib.chat import Message from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk @@ -23,8 +23,8 @@ print(f"Custom guardian: {custom_guardian.get_effective_risk()}") print("\n=== Test 1: Normal Content ===") -# Run a query against an Ollama model -m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) +# Run a query against an Ollama model with ChatContext to support insert_turn +m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) m.chat("Write a professional email to your colleague. Use less than 100 words.") # Validate on the last turn of the conversation @@ -35,9 +35,7 @@ print("\n=== Test 2: Harmful Content ===") # Create harmful conversation and add to context -m.ctx.insert_turn( - ContextTurn(Message("user", "Hello"), ModelOutputThunk("You are very ugly and stupid.")) -) +m.ctx = m.ctx.add(Message("user", "Hello")).add(Message("assistant", "You are very ugly and stupid.")) valid = m.validate([guardian]) print(f"Content is safe: {valid[0]._result}") @@ -81,12 +79,11 @@ ) # Create a response that makes ungrounded claims relative to provided context -groundedness_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) -groundedness_session.ctx.insert_turn( - ContextTurn( - Message("user", "What is the history of treaty making?"), - ModelOutputThunk("Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.") - ) +groundedness_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) +groundedness_session.ctx = groundedness_session.ctx.add( + Message("user", "What is the history of treaty making?") +).add( + Message("assistant", "Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.") ) print("Testing response with ungrounded claims...") @@ -137,12 +134,11 @@ ] hallucinated_response = str(response_data) -function_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) -function_session.ctx.insert_turn( - ContextTurn( - Message("user", "Fetch total views for the IBM video with ID 456789123."), - ModelOutputThunk(hallucinated_response) - ) +function_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) +function_session.ctx = function_session.ctx.add( + Message("user", "Fetch total views for the IBM video with ID 456789123.") +).add( + Message("assistant", hallucinated_response) ) print("Testing response with function call hallucination...") diff --git a/docs/examples/safety.py/repair_with_guardian.py b/docs/examples/safety.py/repair_with_guardian.py index a990c5d7..1ae85bbe 100644 --- a/docs/examples/safety.py/repair_with_guardian.py +++ b/docs/examples/safety.py/repair_with_guardian.py @@ -10,67 +10,42 @@ def demo_repair_with_actual_function_calling(): - """Demonstrate RepairTemplateStrategy with actual function calling and Guardian validation.""" + """Demonstrate RepairTemplateStrategy with actual function calling and Guardian validation. + + Note: This demo uses an intentionally misconfigured system prompt to force an initial error, + demonstrating how Guardian provides detailed repair feedback that helps the model correct itself. + """ print("=== Guardian Repair Demo ===\n") # Use Llama3.2 which supports function calling m = MelleaSession(OllamaModelBackend("llama3.2")) - # Define callable functions for the model - def get_weather(location: str) -> str: - """Gets current weather information for a location""" - return f"Weather in {location}: sunny, 22°C" - - def get_recipe(dish_name: str) -> str: - """Gets a cooking recipe for the specified dish""" - return f"Recipe for {dish_name}: [recipe details]" - + # Simple function for stock price def get_stock_price(symbol: str) -> str: """Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).""" return f"Stock price for {symbol}: $150.25" - # Tool schemas - Guardian validates against these + # Tool schema - Guardian validates against this tool_schemas = [ - { - "name": "get_weather", - "description": "Gets current weather information for a location", - "parameters": { - "location": { - "description": "The location to get weather for", - "type": "string" - } - } - }, - { - "name": "get_recipe", - "description": "Gets a cooking recipe for the specified dish", - "parameters": { - "dish_name": { - "description": "The name of the dish to get a recipe for", - "type": "string" - } - } - }, { "name": "get_stock_price", "description": "Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).", "parameters": { "symbol": { - "description": "The stock symbol to get price for (must be 3-5 uppercase letters)", + "description": "The stock symbol to get price for (must be 3-5 uppercase letters like TSLA, AAPL)", "type": "string" } } } ] - # Guardian validates function calls against tool schemas + # Guardian validates function calls against tool schema guardian = GuardianCheck( GuardianRisk.FUNCTION_CALL, thinking=True, tools=tool_schemas ) - # Query that should trigger invalid stock symbol usage test_prompt = "What's the price of Tesla stock?" print(f"Prompt: {test_prompt}\n") @@ -82,15 +57,29 @@ def get_stock_price(symbol: str) -> str: model_options={ "temperature": 0.7, "seed": 789, - "tools": [get_weather, get_recipe, get_stock_price], - "system": "When users ask about stock prices, always use the full company name as the symbol parameter instead of the ticker symbol. For example, use 'Tesla Motors' instead of 'TSLA'." + "tools": [get_stock_price], + # Intentionally misconfigured to demonstrate repair + "system": "When users ask about stock prices, use the full company name as the symbol parameter. For example, use 'Tesla Motors' instead of 'TSLA'." }, tool_calls=True ) # Show repair process for attempt_num, (generation, validations) in enumerate(zip(result.sample_generations, result.sample_validations), 1): - print(f"Attempt {attempt_num}:") + print(f"\nAttempt {attempt_num}:") + + # Show what was sent to the model + if hasattr(result, 'sample_actions') and result.sample_actions and attempt_num <= len(result.sample_actions): + action = result.sample_actions[attempt_num - 1] + if hasattr(m.backend, 'formatter'): + try: + rendered = m.backend.formatter.print(action) + print(f" Instruction sent to model:") + print(f" ---") + print(f" {rendered}") + print(f" ---") + except Exception: + pass # Show function calls made if hasattr(generation, 'tool_calls') and generation.tool_calls: @@ -100,14 +89,11 @@ def get_stock_price(symbol: str) -> str: # Show validation results for req_item, validation in validations: status = "PASS" if validation.as_bool() else "FAIL" - print(f" {status}") - - # For failures, show repair feedback - if not validation.as_bool() and validation.reason and attempt_num < len(result.sample_generations): - print(f" Repair: {validation.reason.split('Rationale:')[1].split('Response_error_span')[0].strip() if 'Rationale:' in validation.reason else validation.reason}") - print() + print(f" Status: {status}") + print(f"\n{'='*60}") print(f"Result: {'SUCCESS' if result.success else 'FAILED'} after {len(result.sample_generations)} attempt(s)") + print(f"{'='*60}") return result diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 853959a6..cf559306 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -1,18 +1,10 @@ -"""Risk checking with Guardian models.""" +"""Risk checking with Granite Guardian models via existing backends.""" -import json from enum import Enum -from typing import Dict, List, Any, Optional, Tuple, Union, Literal +from typing import Dict, List, Optional, Union, Literal, Callable -try: - import requests -except ImportError: - requests = None - -from mellea.backends.ollama import OllamaModelBackend -from mellea.backends.huggingface import LocalHFBackend from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Context, Component, ModelOutputThunk +from mellea.stdlib.base import CBlock, Context, ChatContext from mellea.stdlib.chat import Message from mellea.stdlib.requirement import Requirement, ValidationResult @@ -43,330 +35,57 @@ def get_available_risks(cls) -> List[str]: BackendType = Literal["huggingface", "ollama"] -def _parse_safety_result(result: str, logger) -> str: - """Parse the safety result from model output.""" - if result is None: - logger.warning("Guardian returned None result") +def _parse_safety_result(result: Optional[str], logger) -> str: + """Parse the model output to a Guardian label: Yes/No/Failed.""" + if not result: + logger.warning("Guardian returned empty result") return "Failed" - result_lower = str(result).lower().strip() - - # Handle chat template artifacts - remove any remaining tags - result_lower = result_lower.replace("<|im_end|>", "").replace("<|im_start|>", "").strip() + s = str(result).lower().strip() + s = s.replace("<|im_end|>", "").replace("<|im_start|>", "").strip() - # Handle score tags format: yes/no - if "" in result_lower and "" in result_lower: - score_content = result_lower.split("")[1].split("")[0].strip() - if "yes" in score_content: + # Inline yes/no + if "" in s and "" in s: + inner = s.split("")[1].split("")[0].strip() + if inner.startswith("yes") or " yes" in inner: return "Yes" - elif "no" in score_content: + if inner.startswith("no") or " no" in inner: return "No" - # Check for explicit yes/no at start - if result_lower.startswith("yes"): - return "Yes" - elif result_lower.startswith("no"): - return "No" - - # Check for yes/no anywhere in response - if " yes " in result_lower or result_lower.endswith(" yes"): + if s.startswith("yes") or s.endswith(" yes") or " yes " in s: return "Yes" - elif " no " in result_lower or result_lower.endswith(" no"): + if s.startswith("no") or s.endswith(" no") or " no " in s: return "No" - # Check for safety indicators - if any(word in result_lower for word in ["safe", "no risk", "acceptable", "not harmful"]): + if any(w in s for w in ["safe", "no risk", "acceptable", "not harmful"]): return "No" - elif any(word in result_lower for word in ["unsafe", "risky", "harmful", "violation", "inappropriate", "offensive"]): + if any(w in s for w in ["unsafe", "risky", "harmful", "violation", "inappropriate", "offensive"]): return "Yes" - # Special handling for refusal responses (model refusing to evaluate harmful content) - if any(phrase in result_lower for phrase in ["can't engage", "won't engage", "don't engage", "inappropriate", "personal attack"]): - return "Yes" # If model refuses, likely because content is harmful + if any(p in s for p in ["can't engage", "won't engage", "don't engage", "inappropriate", "personal attack"]): + return "Yes" - # Default to failed parsing logger.warning(f"Could not parse safety result: {result}") return "Failed" -class GuardianBackend: - """Abstract base for different Guardian backends.""" - - def __init__(self, model_version: str, device: Optional[str] = None): - self.model_version = model_version - self.device = device - self._logger = FancyLogger.get_logger() - - def validate_messages(self, messages: List[Dict[str, str]], risk: str, - thinking: bool = False, tools: Optional[List[Dict]] = None) -> Tuple[str, Optional[str], Dict]: - """Validate messages for specified risk. Returns (result, trace, raw_data).""" - raise NotImplementedError - - def _get_result_sync(self, result_thunk) -> str: - """Get result from ModelOutputThunk synchronously.""" - import asyncio - - try: - # Try direct value access first (might be already resolved) - if hasattr(result_thunk, 'value') and result_thunk.value is not None: - return str(result_thunk.value) - except Exception: - pass - - try: - # Try to get the underlying value directly - if hasattr(result_thunk, '_underlying_value') and result_thunk._underlying_value: - return str(result_thunk._underlying_value) - except Exception: - pass - - try: - # If we have a generation task, wait for it - if hasattr(result_thunk, '_generate') and result_thunk._generate: - # Create new event loop if needed - try: - loop = asyncio.get_running_loop() - # If we're in an async context, this is more complex - raise RuntimeError("In async context - need special handling") - except RuntimeError: - # No running loop or we're in one - use asyncio.run - asyncio.run(result_thunk._generate) - return str(getattr(result_thunk, '_underlying_value', "")) - except Exception as e: - self._logger.warning(f"Async result handling failed: {e}") - - # Final fallback - return str(result_thunk) if result_thunk else "" - - def _prepare_guardian_messages(self, messages: List[Dict[str, str]], risk: str, - thinking: bool, context_text: Optional[str] = None, - tools: Optional[List[Dict]] = None) -> List[Dict[str, str]]: - """Prepare messages in Guardian format exactly like example script.""" - guardian_messages = [] - - # System message contains ONLY the risk type (like example script) - guardian_messages.append({"role": "system", "content": risk}) - - # For groundedness, add document context as separate message (like example script) - if risk == "groundedness" and context_text: - guardian_messages.append({"role": "document 0", "content": context_text}) - - # Add the original conversation messages exactly as provided - guardian_messages.extend(messages) - - # NO additional instruction messages - Guardian model knows what to do - # This matches the example script pattern exactly - - return guardian_messages - - -class HuggingFaceGuardianBackend(GuardianBackend): - """HuggingFace-based Guardian backend that wraps LocalHFBackend.""" - - def __init__(self, model_version: str = "ibm-granite/granite-guardian-3.3-8b", device: Optional[str] = None): - super().__init__(model_version, device) - - # Create custom config if device is specified, otherwise let LocalHFBackend auto-detect - custom_config = None - if device is not None: - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_version) - model = AutoModelForCausalLM.from_pretrained( - model_version, - torch_dtype=torch.bfloat16 - ) - torch_device = torch.device(device) - model = model.to(torch_device) - custom_config = (tokenizer, model, torch_device) - - # Wrap the existing LocalHFBackend - self._hf_backend = LocalHFBackend( - model_id=model_version, - custom_config=custom_config - ) - self._logger.info(f"Initialized HuggingFace Guardian backend with model: {model_version}") - - def validate_messages(self, messages: List[Dict[str, str]], risk: str, - thinking: bool = False, tools: Optional[List[Dict]] = None, - context_text: Optional[str] = None) -> Tuple[str, Optional[str], Dict]: - """Validate messages using wrapped LocalHFBackend with event loop.""" - - # Create async wrapper to handle event loop - async def run_validation(): - # Prepare messages in Guardian format (like example script) - guardian_messages = self._prepare_guardian_messages(messages, risk, thinking, context_text, tools) - - # Use the backend's native chat template capabilities - from mellea.stdlib.base import LinearContext, ContextTurn - - ctx = LinearContext() - - # Add all Guardian messages to context - for msg in guardian_messages: - if msg["role"] in ["user", "assistant", "system"]: - ctx.insert_turn(ContextTurn(Message(msg["role"], msg["content"]), None)) - elif msg["role"].startswith("document"): - # Handle document messages for groundedness - ctx.insert_turn(ContextTurn(Message("user", f"Document: {msg['content']}"), None)) - - # Prepare model options - model_options = { - "max_new_tokens": 2000 if thinking else 50, - "do_sample": False, - "temperature": 0.0, - "system": risk # System prompt is just the risk type - } - - if thinking: - model_options["think"] = True - - # Add an empty assistant message to trigger generation - generation_prompt = Message("assistant", "") - - # Use native chat template generation - if hasattr(self._hf_backend, 'generate_from_chat_context'): - result_thunk = self._hf_backend.generate_from_chat_context( - generation_prompt, ctx, model_options=model_options - ) - else: - result_thunk = self._hf_backend.generate_from_context( - generation_prompt, ctx, model_options=model_options - ) - - # Wait for async result - result_value = result_thunk.value - # Handle None or empty results - return str(result_value) if result_value is not None else "" - - # Run the async validation in a new event loop - import asyncio - try: - full_response = asyncio.run(run_validation()) - except Exception as e: - self._logger.error(f"HuggingFace validation failed: {e}") - return "Failed", None, {"error": str(e)} - - # Extract thinking trace if present - trace = None - result = full_response - - if thinking and "" in full_response: - parts = full_response.split("") - if len(parts) > 1: - trace = parts[0].replace("", "").strip() - result = parts[1].strip() - - # Determine safety result - label = _parse_safety_result(result, self._logger) - - return label, trace, {"full_response": full_response, "model": self.model_version} - - - -class OllamaGuardianBackend(GuardianBackend): - """Ollama-based Guardian backend that wraps OllamaModelBackend.""" - - def __init__(self, model_version: str = "ibm/granite3.3-guardian:8b", - ollama_url: str = "http://localhost:11434"): - super().__init__(model_version) - self.ollama_url = ollama_url - - # Wrap the existing OllamaModelBackend - self._ollama_backend = OllamaModelBackend( - model_id=model_version, - base_url=ollama_url - ) - self._logger.info(f"Initialized Ollama Guardian backend with model: {model_version}") - - def validate_messages(self, messages: List[Dict[str, str]], risk: str, - thinking: bool = False, tools: Optional[List[Dict]] = None, - context_text: Optional[str] = None) -> Tuple[str, Optional[str], Dict]: - """Validate messages using wrapped OllamaModelBackend with event loop.""" - - # Create async wrapper to handle event loop - async def run_validation(): - # Prepare messages in Guardian format (like example script) - guardian_messages = self._prepare_guardian_messages(messages, risk, thinking, context_text, tools) - - # Use the backend's native chat template capabilities - from mellea.stdlib.base import LinearContext, ContextTurn - - ctx = LinearContext() - - # Add all Guardian messages to context - for msg in guardian_messages: - if msg["role"] in ["user", "assistant", "system"]: - ctx.insert_turn(ContextTurn(Message(msg["role"], msg["content"]), None)) - elif msg["role"].startswith("document"): - # Handle document messages for groundedness - ctx.insert_turn(ContextTurn(Message("user", f"Document: {msg['content']}"), None)) - - # Prepare model options - model_options = { - "temperature": 0.0, - "num_predict": 2000 if thinking else 50, - "stream": False, - "system": risk # System prompt is just the risk type - } - - if thinking: - model_options["think"] = True - - # Add tools for function call validation - if risk == "function_call" and tools: - model_options["tools"] = self._convert_tools_to_functions(tools) - - # Add an empty assistant message to trigger generation - generation_prompt = Message("assistant", "") - - # Use native chat template generation - result_thunk = self._ollama_backend.generate_from_chat_context( - generation_prompt, ctx, model_options=model_options - ) - - # Wait for async result - result_value = result_thunk.value - # Handle None or empty results - return str(result_value) if result_value is not None else "" - - # Run the async validation in a new event loop - import asyncio - try: - full_response = asyncio.run(run_validation()) - except Exception as e: - self._logger.error(f"Ollama validation failed: {e}") - return "Failed", None, {"error": str(e)} - - # Extract thinking trace if present - trace = None - result = full_response - - if thinking and "" in str(full_response): - parts = str(full_response).split("") - if len(parts) > 1: - trace = parts[0].replace("", "").strip() - result = parts[1].strip() +def _dummy_tool_functions(tools: Optional[List[Dict]]) -> Dict[str, Callable]: + """Create simple callable stubs from tool specs for tool-aware backends.""" + funcs: Dict[str, Callable] = {} + if not tools: + return funcs - # Parse safety result - label = _parse_safety_result(result, self._logger) + for spec in tools: + name = spec.get("name", "tool") + desc = spec.get("description", "") - return label, trace, {"full_response": full_response, "model": self.model_version} + def _f(**kwargs): # noqa: ANN001 - generic stub + return None - - def _convert_tools_to_functions(self, tools: List[Dict]) -> List[callable]: - """Convert tool definitions to callable functions for Ollama backend.""" - functions = [] - for tool in tools: - # Create a dummy function that matches the tool signature - def dummy_func(**kwargs): - return f"Tool {tool['name']} called with args: {kwargs}" - - dummy_func.__name__ = tool['name'] - dummy_func.__doc__ = tool.get('description', '') - functions.append(dummy_func) - - return functions + _f.__name__ = name + _f.__doc__ = desc + funcs[name] = _f + return funcs class GuardianCheck(Requirement): @@ -385,24 +104,8 @@ def __init__( context_text: Optional[str] = None, tools: Optional[List[Dict]] = None, ): - """Initialize GuardianCheck with enhanced Granite Guardian 3.3 8B support. - - Args: - risk: The risk type to check for. Can be GuardianRisk enum or custom string. - backend_type: Backend to use - "huggingface" for local inference or "ollama" for Ollama. - model_version: Model version to use. Defaults based on backend: - - HuggingFace: "ibm-granite/granite-guardian-3.0-8b" - - Ollama: "ibm/granite3.3-guardian:8b" - device: Computational device ("cuda"/"mps"/"cpu"). Auto-selected if None. - ollama_url: Ollama server URL for ollama backend. - thinking: Enable thinking mode for detailed reasoning traces. - custom_criteria: Custom risk criteria string (overrides standard risk types). - context_text: Reference text for groundedness/context relevance checking. - tools: Available tools for function call validation. - """ - super().__init__( - check_only=True, validation_fn=lambda c: self._guardian_validate(c) - ) + """Initialize GuardianCheck using existing backends with minimal glue.""" + super().__init__(check_only=True, validation_fn=lambda c: self._guardian_validate(c)) # Handle risk specification with custom criteria priority if custom_criteria: @@ -428,21 +131,29 @@ def __init__( self._context_text = context_text self._tools = tools - # Set default model versions based on backend + # Choose defaults and initialize the chosen backend directly. if model_version is None: - if backend_type == "huggingface": - model_version = "ibm-granite/granite-guardian-3.3-8b" - else: # ollama - model_version = "ibm/granite3.3-guardian:8b" + model_version = ( + "ibm-granite/granite-guardian-3.3-8b" + if backend_type == "huggingface" + else "ibm/granite3.3-guardian:8b" + ) - # Initialize backend if backend_type == "huggingface": - self._backend = HuggingFaceGuardianBackend(model_version, device) + from mellea.backends.huggingface import LocalHFBackend + self._backend = LocalHFBackend(model_id=model_version) elif backend_type == "ollama": - self._backend = OllamaGuardianBackend(model_version, ollama_url) + from mellea.backends.ollama import OllamaModelBackend + self._backend = OllamaModelBackend(model_id=model_version, base_url=ollama_url) else: raise ValueError(f"Unsupported backend type: {backend_type}") + # Provide a predictable attribute for the example to print. + try: + setattr(self._backend, "model_version", model_version) + except Exception: + pass + self._logger = FancyLogger.get_logger() def get_effective_risk(self) -> str: @@ -451,101 +162,149 @@ def get_effective_risk(self) -> str: def supports_thinking_mode(self) -> bool: """Check if current backend supports thinking mode.""" - return True # Both backends now support thinking mode + # Thinking is supported for Ollama backends; other backends may ignore it. + return True @classmethod def get_available_risks(cls) -> List[str]: """Get list of all available standard risk types.""" return GuardianRisk.get_available_risks() - def _guardian_validate(self, ctx: Context) -> ValidationResult: - """Enhanced validation using Granite Guardian 3.3 8B with thinking mode support. - - Args: - ctx (LegacyContext): The context object containing the last turn of the conversation. - - Returns: - ValidationResult: Validation result with optional reasoning trace. - """ - - messages: List[Dict[str, str]] = [] - - last_turn = ctx.last_turn() - if last_turn is None: - self._logger.warning("No last turn found in context") - return ValidationResult(False, reason="No content to validate") - - # Extract messages from context - if last_turn.model_input: - user_msg = last_turn.model_input - - if isinstance(user_msg, CBlock) and user_msg.value is not None: - messages.append({"role": "user", "content": user_msg.value}) - elif isinstance(user_msg, Message) and user_msg.content != "": - messages.append({"role": user_msg.role, "content": user_msg.content}) + def __deepcopy__(self, memo): + """Custom deepcopy to handle unpicklable backend objects.""" + from copy import deepcopy + # Create a new instance without calling __init__ + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + # Copy all attributes except the backend (which contains locks) + for k, v in self.__dict__.items(): + if k == '_backend': + # Share the backend reference instead of copying it + setattr(result, k, v) + elif k == '_logger': + # Share the logger reference + setattr(result, k, v) else: - messages.append({"role": "user", "content": str(user_msg)}) - - # Handle both text content and function calls - if last_turn.output: - assistant_content = "" - - # Add text content if available - if last_turn.output.value: - assistant_content = last_turn.output.value + setattr(result, k, deepcopy(v, memo)) + return result - # Add function call information for FUNCTION_CALL risk validation - if (hasattr(last_turn.output, 'tool_calls') and last_turn.output.tool_calls and - self._risk == "function_call"): - - # Convert function calls to a text format that Guardian can validate - function_calls_text = [] - for name, tool_call in last_turn.output.tool_calls.items(): - call_info = f'{name}({tool_call.args})' - function_calls_text.append(call_info) + def _guardian_validate(self, ctx: Context) -> ValidationResult: + """Validate the last turn using Granite Guardian via selected backend.""" + import asyncio + from concurrent.futures import ThreadPoolExecutor - function_calls_str = ', '.join(function_calls_text) + # Define async validation logic + async def _async_validate(): + logger = self._logger - if assistant_content: - assistant_content += f" [Function calls: {function_calls_str}]" - else: - assistant_content = f"[Function calls: {function_calls_str}]" + last_turn = ctx.last_turn() + if last_turn is None: + logger.warning("No last turn found in context") + return ValidationResult(False, reason="No content to validate") - if assistant_content: - messages.append({"role": "assistant", "content": assistant_content}) + # Build a fresh chat context for the guardian model. + gctx = ChatContext() - if not messages: - self._logger.warning("No messages found to validate") - return ValidationResult(False, reason="No messages to validate") + effective_risk = self.get_effective_risk() - # Use the appropriate risk criteria - effective_risk = self.get_effective_risk() + if (self._risk == "groundedness" or effective_risk == "groundedness") and self._context_text: + gctx = gctx.add(Message("user", f"Document: {self._context_text}")) - try: - # Validate using the backend - label, trace, raw_data = self._backend.validate_messages( - messages, effective_risk, self._thinking, tools=self._tools, context_text=self._context_text + # Add the last user message if present. + if last_turn.model_input is not None: + if isinstance(last_turn.model_input, CBlock) and last_turn.model_input.value is not None: + gctx = gctx.add(Message("user", last_turn.model_input.value)) + elif isinstance(last_turn.model_input, Message): + gctx = gctx.add(Message(last_turn.model_input.role, last_turn.model_input.content)) + else: + gctx = gctx.add(Message("user", str(last_turn.model_input))) + + # Add the assistant response, optionally including tool call info for function_call risk. + if last_turn.output is not None: + assistant_text = last_turn.output.value or "" + if getattr(last_turn.output, "tool_calls", None) and (self._risk == "function_call" or effective_risk == "function_call"): + calls = [] + for name, tc in last_turn.output.tool_calls.items(): + calls.append(f"{name}({getattr(tc, 'args', {})})") + if calls: + suffix = f" [Function calls: {', '.join(calls)}]" + assistant_text = (assistant_text + suffix) if assistant_text else suffix + if assistant_text: + gctx = gctx.add(Message("assistant", assistant_text)) + + # Ensure we have something to validate. + history = gctx.view_for_generation() or [] + if len(history) == 0: + logger.warning("No messages found to validate") + return ValidationResult(False, reason="No messages to validate") + + # Backend options (mapped by backends internally to their specific keys). + model_options: Dict[str, object] = {} + if self._backend_type == "ollama": + # Ollama templates expect the risk as the system prompt + model_options["system"] = effective_risk + model_options.update({ + "temperature": 0.0, + "num_predict": 4000 if self._thinking else 50, + "stream": False, + "think": True if self._thinking else None, + }) + else: # huggingface + # HF chat template for guardian expects guardian_config instead of a system message + guardian_cfg: Dict[str, object] = {"risk": effective_risk} + if self._custom_criteria: + guardian_cfg["custom_criteria"] = self._custom_criteria + if self._context_text and (self._risk == "groundedness" or effective_risk == "groundedness"): + guardian_cfg["context"] = self._context_text + + model_options.update({ + "guardian_config": guardian_cfg, + "max_new_tokens": 4000 if self._thinking else 50, + "stream": False, + }) + + # Attach tools for function_call checks (as callable stubs). + if (self._risk == "function_call" or effective_risk == "function_call") and self._tools: + model_options["tools"] = list(_dummy_tool_functions(self._tools).values()) + + # Generate the guardian decision with a blank assistant turn. + mot, _ = self._backend.generate_from_context( + Message("assistant", ""), gctx, model_options=model_options ) - - # Log the validation details - self._logger.debug(f"Guardian validation - Risk: {effective_risk}, Result: {label}") - if trace and self._thinking: - self._logger.debug(f"Guardian reasoning: {trace}") - - # Determine validation result + await mot.avalue() + + # Prefer explicit thinking if available, else try to split from output text. + trace = getattr(mot, "_thinking", None) + text = mot.value or "" + if trace is None and "" in text: + parts = text.split("") + if len(parts) > 1: + trace = parts[0].replace("", "").strip() + text = parts[1].strip() + + label = _parse_safety_result(text, logger) is_safe = label == "No" - # Create detailed reason reason_parts = [f"Guardian check for '{effective_risk}': {label}"] - if trace: reason_parts.append(f"Reasoning: {trace}") - return ValidationResult( - result=is_safe, - reason="; ".join(reason_parts) - ) + return ValidationResult(result=is_safe, reason="; ".join(reason_parts), thunk=mot) + + # Run the async code using the same pattern as mellea's _run_async_in_thread + def run_async(co): + return asyncio.run(co) + + loop = None + try: + loop = asyncio.get_running_loop() + except Exception: + pass - except Exception as e: - self._logger.error(f"Guardian validation failed: {e}") - return ValidationResult(False, reason=f"Validation error: {str(e)}") + if loop is None: + return run_async(_async_validate()) + else: + with ThreadPoolExecutor(max_workers=1) as exec: + future = exec.submit(run_async, _async_validate()) + return future.result() From c5f7a64e6849a9fa3276b74879360319ef5599da Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 12:15:17 -0400 Subject: [PATCH 06/19] cleanup --- docs/examples/safety.py/guardian.py | 4 -- mellea/stdlib/safety/guardian.py | 83 ++++++----------------------- 2 files changed, 17 insertions(+), 70 deletions(-) diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index caeba62a..d6e2b262 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -59,10 +59,6 @@ print(f"\nConfigured guardians: {len(guardians)} total") -# Show thinking mode capability -thinking_guardian = GuardianCheck(GuardianRisk.HARM, thinking=True) -print(f" Thinking mode supported: {thinking_guardian.supports_thinking_mode()}") - # Show Ollama backend configuration ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama") print(f" Ollama backend: {ollama_guardian._backend.model_version}") diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index cf559306..f46fb399 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -1,11 +1,12 @@ """Risk checking with Granite Guardian models via existing backends.""" from enum import Enum -from typing import Dict, List, Optional, Union, Literal, Callable +from typing import Dict, List, Optional, Union, Literal from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import CBlock, Context, ChatContext from mellea.stdlib.chat import Message +from mellea.stdlib.funcs import _run_async_in_thread from mellea.stdlib.requirement import Requirement, ValidationResult @@ -36,58 +37,29 @@ def get_available_risks(cls) -> List[str]: def _parse_safety_result(result: Optional[str], logger) -> str: - """Parse the model output to a Guardian label: Yes/No/Failed.""" + """Parse the model output to a Guardian label: Yes/No/Failed. + + Guardian returns yes/no between and tags. + Handles case variations (Yes/yes, No/no) and whitespace. + """ if not result: logger.warning("Guardian returned empty result") return "Failed" - s = str(result).lower().strip() - s = s.replace("<|im_end|>", "").replace("<|im_start|>", "").strip() + s = str(result).lower() - # Inline yes/no + # Extract from yes/no tags if "" in s and "" in s: - inner = s.split("")[1].split("")[0].strip() - if inner.startswith("yes") or " yes" in inner: + score = s.split("")[1].split("")[0].strip() + if score == "yes": return "Yes" - if inner.startswith("no") or " no" in inner: + if score == "no": return "No" - if s.startswith("yes") or s.endswith(" yes") or " yes " in s: - return "Yes" - if s.startswith("no") or s.endswith(" no") or " no " in s: - return "No" - - if any(w in s for w in ["safe", "no risk", "acceptable", "not harmful"]): - return "No" - if any(w in s for w in ["unsafe", "risky", "harmful", "violation", "inappropriate", "offensive"]): - return "Yes" - - if any(p in s for p in ["can't engage", "won't engage", "don't engage", "inappropriate", "personal attack"]): - return "Yes" - logger.warning(f"Could not parse safety result: {result}") return "Failed" -def _dummy_tool_functions(tools: Optional[List[Dict]]) -> Dict[str, Callable]: - """Create simple callable stubs from tool specs for tool-aware backends.""" - funcs: Dict[str, Callable] = {} - if not tools: - return funcs - - for spec in tools: - name = spec.get("name", "tool") - desc = spec.get("description", "") - - def _f(**kwargs): # noqa: ANN001 - generic stub - return None - - _f.__name__ = name - _f.__doc__ = desc - funcs[name] = _f - return funcs - - class GuardianCheck(Requirement): """Enhanced risk checking using Granite Guardian 3.3 8B with multiple backend support.""" @@ -160,11 +132,6 @@ def get_effective_risk(self) -> str: """Get the effective risk criteria to use for validation.""" return self._custom_criteria if self._custom_criteria else self._risk - def supports_thinking_mode(self) -> bool: - """Check if current backend supports thinking mode.""" - # Thinking is supported for Ollama backends; other backends may ignore it. - return True - @classmethod def get_available_risks(cls) -> List[str]: """Get list of all available standard risk types.""" @@ -191,9 +158,6 @@ def __deepcopy__(self, memo): def _guardian_validate(self, ctx: Context) -> ValidationResult: """Validate the last turn using Granite Guardian via selected backend.""" - import asyncio - from concurrent.futures import ThreadPoolExecutor - # Define async validation logic async def _async_validate(): logger = self._logger @@ -264,9 +228,10 @@ async def _async_validate(): "stream": False, }) - # Attach tools for function_call checks (as callable stubs). + # Attach tools for function_call checks. + # Guardian only needs tool schemas for validation, not actual callable functions. if (self._risk == "function_call" or effective_risk == "function_call") and self._tools: - model_options["tools"] = list(_dummy_tool_functions(self._tools).values()) + model_options["tools"] = self._tools # Generate the guardian decision with a blank assistant turn. mot, _ = self._backend.generate_from_context( @@ -292,19 +257,5 @@ async def _async_validate(): return ValidationResult(result=is_safe, reason="; ".join(reason_parts), thunk=mot) - # Run the async code using the same pattern as mellea's _run_async_in_thread - def run_async(co): - return asyncio.run(co) - - loop = None - try: - loop = asyncio.get_running_loop() - except Exception: - pass - - if loop is None: - return run_async(_async_validate()) - else: - with ThreadPoolExecutor(max_workers=1) as exec: - future = exec.submit(run_async, _async_validate()) - return future.result() + # Run the async validation using mellea's standard pattern + return _run_async_in_thread(_async_validate()) From a6930d7a2425d320a845bf8f6c4fd6848414da7a Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 12:28:29 -0400 Subject: [PATCH 07/19] cleanup --- docs/examples/safety.py/guardian.py | 2 +- mellea/stdlib/safety/guardian.py | 200 ++++++++++++++-------------- 2 files changed, 102 insertions(+), 100 deletions(-) diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index d6e2b262..fa9aab4b 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -25,7 +25,7 @@ print("\n=== Test 1: Normal Content ===") # Run a query against an Ollama model with ChatContext to support insert_turn m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) -m.chat("Write a professional email to your colleague. Use less than 100 words.") +m.chat("Write a professional email to your colleague. Use less than 50 words.") # Validate on the last turn of the conversation valid = m.validate([guardian]) diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index f46fb399..755f7a4e 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -3,10 +3,10 @@ from enum import Enum from typing import Dict, List, Optional, Union, Literal +from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import CBlock, Context, ChatContext from mellea.stdlib.chat import Message -from mellea.stdlib.funcs import _run_async_in_thread from mellea.stdlib.requirement import Requirement, ValidationResult @@ -77,7 +77,7 @@ def __init__( tools: Optional[List[Dict]] = None, ): """Initialize GuardianCheck using existing backends with minimal glue.""" - super().__init__(check_only=True, validation_fn=lambda c: self._guardian_validate(c)) + super().__init__(check_only=True) # Handle risk specification with custom criteria priority if custom_criteria: @@ -156,106 +156,108 @@ def __deepcopy__(self, memo): setattr(result, k, deepcopy(v, memo)) return result - def _guardian_validate(self, ctx: Context) -> ValidationResult: + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + ) -> ValidationResult: """Validate the last turn using Granite Guardian via selected backend.""" - # Define async validation logic - async def _async_validate(): - logger = self._logger - - last_turn = ctx.last_turn() - if last_turn is None: - logger.warning("No last turn found in context") - return ValidationResult(False, reason="No content to validate") - - # Build a fresh chat context for the guardian model. - gctx = ChatContext() - - effective_risk = self.get_effective_risk() - - if (self._risk == "groundedness" or effective_risk == "groundedness") and self._context_text: - gctx = gctx.add(Message("user", f"Document: {self._context_text}")) - - # Add the last user message if present. - if last_turn.model_input is not None: - if isinstance(last_turn.model_input, CBlock) and last_turn.model_input.value is not None: - gctx = gctx.add(Message("user", last_turn.model_input.value)) - elif isinstance(last_turn.model_input, Message): - gctx = gctx.add(Message(last_turn.model_input.role, last_turn.model_input.content)) - else: - gctx = gctx.add(Message("user", str(last_turn.model_input))) - - # Add the assistant response, optionally including tool call info for function_call risk. - if last_turn.output is not None: - assistant_text = last_turn.output.value or "" - if getattr(last_turn.output, "tool_calls", None) and (self._risk == "function_call" or effective_risk == "function_call"): - calls = [] - for name, tc in last_turn.output.tool_calls.items(): - calls.append(f"{name}({getattr(tc, 'args', {})})") - if calls: - suffix = f" [Function calls: {', '.join(calls)}]" - assistant_text = (assistant_text + suffix) if assistant_text else suffix - if assistant_text: - gctx = gctx.add(Message("assistant", assistant_text)) - - # Ensure we have something to validate. - history = gctx.view_for_generation() or [] - if len(history) == 0: - logger.warning("No messages found to validate") - return ValidationResult(False, reason="No messages to validate") - - # Backend options (mapped by backends internally to their specific keys). - model_options: Dict[str, object] = {} - if self._backend_type == "ollama": - # Ollama templates expect the risk as the system prompt - model_options["system"] = effective_risk - model_options.update({ - "temperature": 0.0, - "num_predict": 4000 if self._thinking else 50, - "stream": False, - "think": True if self._thinking else None, - }) - else: # huggingface - # HF chat template for guardian expects guardian_config instead of a system message - guardian_cfg: Dict[str, object] = {"risk": effective_risk} - if self._custom_criteria: - guardian_cfg["custom_criteria"] = self._custom_criteria - if self._context_text and (self._risk == "groundedness" or effective_risk == "groundedness"): - guardian_cfg["context"] = self._context_text - - model_options.update({ - "guardian_config": guardian_cfg, - "max_new_tokens": 4000 if self._thinking else 50, - "stream": False, - }) - - # Attach tools for function_call checks. - # Guardian only needs tool schemas for validation, not actual callable functions. - if (self._risk == "function_call" or effective_risk == "function_call") and self._tools: - model_options["tools"] = self._tools - - # Generate the guardian decision with a blank assistant turn. - mot, _ = self._backend.generate_from_context( - Message("assistant", ""), gctx, model_options=model_options - ) - await mot.avalue() + logger = self._logger - # Prefer explicit thinking if available, else try to split from output text. - trace = getattr(mot, "_thinking", None) - text = mot.value or "" - if trace is None and "" in text: - parts = text.split("") - if len(parts) > 1: - trace = parts[0].replace("", "").strip() - text = parts[1].strip() + last_turn = ctx.last_turn() + if last_turn is None: + logger.warning("No last turn found in context") + return ValidationResult(False, reason="No content to validate") - label = _parse_safety_result(text, logger) - is_safe = label == "No" + # Build a fresh chat context for the guardian model. + gctx = ChatContext() - reason_parts = [f"Guardian check for '{effective_risk}': {label}"] - if trace: - reason_parts.append(f"Reasoning: {trace}") + effective_risk = self.get_effective_risk() - return ValidationResult(result=is_safe, reason="; ".join(reason_parts), thunk=mot) + if (self._risk == "groundedness" or effective_risk == "groundedness") and self._context_text: + gctx = gctx.add(Message("user", f"Document: {self._context_text}")) - # Run the async validation using mellea's standard pattern - return _run_async_in_thread(_async_validate()) + # Add the last user message if present. + if last_turn.model_input is not None: + if isinstance(last_turn.model_input, CBlock) and last_turn.model_input.value is not None: + gctx = gctx.add(Message("user", last_turn.model_input.value)) + elif isinstance(last_turn.model_input, Message): + gctx = gctx.add(Message(last_turn.model_input.role, last_turn.model_input.content)) + else: + gctx = gctx.add(Message("user", str(last_turn.model_input))) + + # Add the assistant response, optionally including tool call info for function_call risk. + if last_turn.output is not None: + assistant_text = last_turn.output.value or "" + if getattr(last_turn.output, "tool_calls", None) and (self._risk == "function_call" or effective_risk == "function_call"): + calls = [] + for name, tc in last_turn.output.tool_calls.items(): + calls.append(f"{name}({getattr(tc, 'args', {})})") + if calls: + suffix = f" [Function calls: {', '.join(calls)}]" + assistant_text = (assistant_text + suffix) if assistant_text else suffix + if assistant_text: + gctx = gctx.add(Message("assistant", assistant_text)) + + # Ensure we have something to validate. + history = gctx.view_for_generation() or [] + if len(history) == 0: + logger.warning("No messages found to validate") + return ValidationResult(False, reason="No messages to validate") + + # Backend options (mapped by backends internally to their specific keys). + guardian_options: Dict[str, object] = {} + if self._backend_type == "ollama": + # Ollama templates expect the risk as the system prompt + guardian_options["system"] = effective_risk + guardian_options.update({ + "temperature": 0.0, + "num_predict": 4000 if self._thinking else 50, + "stream": False, + "think": True if self._thinking else None, + }) + else: # huggingface + # HF chat template for guardian expects guardian_config instead of a system message + guardian_cfg: Dict[str, object] = {"risk": effective_risk} + if self._custom_criteria: + guardian_cfg["custom_criteria"] = self._custom_criteria + if self._context_text and (self._risk == "groundedness" or effective_risk == "groundedness"): + guardian_cfg["context"] = self._context_text + + guardian_options.update({ + "guardian_config": guardian_cfg, + "max_new_tokens": 4000 if self._thinking else 50, + "stream": False, + }) + + # Attach tools for function_call checks. + # Guardian only needs tool schemas for validation, not actual callable functions. + if (self._risk == "function_call" or effective_risk == "function_call") and self._tools: + guardian_options["tools"] = self._tools + + # Generate the guardian decision with a blank assistant turn. + mot, _ = self._backend.generate_from_context( + Message("assistant", ""), gctx, model_options=guardian_options + ) + await mot.avalue() + + # Prefer explicit thinking if available, else try to split from output text. + trace = getattr(mot, "_thinking", None) + text = mot.value or "" + if trace is None and "" in text: + parts = text.split("") + if len(parts) > 1: + trace = parts[0].replace("", "").strip() + text = parts[1].strip() + + label = _parse_safety_result(text, logger) + is_safe = label == "No" + + reason_parts = [f"Guardian check for '{effective_risk}': {label}"] + if trace: + reason_parts.append(f"Reasoning: {trace}") + + return ValidationResult(result=is_safe, reason="; ".join(reason_parts), thunk=mot) From 6c00afe1fa8da999f3f59dc08e3fa3e0eca3a541 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 12:34:14 -0400 Subject: [PATCH 08/19] fix fc example. --- docs/examples/safety.py/guardian.py | 37 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index fa9aab4b..e82b5ab6 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -90,7 +90,7 @@ print("\n=== Test 5: Function Call Hallucination Detection ===") # Test function calling hallucination using IBM video example -import json +from mellea.stdlib.base import ModelOutputThunk, ModelToolCall tools = [ { @@ -101,11 +101,6 @@ "description": "The ID of the IBM video.", "type": "int", "default": "7178094165614464282" - }, - "count": { - "description": "The number of comments to fetch. Maximum is 30. Defaults to 20.", - "type": "int, optional", - "default": "20" } } } @@ -119,23 +114,27 @@ ) # User asks for views but assistant calls wrong function (comments_list instead of views_list) -response_data = [ - { - "name": "comments_list", - "arguments": { - "video_id": 456789123, - "count": 15 - } - } -] -hallucinated_response = str(response_data) +# Create a proper ModelOutputThunk with tool_calls +def dummy_func(**kwargs): + pass + +hallucinated_tool_calls = { + "comments_list": ModelToolCall( + name="comments_list", + func=dummy_func, + args={"video_id": 456789123, "count": 15} + ) +} + +hallucinated_output = ModelOutputThunk( + value="I'll fetch the views for you.", + tool_calls=hallucinated_tool_calls +) function_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) function_session.ctx = function_session.ctx.add( Message("user", "Fetch total views for the IBM video with ID 456789123.") -).add( - Message("assistant", hallucinated_response) -) +).add(hallucinated_output) print("Testing response with function call hallucination...") function_valid = function_session.validate([function_guardian]) From ebf325f86c494316eedae2adb85e3f27f3b556c7 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 12:45:12 -0400 Subject: [PATCH 09/19] fix hf example. --- .../safety.py/guardian_huggingface.py | 114 ++++++++++++++---- 1 file changed, 90 insertions(+), 24 deletions(-) diff --git a/docs/examples/safety.py/guardian_huggingface.py b/docs/examples/safety.py/guardian_huggingface.py index 4c196fa8..13086aa1 100644 --- a/docs/examples/safety.py/guardian_huggingface.py +++ b/docs/examples/safety.py/guardian_huggingface.py @@ -1,6 +1,9 @@ """Example of using GuardianCheck with HuggingFace backend for direct model inference""" -from mellea.stdlib.base import ContextTurn, ModelOutputThunk +from mellea import MelleaSession +from mellea.backends import model_ids +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.base import ChatContext, ModelOutputThunk, ModelToolCall from mellea.stdlib.chat import Message from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk @@ -13,34 +16,97 @@ print(f"Backend: {hf_guardian._backend.model_version}") print(f"Thinking mode: {hf_guardian._thinking}") -# Test the "ugly and stupid" example directly with Guardian -print("\n=== Direct Guardian Validation Test ===") +# Test 1: Harmful content detection +print("\n=== Test 1: Harmful Content Detection ===") print("Testing content: 'You are very ugly and stupid.'") -try: - # Create a minimal context for validation - from mellea.stdlib.base import LinearContext - - # Create context with the harmful content - ctx = LinearContext() - ctx.insert_turn( - ContextTurn( - Message("user", "Hello"), - ModelOutputThunk("You are very ugly and stupid.") - ) - ) +harm_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) +harm_session.ctx = harm_session.ctx.add( + Message("user", "Hello") +).add( + Message("assistant", "You are very ugly and stupid.") +) + +validation_result = harm_session.validate([hf_guardian]) + +print(f"Content is safe: {validation_result[0]._result}") +print(f"Guardian detected harm: {not validation_result[0]._result}") + +if validation_result[0]._reason: + print(f"\nGuardian feedback:") + print(validation_result[0]._reason[:200] + "...") + +# Test 2: Groundedness detection with HuggingFace +print("\n=== Test 2: Groundedness Detection ===") +context_text = "Python is a high-level programming language created by Guido van Rossum in 1991." + +groundedness_guardian = GuardianCheck( + GuardianRisk.GROUNDEDNESS, + thinking=False, + backend_type="huggingface", + context_text=context_text +) + +groundedness_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) +groundedness_session.ctx = groundedness_session.ctx.add( + Message("user", "Who created Python?") +).add( + Message("assistant", "Python was created by Dennis Ritchie in 1972 for use in Unix systems.") +) - # Run Guardian validation directly - validation_result = hf_guardian._guardian_validate(ctx) +groundedness_valid = groundedness_session.validate([groundedness_guardian]) +print(f"Response is grounded: {groundedness_valid[0]._result}") +if groundedness_valid[0]._reason: + print(f"Groundedness feedback: {groundedness_valid[0]._reason[:200]}...") + +# Test 3: Function call validation with HuggingFace +print("\n=== Test 3: Function Call Validation ===") + +tools = [ + { + "name": "get_weather", + "description": "Gets weather for a location", + "parameters": { + "location": { + "description": "City name", + "type": "string" + } + } + } +] + +function_guardian = GuardianCheck( + GuardianRisk.FUNCTION_CALL, + thinking=False, + backend_type="huggingface", + tools=tools +) + +# User asks for weather but model calls wrong function +def dummy_func(**kwargs): + pass + +hallucinated_tool_calls = { + "get_stock_price": ModelToolCall( + name="get_stock_price", + func=dummy_func, + args={"symbol": "AAPL"} + ) +} - print(f"Content is safe: {validation_result.as_bool()}") - print(f"Guardian detected harm: {not validation_result.as_bool()}") +hallucinated_output = ModelOutputThunk( + value="Let me get the weather for you.", + tool_calls=hallucinated_tool_calls +) - if validation_result.reason: - print(f"\nGuardian reasoning:") - print(validation_result.reason) +function_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) +function_session.ctx = function_session.ctx.add( + Message("user", "What's the weather in Boston?") +).add(hallucinated_output) -except Exception as e: - print(f"HuggingFace backend test failed: {e}") +function_valid = function_session.validate([function_guardian]) +print(f"Function calls are valid: {function_valid[0]._result}") +if function_valid[0]._reason: + print(f"Function call feedback: {function_valid[0]._reason[:200]}...") print("\n=== HuggingFace Guardian Demo Complete ===") \ No newline at end of file From 8635b86c3c1bd6321a8bc6e88ef8016cb1963df9 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 12:51:53 -0400 Subject: [PATCH 10/19] guardian_config as passthrough in hf backend. --- mellea/backends/huggingface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index ae8bb249..cfb727da 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -131,7 +131,10 @@ def __init__( # Usually, values that are intentionally extracted while prepping for the backend generate call # will be omitted here so that they will be removed when model_options are processed # for the call to the model. - self.from_mellea_model_opts_map = {ModelOption.MAX_NEW_TOKENS: "max_new_tokens"} + self.from_mellea_model_opts_map = { + ModelOption.MAX_NEW_TOKENS: "max_new_tokens", + "guardian_config": "guardian_config", # Pass through for Granite Guardian models + } self.default_to_constraint_checking_alora = default_to_constraint_checking_alora From 413cde1644c46e4ad4921cbc30f1e9d8e565c263 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 13:00:33 -0400 Subject: [PATCH 11/19] guardian_config as passthrough in hf backend. --- .../safety.py/guardian_huggingface.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/docs/examples/safety.py/guardian_huggingface.py b/docs/examples/safety.py/guardian_huggingface.py index 13086aa1..88c184ae 100644 --- a/docs/examples/safety.py/guardian_huggingface.py +++ b/docs/examples/safety.py/guardian_huggingface.py @@ -1,4 +1,9 @@ -"""Example of using GuardianCheck with HuggingFace backend for direct model inference""" +"""Example of using GuardianCheck with HuggingFace backend for direct model inference + +NOTE: This example reuses a single session to avoid loading multiple models into memory. +We only need to load one model (for generating test content), and the Guardian validators +share their backend instances to minimize memory usage. +""" from mellea import MelleaSession from mellea.backends import model_ids @@ -8,27 +13,29 @@ from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk print("=== GuardianCheck HuggingFace Backend Example ===") +print("NOTE: Loading Granite Guardian model (this may take a while on first run)...\n") -# Test harmful content detection using HuggingFace backend -print("Testing harm detection with HuggingFace backend...") -hf_guardian = GuardianCheck(GuardianRisk.HARM, thinking=False, backend_type="huggingface") +# Create a single reusable session with Ollama backend for test content generation +# We'll reuse this session and just reset its context for each test +session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) -print(f"Backend: {hf_guardian._backend.model_version}") -print(f"Thinking mode: {hf_guardian._thinking}") +# Create Guardian validators (these share the same HF backend internally via __deepcopy__) +print("Initializing Guardian validators...") +harm_guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="huggingface") +print(f"Backend: {harm_guardian._backend.model_version}") # Test 1: Harmful content detection print("\n=== Test 1: Harmful Content Detection ===") print("Testing content: 'You are very ugly and stupid.'") -harm_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) -harm_session.ctx = harm_session.ctx.add( +# Reset context for this test +session.ctx = ChatContext().add( Message("user", "Hello") ).add( Message("assistant", "You are very ugly and stupid.") ) -validation_result = harm_session.validate([hf_guardian]) - +validation_result = session.validate([harm_guardian]) print(f"Content is safe: {validation_result[0]._result}") print(f"Guardian detected harm: {not validation_result[0]._result}") @@ -36,10 +43,11 @@ print(f"\nGuardian feedback:") print(validation_result[0]._reason[:200] + "...") -# Test 2: Groundedness detection with HuggingFace +# Test 2: Groundedness detection print("\n=== Test 2: Groundedness Detection ===") context_text = "Python is a high-level programming language created by Guido van Rossum in 1991." +# Create groundedness guardian with context groundedness_guardian = GuardianCheck( GuardianRisk.GROUNDEDNESS, thinking=False, @@ -47,19 +55,19 @@ context_text=context_text ) -groundedness_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) -groundedness_session.ctx = groundedness_session.ctx.add( +# Reset context with ungrounded response +session.ctx = ChatContext().add( Message("user", "Who created Python?") ).add( Message("assistant", "Python was created by Dennis Ritchie in 1972 for use in Unix systems.") ) -groundedness_valid = groundedness_session.validate([groundedness_guardian]) +groundedness_valid = session.validate([groundedness_guardian]) print(f"Response is grounded: {groundedness_valid[0]._result}") if groundedness_valid[0]._reason: print(f"Groundedness feedback: {groundedness_valid[0]._reason[:200]}...") -# Test 3: Function call validation with HuggingFace +# Test 3: Function call validation print("\n=== Test 3: Function Call Validation ===") tools = [ @@ -99,12 +107,12 @@ def dummy_func(**kwargs): tool_calls=hallucinated_tool_calls ) -function_session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) -function_session.ctx = function_session.ctx.add( +# Reset context with hallucinated function call +session.ctx = ChatContext().add( Message("user", "What's the weather in Boston?") ).add(hallucinated_output) -function_valid = function_session.validate([function_guardian]) +function_valid = session.validate([function_guardian]) print(f"Function calls are valid: {function_valid[0]._result}") if function_valid[0]._reason: print(f"Function call feedback: {function_valid[0]._reason[:200]}...") From c17a981197ec22d8f4effa8356a83fee562d0e85 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 13:09:42 -0400 Subject: [PATCH 12/19] simplelr gg hf example. --- .../safety.py/guardian_huggingface.py | 36 +++++++---- mellea/stdlib/safety/guardian.py | 60 ++++++++++++------- 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/docs/examples/safety.py/guardian_huggingface.py b/docs/examples/safety.py/guardian_huggingface.py index 88c184ae..eae9bc71 100644 --- a/docs/examples/safety.py/guardian_huggingface.py +++ b/docs/examples/safety.py/guardian_huggingface.py @@ -1,28 +1,35 @@ """Example of using GuardianCheck with HuggingFace backend for direct model inference -NOTE: This example reuses a single session to avoid loading multiple models into memory. -We only need to load one model (for generating test content), and the Guardian validators -share their backend instances to minimize memory usage. +This example shows how to reuse the Guardian backend across multiple validators +to avoid reloading the model multiple times. """ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.base import ChatContext, ModelOutputThunk, ModelToolCall from mellea.stdlib.chat import Message from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk print("=== GuardianCheck HuggingFace Backend Example ===") -print("NOTE: Loading Granite Guardian model (this may take a while on first run)...\n") # Create a single reusable session with Ollama backend for test content generation # We'll reuse this session and just reset its context for each test session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) -# Create Guardian validators (these share the same HF backend internally via __deepcopy__) -print("Initializing Guardian validators...") -harm_guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="huggingface") -print(f"Backend: {harm_guardian._backend.model_version}") +# Create a single shared HuggingFace backend for Guardian (loads model once) +print("Loading Granite Guardian model (this happens only once)...") +shared_guardian_backend = LocalHFBackend(model_id="ibm-granite/granite-guardian-3.3-8b") +print(f"Loaded backend: {shared_guardian_backend.model_id}\n") + +# Create Guardian validators that share the backend (no model reloading!) +print("Creating harm guardian...") +harm_guardian = GuardianCheck( + GuardianRisk.HARM, + thinking=True, + backend=shared_guardian_backend +) # Test 1: Harmful content detection print("\n=== Test 1: Harmful Content Detection ===") @@ -47,12 +54,13 @@ print("\n=== Test 2: Groundedness Detection ===") context_text = "Python is a high-level programming language created by Guido van Rossum in 1991." -# Create groundedness guardian with context +# Create groundedness guardian with context (reuse shared backend) +print("Creating groundedness guardian...") groundedness_guardian = GuardianCheck( GuardianRisk.GROUNDEDNESS, thinking=False, - backend_type="huggingface", - context_text=context_text + context_text=context_text, + backend=shared_guardian_backend ) # Reset context with ungrounded response @@ -83,11 +91,13 @@ } ] +# Create function call guardian (reuse shared backend) +print("Creating function call guardian...") function_guardian = GuardianCheck( GuardianRisk.FUNCTION_CALL, thinking=False, - backend_type="huggingface", - tools=tools + tools=tools, + backend=shared_guardian_backend ) # User asks for weather but model calls wrong function diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 755f7a4e..9376f5c8 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -75,8 +75,22 @@ def __init__( custom_criteria: Optional[str] = None, context_text: Optional[str] = None, tools: Optional[List[Dict]] = None, + backend: Optional[Backend] = None, ): - """Initialize GuardianCheck using existing backends with minimal glue.""" + """Initialize GuardianCheck using existing backends with minimal glue. + + Args: + risk: The type of risk to check for (harm, jailbreak, etc.) + backend_type: Type of backend to use ("ollama" or "huggingface") + model_version: Specific model version to use + device: Device for model inference (for HuggingFace) + ollama_url: URL for Ollama server + thinking: Enable thinking/reasoning mode + custom_criteria: Custom criteria for validation + context_text: Context document for groundedness checks + tools: Tool schemas for function call validation + backend: Pre-initialized backend to reuse (avoids loading model multiple times) + """ super().__init__(check_only=True) # Handle risk specification with custom criteria priority @@ -103,28 +117,32 @@ def __init__( self._context_text = context_text self._tools = tools - # Choose defaults and initialize the chosen backend directly. - if model_version is None: - model_version = ( - "ibm-granite/granite-guardian-3.3-8b" - if backend_type == "huggingface" - else "ibm/granite3.3-guardian:8b" - ) - - if backend_type == "huggingface": - from mellea.backends.huggingface import LocalHFBackend - self._backend = LocalHFBackend(model_id=model_version) - elif backend_type == "ollama": - from mellea.backends.ollama import OllamaModelBackend - self._backend = OllamaModelBackend(model_id=model_version, base_url=ollama_url) + # Use provided backend or create a new one + if backend is not None: + self._backend = backend else: - raise ValueError(f"Unsupported backend type: {backend_type}") + # Choose defaults and initialize the chosen backend directly. + if model_version is None: + model_version = ( + "ibm-granite/granite-guardian-3.3-8b" + if backend_type == "huggingface" + else "ibm/granite3.3-guardian:8b" + ) + + if backend_type == "huggingface": + from mellea.backends.huggingface import LocalHFBackend + self._backend = LocalHFBackend(model_id=model_version) + elif backend_type == "ollama": + from mellea.backends.ollama import OllamaModelBackend + self._backend = OllamaModelBackend(model_id=model_version, base_url=ollama_url) + else: + raise ValueError(f"Unsupported backend type: {backend_type}") - # Provide a predictable attribute for the example to print. - try: - setattr(self._backend, "model_version", model_version) - except Exception: - pass + # Provide a predictable attribute for the example to print. + try: + setattr(self._backend, "model_version", model_version) + except Exception: + pass self._logger = FancyLogger.get_logger() From 283dc176f6c516a11461adc51a68bea94f1bdafa Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 13:18:02 -0400 Subject: [PATCH 13/19] pass think to hf backend. --- mellea/backends/huggingface.py | 1 + mellea/stdlib/safety/guardian.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index cfb727da..ef268d44 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -134,6 +134,7 @@ def __init__( self.from_mellea_model_opts_map = { ModelOption.MAX_NEW_TOKENS: "max_new_tokens", "guardian_config": "guardian_config", # Pass through for Granite Guardian models + "think": "think", # Pass through for Granite Guardian thinking mode } self.default_to_constraint_checking_alora = default_to_constraint_checking_alora diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 9376f5c8..348e1ed4 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -113,14 +113,24 @@ def __init__( self._custom_criteria = custom_criteria self._thinking = thinking - self._backend_type = backend_type self._context_text = context_text self._tools = tools # Use provided backend or create a new one if backend is not None: self._backend = backend + # Infer backend_type from the provided backend + from mellea.backends.huggingface import LocalHFBackend + from mellea.backends.ollama import OllamaModelBackend + if isinstance(backend, LocalHFBackend): + self._backend_type = "huggingface" + elif isinstance(backend, OllamaModelBackend): + self._backend_type = "ollama" + else: + # Keep the provided backend_type as fallback + self._backend_type = backend_type else: + self._backend_type = backend_type # Choose defaults and initialize the chosen backend directly. if model_version is None: model_version = ( @@ -247,6 +257,7 @@ async def validate( guardian_options.update({ "guardian_config": guardian_cfg, + "think": self._thinking, # Passed to apply_chat_template (not guardian_config) "max_new_tokens": 4000 if self._thinking else 50, "stream": False, }) From f29da44efde3e950bd8712b140e9ce1930a7770d Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 15:29:04 -0400 Subject: [PATCH 14/19] pass think to hf backend. --- mellea/backends/huggingface.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index ef268d44..fc867499 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -133,8 +133,6 @@ def __init__( # for the call to the model. self.from_mellea_model_opts_map = { ModelOption.MAX_NEW_TOKENS: "max_new_tokens", - "guardian_config": "guardian_config", # Pass through for Granite Guardian models - "think": "think", # Pass through for Granite Guardian thinking mode } self.default_to_constraint_checking_alora = default_to_constraint_checking_alora From 6b18c6082f3458c5dc0dab32dcdf5d636ecac7bf Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 16:47:31 -0400 Subject: [PATCH 15/19] pass add_generation_prompt to hf backend. --- mellea/stdlib/safety/guardian.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 348e1ed4..8aa2e32d 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -258,6 +258,7 @@ async def validate( guardian_options.update({ "guardian_config": guardian_cfg, "think": self._thinking, # Passed to apply_chat_template (not guardian_config) + "add_generation_prompt": True, # Required for Guardian template "max_new_tokens": 4000 if self._thinking else 50, "stream": False, }) @@ -267,9 +268,17 @@ async def validate( if (self._risk == "function_call" or effective_risk == "function_call") and self._tools: guardian_options["tools"] = self._tools - # Generate the guardian decision with a blank assistant turn. + # Generate the guardian decision. + # For Ollama: add blank assistant turn to trigger generation + # For HuggingFace: use CBlock (won't be added to conversation, add_generation_prompt handles the judge role) + if self._backend_type == "ollama": + action = Message("assistant", "") + else: + # Use a CBlock for HuggingFace - it won't be added as a message + action = CBlock("") + mot, _ = self._backend.generate_from_context( - Message("assistant", ""), gctx, model_options=guardian_options + action, gctx, model_options=guardian_options ) await mot.avalue() From 5622cbb7ce7028c729209867c345cfbc071830a1 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 16:57:50 -0400 Subject: [PATCH 16/19] dont pass add_generation_prompt to hf generate. --- mellea/backends/huggingface.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index fc867499..981c1f09 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -381,12 +381,16 @@ def _generate_from_context_standard( # Create a separate thread to handle the processing. Make it awaitable # for non-streaming cases and to get the final output. # Details: https://huggingface.co/docs/transformers/en/internal/generation_utils#transformers.AsyncTextIteratorStreamer + + # Filter out chat template-only options before passing to generate() + generate_options = self._filter_chat_template_only_options(model_options) + chat_response = asyncio.to_thread( self._model.generate, # type: ignore input_ids, return_dict_in_generate=True, output_scores=True, - **self._make_backend_specific_and_remove(model_options), + **self._make_backend_specific_and_remove(generate_options), **streaming_kwargs, # type: ignore **format_kwargs, # type: ignore ) @@ -671,6 +675,21 @@ def _make_backend_specific_and_remove( ) return ModelOption.remove_special_keys(backend_specific) + def _filter_chat_template_only_options( + self, model_options: dict[str, Any] + ) -> dict[str, Any]: + """Remove options that are only for apply_chat_template, not for generate(). + + Args: + model_options: the model_options for this call + + Returns: + a new dict without chat template-specific options + """ + # Options that should only go to apply_chat_template, not generate() + chat_template_only = {"guardian_config", "think", "add_generation_prompt"} + return {k: v for k, v in model_options.items() if k not in chat_template_only} + def _extract_model_tool_requests( self, tools: dict[str, Callable], decoded_result: str ) -> dict[str, ModelToolCall] | None: From f0b1af927b550f828aa56c8242ba07081a56a8e5 Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 21:36:47 -0400 Subject: [PATCH 17/19] better construction of messages for guardian. --- mellea/stdlib/safety/guardian.py | 92 +++++++++++++++++++------------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 8aa2e32d..13662bc7 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -5,8 +5,9 @@ from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Context, ChatContext +from mellea.stdlib.base import CBlock, Context, ChatContext, ModelOutputThunk from mellea.stdlib.chat import Message +from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ValidationResult @@ -192,43 +193,57 @@ async def validate( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, ) -> ValidationResult: - """Validate the last turn using Granite Guardian via selected backend.""" + """Validate conversation using Granite Guardian via selected backend.""" logger = self._logger - last_turn = ctx.last_turn() - if last_turn is None: - logger.warning("No last turn found in context") - return ValidationResult(False, reason="No content to validate") - - # Build a fresh chat context for the guardian model. + # Build a fresh chat context for the guardian model (keep it minimal). gctx = ChatContext() effective_risk = self.get_effective_risk() - if (self._risk == "groundedness" or effective_risk == "groundedness") and self._context_text: + # For groundedness: add doc only for Ollama; HF receives context via guardian_config + if ( + (self._risk == "groundedness" or effective_risk == "groundedness") + and self._context_text + and self._backend_type == "ollama" + ): gctx = gctx.add(Message("user", f"Document: {self._context_text}")) - # Add the last user message if present. - if last_turn.model_input is not None: - if isinstance(last_turn.model_input, CBlock) and last_turn.model_input.value is not None: - gctx = gctx.add(Message("user", last_turn.model_input.value)) - elif isinstance(last_turn.model_input, Message): - gctx = gctx.add(Message(last_turn.model_input.role, last_turn.model_input.content)) - else: - gctx = gctx.add(Message("user", str(last_turn.model_input))) - - # Add the assistant response, optionally including tool call info for function_call risk. - if last_turn.output is not None: - assistant_text = last_turn.output.value or "" - if getattr(last_turn.output, "tool_calls", None) and (self._risk == "function_call" or effective_risk == "function_call"): - calls = [] - for name, tc in last_turn.output.tool_calls.items(): - calls.append(f"{name}({getattr(tc, 'args', {})})") - if calls: - suffix = f" [Function calls: {', '.join(calls)}]" - assistant_text = (assistant_text + suffix) if assistant_text else suffix - if assistant_text: - gctx = gctx.add(Message("assistant", assistant_text)) + # Try to reuse chat history directly when available. + messages = None + try: + from mellea.stdlib.chat import as_chat_history + messages = as_chat_history(ctx) + except Exception: + messages = None + + if messages: + for m in messages: + gctx = gctx.add(m) + else: + # Fallback: build from the last turn only + last_turn = ctx.last_turn() + if last_turn is None: + logger.warning("No last turn found in context") + return ValidationResult(False, reason="No content to validate") + + if last_turn.model_input is not None: + gctx = gctx.add(last_turn.model_input) + + if last_turn.output is not None: + # For function call risk, append tool call info as text; otherwise add thunk directly. + if self._risk == "function_call" or effective_risk == "function_call": + content = last_turn.output.value or "" + tcalls = getattr(last_turn.output, "tool_calls", None) + if tcalls: + calls = [f"{name}({getattr(tc, 'args', {})})" for name, tc in tcalls.items()] + if calls: + suffix = f" [Tool calls: {', '.join(calls)}]" + content = (content + suffix) if content else suffix + if content: + gctx = gctx.add(Message("assistant", content)) + else: + gctx = gctx.add(last_turn.output) # Ensure we have something to validate. history = gctx.view_for_generation() or [] @@ -248,21 +263,24 @@ async def validate( "think": True if self._thinking else None, }) else: # huggingface - # HF chat template for guardian expects guardian_config instead of a system message - guardian_cfg: Dict[str, object] = {"risk": effective_risk} + # HF chat template for Guardian expects guardian_config and (optionally) documents + guardian_cfg: Dict[str, object] = {"criteria_id": effective_risk} if self._custom_criteria: - guardian_cfg["custom_criteria"] = self._custom_criteria - if self._context_text and (self._risk == "groundedness" or effective_risk == "groundedness"): - guardian_cfg["context"] = self._context_text + # When using custom criteria, provide it as free-text criteria + guardian_cfg["criteria_text"] = self._custom_criteria guardian_options.update({ "guardian_config": guardian_cfg, - "think": self._thinking, # Passed to apply_chat_template (not guardian_config) - "add_generation_prompt": True, # Required for Guardian template + "think": self._thinking, # Passed to apply_chat_template + "add_generation_prompt": True, # Guardian template requires a generation prompt "max_new_tokens": 4000 if self._thinking else 50, "stream": False, }) + # Provide documents parameter for groundedness + if self._context_text and (self._risk == "groundedness" or effective_risk == "groundedness"): + guardian_options["documents"] = [{"doc_id": "0", "text": self._context_text}] + # Attach tools for function_call checks. # Guardian only needs tool schemas for validation, not actual callable functions. if (self._risk == "function_call" or effective_risk == "function_call") and self._tools: From 26bb877b135d1b1a8e553e5af9ec95a2ec6ac4ae Mon Sep 17 00:00:00 2001 From: Prasanna Sattigeri Date: Wed, 1 Oct 2025 21:39:12 -0400 Subject: [PATCH 18/19] better construction of messages for guardian. --- mellea/backends/huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 981c1f09..53f7c385 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -687,7 +687,7 @@ def _filter_chat_template_only_options( a new dict without chat template-specific options """ # Options that should only go to apply_chat_template, not generate() - chat_template_only = {"guardian_config", "think", "add_generation_prompt"} + chat_template_only = {"guardian_config", "think", "add_generation_prompt", "documents"} return {k: v for k, v in model_options.items() if k not in chat_template_only} def _extract_model_tool_requests( From 75069381e327acdf96a97131402954a728ceb0ea Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Mon, 6 Oct 2025 10:49:48 -0700 Subject: [PATCH 19/19] chore: fixing some ruff issues --- .../safety.py/guardian_huggingface.py | 59 ++++++++++--------- .../safety.py/repair_with_guardian.py | 36 ++++++----- test/backends/test_huggingface_tools.py | 2 +- 3 files changed, 52 insertions(+), 45 deletions(-) diff --git a/docs/examples/safety.py/guardian_huggingface.py b/docs/examples/safety.py/guardian_huggingface.py index eae9bc71..3cc3d507 100644 --- a/docs/examples/safety.py/guardian_huggingface.py +++ b/docs/examples/safety.py/guardian_huggingface.py @@ -26,9 +26,7 @@ # Create Guardian validators that share the backend (no model reloading!) print("Creating harm guardian...") harm_guardian = GuardianCheck( - GuardianRisk.HARM, - thinking=True, - backend=shared_guardian_backend + GuardianRisk.HARM, thinking=True, backend=shared_guardian_backend ) # Test 1: Harmful content detection @@ -36,10 +34,10 @@ print("Testing content: 'You are very ugly and stupid.'") # Reset context for this test -session.ctx = ChatContext().add( - Message("user", "Hello") -).add( - Message("assistant", "You are very ugly and stupid.") +session.ctx = ( + ChatContext() + .add(Message("user", "Hello")) + .add(Message("assistant", "You are very ugly and stupid.")) ) validation_result = session.validate([harm_guardian]) @@ -52,7 +50,9 @@ # Test 2: Groundedness detection print("\n=== Test 2: Groundedness Detection ===") -context_text = "Python is a high-level programming language created by Guido van Rossum in 1991." +context_text = ( + "Python is a high-level programming language created by Guido van Rossum in 1991." +) # Create groundedness guardian with context (reuse shared backend) print("Creating groundedness guardian...") @@ -60,14 +60,19 @@ GuardianRisk.GROUNDEDNESS, thinking=False, context_text=context_text, - backend=shared_guardian_backend + backend=shared_guardian_backend, ) # Reset context with ungrounded response -session.ctx = ChatContext().add( - Message("user", "Who created Python?") -).add( - Message("assistant", "Python was created by Dennis Ritchie in 1972 for use in Unix systems.") +session.ctx = ( + ChatContext() + .add(Message("user", "Who created Python?")) + .add( + Message( + "assistant", + "Python was created by Dennis Ritchie in 1972 for use in Unix systems.", + ) + ) ) groundedness_valid = session.validate([groundedness_guardian]) @@ -82,12 +87,7 @@ { "name": "get_weather", "description": "Gets weather for a location", - "parameters": { - "location": { - "description": "City name", - "type": "string" - } - } + "parameters": {"location": {"description": "City name", "type": "string"}}, } ] @@ -97,34 +97,35 @@ GuardianRisk.FUNCTION_CALL, thinking=False, tools=tools, - backend=shared_guardian_backend + backend=shared_guardian_backend, ) + # User asks for weather but model calls wrong function def dummy_func(**kwargs): pass + hallucinated_tool_calls = { "get_stock_price": ModelToolCall( - name="get_stock_price", - func=dummy_func, - args={"symbol": "AAPL"} + name="get_stock_price", func=dummy_func, args={"symbol": "AAPL"} ) } hallucinated_output = ModelOutputThunk( - value="Let me get the weather for you.", - tool_calls=hallucinated_tool_calls + value="Let me get the weather for you.", tool_calls=hallucinated_tool_calls ) # Reset context with hallucinated function call -session.ctx = ChatContext().add( - Message("user", "What's the weather in Boston?") -).add(hallucinated_output) +session.ctx = ( + ChatContext() + .add(Message("user", "What's the weather in Boston?")) + .add(hallucinated_output) +) function_valid = session.validate([function_guardian]) print(f"Function calls are valid: {function_valid[0]._result}") if function_valid[0]._reason: print(f"Function call feedback: {function_valid[0]._reason[:200]}...") -print("\n=== HuggingFace Guardian Demo Complete ===") \ No newline at end of file +print("\n=== HuggingFace Guardian Demo Complete ===") diff --git a/docs/examples/safety.py/repair_with_guardian.py b/docs/examples/safety.py/repair_with_guardian.py index 1ae85bbe..c2c1d20a 100644 --- a/docs/examples/safety.py/repair_with_guardian.py +++ b/docs/examples/safety.py/repair_with_guardian.py @@ -33,17 +33,15 @@ def get_stock_price(symbol: str) -> str: "parameters": { "symbol": { "description": "The stock symbol to get price for (must be 3-5 uppercase letters like TSLA, AAPL)", - "type": "string" + "type": "string", } - } + }, } ] # Guardian validates function calls against tool schema guardian = GuardianCheck( - GuardianRisk.FUNCTION_CALL, - thinking=True, - tools=tool_schemas + GuardianRisk.FUNCTION_CALL, thinking=True, tools=tool_schemas ) test_prompt = "What's the price of Tesla stock?" @@ -59,19 +57,25 @@ def get_stock_price(symbol: str) -> str: "seed": 789, "tools": [get_stock_price], # Intentionally misconfigured to demonstrate repair - "system": "When users ask about stock prices, use the full company name as the symbol parameter. For example, use 'Tesla Motors' instead of 'TSLA'." + "system": "When users ask about stock prices, use the full company name as the symbol parameter. For example, use 'Tesla Motors' instead of 'TSLA'.", }, - tool_calls=True + tool_calls=True, ) # Show repair process - for attempt_num, (generation, validations) in enumerate(zip(result.sample_generations, result.sample_validations), 1): + for attempt_num, (generation, validations) in enumerate( + zip(result.sample_generations, result.sample_validations), 1 + ): print(f"\nAttempt {attempt_num}:") # Show what was sent to the model - if hasattr(result, 'sample_actions') and result.sample_actions and attempt_num <= len(result.sample_actions): + if ( + hasattr(result, "sample_actions") + and result.sample_actions + and attempt_num <= len(result.sample_actions) + ): action = result.sample_actions[attempt_num - 1] - if hasattr(m.backend, 'formatter'): + if hasattr(m.backend, "formatter"): try: rendered = m.backend.formatter.print(action) print(f" Instruction sent to model:") @@ -82,7 +86,7 @@ def get_stock_price(symbol: str) -> str: pass # Show function calls made - if hasattr(generation, 'tool_calls') and generation.tool_calls: + if hasattr(generation, "tool_calls") and generation.tool_calls: for name, tool_call in generation.tool_calls.items(): print(f" Function: {name}({tool_call.args})") @@ -91,11 +95,13 @@ def get_stock_price(symbol: str) -> str: status = "PASS" if validation.as_bool() else "FAIL" print(f" Status: {status}") - print(f"\n{'='*60}") - print(f"Result: {'SUCCESS' if result.success else 'FAILED'} after {len(result.sample_generations)} attempt(s)") - print(f"{'='*60}") + print(f"\n{'=' * 60}") + print( + f"Result: {'SUCCESS' if result.success else 'FAILED'} after {len(result.sample_generations)} attempt(s)" + ) + print(f"{'=' * 60}") return result if __name__ == "__main__": - demo_repair_with_actual_function_calling() \ No newline at end of file + demo_repair_with_actual_function_calling() diff --git a/test/backends/test_huggingface_tools.py b/test/backends/test_huggingface_tools.py index e399bc99..f78898ec 100644 --- a/test/backends/test_huggingface_tools.py +++ b/test/backends/test_huggingface_tools.py @@ -23,7 +23,7 @@ def backend(): """Shared HuggingFace backend for all tests in this module.""" backend = LocalHFBackend( - model_id=model_ids.IBM_GRANITE_4_MICRO_3B, + model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B, cache=SimpleLRUCache(5), ) # add_granite_aloras(backend)