diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index ed7fb6c489..5234c8e1f3 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -16,8 +16,7 @@ from typing import List, Optional import torch -from llama_models.llama3_1.api.datatypes import Message -from termcolor import cprint +from llama_models.llama3_1.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -97,8 +96,8 @@ PROMPT_INSTRUCTIONS = """ Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: -- First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories.""" + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories.""" PROMPT_TEMPLATE = Template( @@ -143,6 +142,7 @@ def __init__( if excluded_categories is None: excluded_categories = [] + assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" @@ -221,7 +221,6 @@ def get_shield_response(self, response: str) -> ShieldResponse: raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - if self.disable_input_check and messages[-1].role == Role.user.value: return ShieldResponse( shield_type=BuiltinShield.llama_guard, is_violation=False @@ -254,9 +253,6 @@ async def run(self, messages: List[Message]) -> ShieldResponse: response = self.tokenizer.decode( generated_tokens[0], skip_special_tokens=True ) - - response = response.strip() + response = response.strip() shield_response = self.get_shield_response(response) - - cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response