# Generating global explanations of LLM-as-a-Judge using GloVE algorithm


In [1]:
from risk_policy_distillation.datasets.prompt_response_dataset import (
    PromptResponseDataset,
)
from risk_policy_distillation.datasets.abs_dataset import AbstractDataset
from risk_policy_distillation.models.explainers.local_explainers.lime import LIME
from risk_policy_distillation.models.explainers.local_explainers.shap_vals import SHAP
from risk_policy_distillation.models.guardians.guardian import Guardian
from risk_policy_distillation.pipeline.clusterer import Clusterer
from risk_policy_distillation.pipeline.concept_extractor import Extractor
from risk_policy_distillation.pipeline.pipeline import Pipeline
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


INFO 10-30 12:07:41 [__init__.py:216] Automatically detected platform cpu.


#### Create a dataset

To explain the LLM-as-a-Judge you need to provide a dataset. [AbstractDataset](../src/models/datasets/abs_dataset.py) class provides a wrapper for a dataframe you want to explain. You can use [PromptDataset](../src/models/datasets/prompt_dataset.py) or [PromptResponseDataset](../src/models/datasets/prompt_response_dataset.py) depending on whether your dataframe consists of only prompts or prompt-response pairs. You can also create a custom dataset by inheriting the Dataset class.

You have to provide a config with information on column name mapping. Additional parameters: _flip_labels_ indicates whether labels of the dataframe should be flipped in preprocessing step (e.g. for BeaverTails where labels indicate that the content is safe rather than harmful); _split_ indicates whether a train-val-test split needs to be performed during preprocessing.


In [2]:
# get the dataset
ds = load_dataset("PKU-Alignment/BeaverTails")
dataframe = ds.data["330k_train"].table.to_pandas()

# sample for testing purposes
dataframe = dataframe.sample(3)

# provide some information in the data config
data_config = {
    "dataset_name": "BeaverTails",
    "index_col": "",  # if there is no index column it will be automatically populated
    "prompt_col": "prompt",
    "response_col": "response",
    "label_col": "is_safe",
}

dataset = PromptResponseDataset(dataframe=dataframe, config=data_config)

### Create a guardian config


In [3]:
guardian_config = {
    "task": f"Harm detection",
    "criterion": f"Harm",
    "criterion_definition": f"The message has harmful by common-sense.",
    "labels": [0, 1],
    "label_names": ["harmless", "harmful"],
    "output_labels": ["no", "yes"],
}

##### Policy Distillation uses granite gurdian as a LLM-as-a-Judge, and other Large Language Models (LLMs) to query at various stages in the pipeline. Therefore requires access to inference or call the model.


In [None]:
from openai import OpenAI
from typing import Dict, List, Union, Any, Optional
import dataclasses
from multiprocessing.pool import ThreadPool
from tqdm.autonotebook import tqdm
import json


def run_parallel(func, items, concurrency_limit: int = 10, verbose=True):
    outputs = []
    with ThreadPool(processes=concurrency_limit) as pool:
        for output in tqdm(
            pool.imap(func, items), total=len(items), disable=(not verbose)
        ):
            outputs.append(output)

    return outputs


@dataclasses.dataclass(kw_only=True)
class TextGenerationInferenceOutput:

    prediction: Union[str, List[Dict[str, Any]]]
    logprobs: Optional[Dict[str, float]] = None


class RITSInferenceEngine:

    def __init__(self, model_name_or_path, parameters={}):
        self.model_name_or_path = model_name_or_path
        self.parameters = parameters
        model_name_for_endpoint = (
            model_name_or_path.split("/")[-1].lower().replace(".", "-")
        )
        self.model = OpenAI(
            api_key="RITS_API_KEY",
            base_url=f"https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{model_name_for_endpoint}/v1",
            default_headers={"RITS_API_KEY": "RITS_API_KEY"},
        )

    def chat(
        self,
        messages,
        response_format=None,
        postprocessors=None,
    ) -> TextGenerationInferenceOutput:

        def chat_response(messages):
            response = self.model.chat.completions.create(
                messages=messages,
                model=self.model_name_or_path,
                response_format=self._create_schema_format(response_format),
                **self.parameters,
            )
            return self._prepare_chat_output(response, postprocessors)

        return run_parallel(chat_response, messages)

    def _prepare_chat_output(self, response, postprocessors):
        return TextGenerationInferenceOutput(
            prediction=(
                json.loads(response.choices[0].message.content)
                if postprocessors
                else response.choices[0].message.content
            ),
            logprobs=(
                {
                    output.token.strip(): output.logprob
                    for output in response.choices[0].logprobs.content
                }
                if response.choices[0].logprobs
                else None
            ),
        )

    def _create_schema_format(self, response_format):
        if response_format:
            return {
                "type": "json_schema",
                "json_schema": {
                    "name": "RITS_schema",
                    "schema": response_format,
                },
            }
        else:
            return None


guardian_judge = RITSInferenceEngine(
    model_name_or_path="ibm-granite/granite-guardian-3.3-8b",
    parameters={"logprobs": True, "top_logprobs": 10, "temperature": 0.0},
)

llm_component = RITSInferenceEngine(
    model_name_or_path="meta-llama/llama-3-3-70b-instruct"
)

### Create and run the explanation generation pipeline

Pipeline streamlines local and global explanation generation process. Extractor executes the CLoVE algorithm and generates a set of local explanations, and Clusterer executes GloVE algorithm and merges the local explanations into a global one.

Pass `lime=False` to pipeline creation step if no local word-based verification is done. SImilarly, use `fr=False` if FactReasoner is not used to verify global explanations.

The resulting local and global explanations are saved in the path folder passed to the pipeline.run() call.
The execution logs can be found in the logs folder.


In [5]:
# Create an instance of the guardian model
guardian = Guardian(
    inference_engine=guardian_judge,
    config=guardian_config,
)

local_expl = "LIME"
# local explanation model -- only LIME and SHAP are supported
if local_expl == "LIME":
    local_explainer = LIME(
        dataset.dataset_name, guardian_config["label_names"], n_samples=100
    )
elif local_expl == "SHAP":
    local_explainer = SHAP(
        dataset.dataset_name, guardian_config["label_names"], n_samples=100
    )
else:
    raise ValueError("Only LIME and SHAP are supported")

# Create pipeline
pipeline = Pipeline(
    extractor=Extractor(
        guardian,
        llm_component,
        guardian_config["criterion"],
        guardian_config["criterion_definition"],
        local_explainer,
    ),
    clusterer=Clusterer(
        llm_component,
        guardian_config["criterion_definition"],
        guardian_config["label_names"],
        n_iter=10,
    ),
    lime=True,
    fr=True,
)

# Run pipeline
expl = pipeline.run(dataset)

print(expl.print())

100%|██████████| 1/1 [00:00<00:00,  1.23it/s]
100%|██████████| 100/100 [00:04<00:00, 22.38it/s]
100%|██████████| 1/1 [00:01<00:00,  1.02s/it]
100%|██████████| 1/1 [00:00<00:00,  1.70it/s]
100%|██████████| 1/1 [00:02<00:00,  2.09s/it]
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  1.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  1.80it/s]
100%|██████████| 100/100 [00:04<00:00, 21.73it/s]
100%|██████████| 1/1 [00:01<00:00,  1.06s/it]
100%|██████████| 1/1 [00:00<00:00,  1.65it/s]
100%|██████████| 1/1 [00:00<00:00,  2.01it/s]
100%|██████████| 1/1 [00:00<00:00,  2.37it/s]
100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 1/1 [00:01<00:00,  1.01s/it]
100%|██████████| 1/1 [00:00<00:00,  1.59it/s]
100%|██████████| 1/1 [00:0

RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 74235.47prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.48it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 39199.10prompts/s]
100%|██████████| 1/1 [00:00<00:00,  1.51it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 24966.10prompts/s]
100%|██████████| 1/1 [00:00<00:00,  1.68it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 73584.28prompts/s]
100%|██████████| 1/1 [00:01<00:00,  1.63s/it]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 64527.75prompts/s]
100%|██████████| 1/1 [00:00<00:00,  1.11it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 27685.17prompts/s]
100%|██████████| 1/1 [00:00<00:00,  1.99it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 30727.50prompts/s]
100%|██████████| 1/1 [00:00<00:00,  1.98it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 25575.02prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 58661.59prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 51150.05prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 22192.08prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 25040.62prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.19it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 32896.50prompts/s]
100%|██████████| 1/1 [00:06<00:00,  6.28s/it]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 53773.13prompts/s]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 28926.23prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.40it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 28532.68prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.77it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 31300.78prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 31300.78prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.45it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 31184.42prompts/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]


RITS = True Model = llama-3.1-8b-instruct
[NLIExtractor] Using LLM on RITS: llama-3.1-8b-instruct
[NLIExtractor] Prompt version: v1
[NLIExtractor] Prompts created: 2


NLI: 100%|██████████| 2/2 [00:00<00:00, 31300.78prompts/s]

{'rules': [{'prediction': 1, 'if_clause': 'involves race issues', 'despite_clauses': ['provides factual information']}, {'prediction': 1, 'if_clause': 'suggests eating animal waste', 'despite_clauses': ['provides factual information']}, {'prediction': 1, 'if_clause': 'promotes vigilantism', 'despite_clauses': ['provides factual information']}]}



