# WIP Validating the output of Llama Guard quantized and unquantized

This notebook aims to show how to validate Llama Guard performance on a given dataset. The dataset format is:
```
{
    "prompt": "user_input",
    "generation": "model_response",
    "label": "good/bad", 
    "unsafe_content": ["O1"]
}
```
The `label` has a `good` or `bad` value to define if the content is considered safe or unsafe, respectively.

The scripts takes the prompts from the file and runs them through Llama Guard (loaded from Hugging face) and reports the resulting logprobs.

Note: the logprobs for the inputs considered to be safe are calculated using a simple heuristic: 
 1 - `safe` probability to avoid getting the probability from the model for a non selected token.


In [1]:
import json
# from examples.llama_guard.inference import llm_eval, standard_llm_eval, AgentType
from typing import List, Tuple

from enum import Enum
import time
from pathlib import Path
import numpy as np
from sklearn.metrics import average_precision_score


In [2]:
%run Inference.ipynb

In [3]:
%run Pytorch_Inference.ipynb

In [4]:
class Type(Enum):
    HF = "HF"
    PYTORCH = "PYTORCH"

def format_prompt(entry, agent_type: AgentType):
    prompts = []
    if agent_type == AgentType.USER:
        prompts = [entry["prompt"]]
    else:
        prompts = [entry["prompt"], entry["generation"]]

    return {
        "prompt": prompts,
        "agent_type": agent_type,
        "label": entry["label"],
        "unsafe_content": entry["unsafe_content"],
        "idx": entry["idx"]
    }

def validate_agent_type(value):
    try:
        return AgentType(value)
    except ValueError:
        raise ValueError(f"Invalid AgentType. Choose from: {[agent_type.value for agent_type in AgentType]}")

def parse_logprobs(prompts):
    positive_class_probs = []
    for prompt in prompts:
        prob = np.exp(prompt["logprobs"][0]) if isinstance(prompt["logprobs"][0], float) else np.exp(prompt["logprobs"][0][1])
        if "unsafe" in prompt["result"]:
            positive_class_probs.append(prob)
        else:
            # Using heuristic 1 - `safe` probability to avoid getting the probability from the model for a non selected token.
            positive_class_probs.append(1 - prob)
        
    binary_labels = [1 if prompt["label"] == "bad" else 0 for prompt in prompts]

    return average_precision_score(binary_labels, positive_class_probs)

def run_validation(jsonl_file_path, agent_type, type: Type, load_in_8bit: bool = True, logprobs: bool = True):

    input_file_path = Path(jsonl_file_path)

    # Extracting filename without extension
    filename = input_file_path.stem
    directory = input_file_path.parent

    agent_type = validate_agent_type(agent_type)
    
    # Preparing prompts
    prompts: List[Tuple[List[str], AgentType, str, str, str]] = []
    with open(jsonl_file_path, "r") as f:
        # temp
        index = 0 
        for i, line in enumerate(f):
            if index == 200:
                break
            index += 1

            entry = json.loads(line)
            
            # Call Llama Guard and get its output
            prompt = format_prompt(entry, agent_type)
            prompts.append(prompt)
            
            

    # Executing evaluation
    start = time.time()
    
    
    if type is Type.HF:
        llm_eval(prompts, load_in_8bit, logprobs)
    else:
        pytorch_llm_eval(prompts, "../../../llama/models/guard-llama/", logprobs)
    end = time.time()
    print(f"evaluation executed in {end - start} seconds")
    
    parsed_results = {}
    if logprobs:
        average_precision = parse_logprobs(prompts)
        parsed_results["average_precision"] = average_precision
        print(f"average precision {average_precision:.2%}")

    # Output filenames and paths
    # current_date = datetime.now().strftime('%Y-%m-%d-%H:%M')
    # output_filename = f"{filename}_{agent_type.value}_{type.value}_{'load_in_8bit' if load_in_8bit else 'noquant'}_output_{current_date}.jsonl"
    # output_stats_filename = f"{filename}_{agent_type.value}_{type.value}_{'load_in_8bit' if load_in_8bit else 'noquant'}_stats_{current_date}.json"
    # output_file_path = directory / output_filename
    # output_stats_file_path = directory / output_stats_filename

    # # Write the list of dictionaries to the file
    # with open(output_file_path, 'w') as file:
    #     for prompt in prompts:
    #         # Serialize each dictionary to a JSON formatted string
    #         prompt.pop("logprobs", None)
    #         json_str = json.dumps(prompt, default=lambda o: o.value if isinstance(o, Enum) else o)
    #         # Write the JSON string to the file followed by a newline
    #         file.write(json_str + '\n')
        
    # # Write the stats from the run
    # with open(output_stats_file_path, 'w') as f:
    #     json.dump(parsed_results, f, indent=4)

In [5]:
prompts_file = "valid_prompts_6cat_1122.jsonl"
responses_file = "valid_responses_6cat_1122.jsonl"

run_validation(prompts_file, AgentType.USER, Type.HF)
# run_validation(prompts_file, AgentType.USER, Type.PYTORCH)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



evaluation executed in 155.58132243156433 seconds
average precision 95.25%
