From 138b92ae69ec2eab4a90f9f894e81b379dc38bd7 Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 21:26:03 -0700 Subject: [PATCH 1/4] llama_guard inference fix --- llama_toolchain/safety/shields/llama_guard.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94be0e06cd..790ff4def0 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -85,8 +85,8 @@ PROMPT_INSTRUCTIONS = """ Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: -- First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories.""" + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories.""" PROMPT_TEMPLATE = Template( @@ -240,9 +240,39 @@ async def run(self, messages: List[Message]) -> ShieldResponse: response = self.tokenizer.decode( generated_tokens[0], skip_special_tokens=True ) - - response = response.strip() + cprint(f" Llama Guard response {response}", color="magenta") + response = response.strip() shield_response = self.get_shield_response(response) - cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response + + + + + + '''if self.disable_input_check and messages[-1].role == "user": + return ShieldResponse(is_violation=False) + elif self.disable_output_check and messages[-1].role == "assistant": + return ShieldResponse(is_violation=False) + else: + prompt = self.build_prompt(messages) + llama_guard_input = { + "role": "user", + "content": prompt, + } + input_ids = self.tokenizer.apply_chat_template( + [llama_guard_input], return_tensors="pt", tokenize=True + ).to(self.device) + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=50, + output_scores=True, + return_dict_in_generate=True, + pad_token_id=0 + ) + generated_tokens = output.sequences[:, prompt_len:] + response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) + response = response.strip() + shield_response = self.get_shield_response(response) + return shield_response''' From 16fe0e45940068f12cd0afef5046df4db3bb752f Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 21:59:57 -0700 Subject: [PATCH 2/4] clean up and add license --- llama_toolchain/safety/shields/llama_guard.py | 46 ++----------------- 1 file changed, 4 insertions(+), 42 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94f111fb75..12aadc3057 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,4 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in @@ -16,8 +15,7 @@ from typing import List, Optional import torch -from llama_models.llama3_1.api.datatypes import Message -from termcolor import cprint +from llama_models.llama3_1.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -112,7 +110,7 @@ class LlamaGuardShield(ShieldBase): def instance( on_violation_action=OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = None, + excluded_categories: List[str] = [], disable_input_check: bool = False, disable_output_check: bool = False, ) -> "LlamaGuardShield": @@ -131,7 +129,7 @@ def __init__( self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = None, + excluded_categories: List[str] = [], disable_input_check: bool = False, disable_output_check: bool = False, ): @@ -141,8 +139,6 @@ def __init__( assert model_dir is not None, "Llama Guard model_dir is None" - if excluded_categories is None: - excluded_categories = [] assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" @@ -221,8 +217,7 @@ def get_shield_response(self, response: str) -> ShieldResponse: raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - - if self.disable_input_check and messages[-1].role == Role.user.value: + if self.disable_input_check and messages[-1].role == Role.user.value: return ShieldResponse( shield_type=BuiltinShield.llama_guard, is_violation=False ) @@ -254,39 +249,6 @@ async def run(self, messages: List[Message]) -> ShieldResponse: response = self.tokenizer.decode( generated_tokens[0], skip_special_tokens=True ) - cprint(f" Llama Guard response {response}", color="magenta") response = response.strip() shield_response = self.get_shield_response(response) - cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response - - - - - - '''if self.disable_input_check and messages[-1].role == "user": - return ShieldResponse(is_violation=False) - elif self.disable_output_check and messages[-1].role == "assistant": - return ShieldResponse(is_violation=False) - else: - prompt = self.build_prompt(messages) - llama_guard_input = { - "role": "user", - "content": prompt, - } - input_ids = self.tokenizer.apply_chat_template( - [llama_guard_input], return_tensors="pt", tokenize=True - ).to(self.device) - prompt_len = input_ids.shape[1] - output = self.model.generate( - input_ids=input_ids, - max_new_tokens=50, - output_scores=True, - return_dict_in_generate=True, - pad_token_id=0 - ) - generated_tokens = output.sequences[:, prompt_len:] - response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) - response = response.strip() - shield_response = self.get_shield_response(response) - return shield_response''' From ab8a220faa7dd70fa9f0ec98b8568e1c5662fe0e Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 22:03:05 -0700 Subject: [PATCH 3/4] add missing license part --- llama_toolchain/safety/shields/llama_guard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 12aadc3057..d6154f027d 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in @@ -217,7 +218,7 @@ def get_shield_response(self, response: str) -> ShieldResponse: raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - if self.disable_input_check and messages[-1].role == Role.user.value: + if self.disable_input_check and messages[-1].role == Role.user.value: return ShieldResponse( shield_type=BuiltinShield.llama_guard, is_violation=False ) From ab829b055736e87173b6748000cfa0525469e2e6 Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 22:09:44 -0700 Subject: [PATCH 4/4] revert excluded cat defaults --- llama_toolchain/safety/shields/llama_guard.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index d6154f027d..5234c8e1f3 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -111,7 +111,7 @@ class LlamaGuardShield(ShieldBase): def instance( on_violation_action=OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ) -> "LlamaGuardShield": @@ -130,7 +130,7 @@ def __init__( self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ): @@ -140,6 +140,9 @@ def __init__( assert model_dir is not None, "Llama Guard model_dir is None" + if excluded_categories is None: + excluded_categories = [] + assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"