# Running Llama Guard inference

This notebook is intented to showcase how to run Llama Guard inference on a big dataset for testing.

In [21]:
from transformers import AutoTokenizer, AutoModelForCausalLM

from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
# import generation 

from typing import List, Optional, Tuple, Dict
from enum import Enum


# Defining the main functions

Agent type enum to define what type of inference Llama Guard should be doing, either User or Agent.

Defining the function that loads 

In [22]:
class AgentType(Enum):
    AGENT = "Agent"
    USER = "User"

def llm_eval(prompts, load_in_8bit=True, logprobs = False) -> Tuple[List[str], Optional[List[List[Tuple[int, float]]]]]:

    model_id = "meta-llama/LlamaGuard-7b"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=load_in_8bit, device_map="auto")

    results: List[str] = []
    if logprobs:
        result_logprobs: List[List[Tuple[int, float]]] = []

    for prompt in prompts:
        formatted_prompt = build_prompt(
                prompt["agent_type"], 
                LLAMA_GUARD_CATEGORY, 
                create_conversation(prompt["prompt"]))


        input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
        prompt_len = input["input_ids"].shape[-1]
        output = model.generate(**input, max_new_tokens=10, pad_token_id=0, return_dict_in_generate=True, output_scores=logprobs)
        
        if logprobs:
            transition_scores = model.compute_transition_scores(
                output.sequences, output.scores, normalize_logits=True)

        generated_tokens = output.sequences[:, prompt_len:]
        
        if logprobs:
            temp_logprobs: List[Tuple[int, float]] = []
            for tok, score in zip(generated_tokens[0], transition_scores[0]):
                temp_logprobs.append((tok.cpu().numpy(), score.cpu().numpy()))
            
            result_logprobs.append(temp_logprobs)
            prompt["logprobs"] = temp_logprobs
        
        # result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
        result = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)    

        prompt["result"] = result
        results.append(result)

    return (results, result_logprobs if logprobs else None)  

In [23]:
def main():
   
    prompts: List[Dict[List[str], AgentType]] = [
        {
            "prompt": ["<Sample user prompt>"],
            "agent_type": AgentType.USER
        },
        {
            "prompt": ["<Sample user prompt>", "<Sample agent response>"],
            "agent_type": AgentType.AGENT
        },
        {
            "prompt": ["<Sample user prompt>", 
                       "<Sample agent response>", 
                       "<Sample user reply>", 
                       "<Sample agent response>"],
            "agent_type": AgentType.AGENT
        }
    ]

    
    results = llm_eval(prompts)
    
    for i, prompt in enumerate(prompts):
        print(prompt['prompt'])
        print(f"> {results[0][i]}")
        print("\n==================================\n")

In [24]:
# used to be able to import this script in another notebook and not run the main function
if __name__ == '__main__' and '__file__' not in globals():
    main()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

['<Sample user prompt>']
> safe


['<Sample user prompt>', '<Sample agent response>']
> safe


['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
> safe


