# Research aside - Hallucination detection & LLM explainability

NB: May require a T4 GPU and Colab high-RAM to get the model loaded.


## Setup - Install dependencies

In [None]:
!pip install datasets # &> /dev/null

## Declare LLMAgent object

We use a single handler class to hold the model and our token attributes to avoid repeating boilerplate.

In [None]:
"""
Python script containing explainability and visualisation
utilities for large language models.
"""

import io
import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from datasets import load_dataset

logging.basicConfig(level=logging.INFO)


class LLMAgent:
    """
    Class encapsulating explainability and visualisation
    utilities for large language models.
    """

    def __init__(
        self,
        hf_api_key: str = None,
        model_id: str = "microsoft/Phi-3-mini-4k-instruct",
        use_gpu: bool = False,
        prompt_suffix="<|end|>\n",
        user_prompt_start="<|user|>\n",
        assistant_prompt_start="<|assistant|>\n",
        system_prompt_start="<|system|>\n",
        system_prompt="You are a helpful AI assistant.\n",
        end_token="<|end|>",
        eot_token="<|endoftext|>",
    ):
        """
        Initialise the LLMAgent object.
        """
        logging.info(
            "LLMAgent downloading/ensuring presence of large language model: %s.",
            model_id,
        )
        self.device = "cuda" if use_gpu else "mps" if torch.has_mps else "cpu"
        logging.info("LLMAgent sending model to device: %s", self.device)
        self.processor = AutoProcessor.from_pretrained(model_id, token=hf_api_key)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            # torch_dtype=torch.bfloat16,
            token=hf_api_key,
            output_hidden_states=True,
            output_attentions=True,
            do_sample=True,
        ).to(self.device)
        self.prompt_suffix = prompt_suffix
        self.user_prompt_start = user_prompt_start
        self.assistant_prompt_start = assistant_prompt_start
        self.end_token = end_token
        self.eot_token = eot_token
        self.system_prompt_start = system_prompt_start
        self.system_prompt = system_prompt
        logging.info("LLMAgent initialisaed.")

    def generate_with_response_dict(
        self, prompt: str, max_tokens: int = 200, temperature=0.0
    ):
        """
        Method to inference on the model and return the full response dict,
        containing attentions, hidden states, and the generated response.

        Args:
            prompt: str: The prompt to generate a response for.
            max_tokens: int: The maximum number of tokens to generate.

        Returns:
            dict: The full response dict containing attentions, hidden states, and the generated response.
        """
        logging.info("LLMAgent generating response for: %s.", prompt)
        formatted_prompt = f"{self.system_prompt_start}{self.system_prompt}{self.prompt_suffix}{self.user_prompt_start}{prompt}{self.prompt_suffix}{self.assistant_prompt_start}"
        inputs = self.processor(formatted_prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            return_dict_in_generate=True,
            do_sample=bool(temperature > 0.0),
            temperature=temperature,
        )
        logging.info("LLMAgent response generated.")
        return outputs

    def visualise_average_activations(self, outputs):
        """
        Method to visualise average activations per layer as a heatmap.
        """
        logging.info("LLMAgent visualising average activations for sequence.")
        tokens = [
            self.processor.decode(input_token) for input_token in outputs.sequences[0]
        ]
        average_activations = []
        for layer_states in outputs.hidden_states[0]:
            avg_activation = layer_states.squeeze(0).mean(dim=-1)
            average_activations.append(avg_activation)

        for layer_states in outputs.hidden_states[1:]:
            for i, layer_state in enumerate(layer_states):
                avg_activation = layer_state.squeeze(0).mean(dim=-1)
                average_activations[i] = torch.cat(
                    [average_activations[i], avg_activation]
                )

        average_activations = torch.stack(average_activations, dim=1)
        figsize_x = max(12, len(outputs.hidden_states[0]) * 0.8)
        figsize_y = max(8, len(tokens) * 0.3)

        plt.figure(figsize=(figsize_x, figsize_y))
        sns.heatmap(
            average_activations.detach().cpu().numpy(),
            cmap="mako_r",
            xticklabels=[f"Layer {i}" for i in range(len(outputs.hidden_states[0]))],
            yticklabels=tokens,
            linecolor="lightgrey",
            linewidths=0.2,
            cbar=True,
        )
        plt.title("Average activation per layer per token")
        plt.tight_layout()
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png")
        buffer.seek(0)
        image_bytes = buffer.read()
        plt.close()
        logging.info("LLMAgent visualised average activations for sequence.")
        return image_bytes

    def generate_with_probability(
        self,
        prompt,
        response_prefix=None,
        max_tokens=200,
        temperature=0.0,
        round_to=6,
    ):
        """
        Method to generate a response with information about
        token probabilities.

        Args:
            prompt: str: The prompt to generate a response for.
            response_prefix: str: String to prefix the LLM's response.
            max_tokens: int: The maximum number of tokens to generate.
            temperature: float: The temperature to use for sampling.
            round_to: int: The number of decimal places to round to.

        Returns:
            str: The generated response.
            float: The total probability of the generated response.
            float: The average probability of each token in the generated response.
            list[tuple]: The individual token probabilities.
        """
        formatted_prompt = f"{self.system_prompt_start}{self.system_prompt}{self.prompt_suffix}{self.user_prompt_start}{prompt}{self.prompt_suffix}{self.assistant_prompt_start}"
        if response_prefix:
            formatted_prompt = formatted_prompt + response_prefix
        logging.info("LLMAgent generating response with probability for: %s.", prompt)
        inputs = self.processor(formatted_prompt, return_tensors="pt").to(self.device)
        generate_output = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            output_scores=True,
            return_dict_in_generate=True,
            do_sample=bool(temperature > 0.0),
            temperature=temperature,
        )
        generate_ids = generate_output.sequences[:, inputs["input_ids"].shape[1] :]
        generated_sequence = generate_ids[0].cpu().numpy()
        generated_sequence = generated_sequence[:-2]  # Remove eos tokens
        response = self.processor.batch_decode(
            [generated_sequence],
            skip_special_tokens=False,
            clean_up_tokenization_spaces=False,
        )[0]
        logging.info("  LLMAgent generated response: '%s'", response)
        logits = torch.stack(generate_output.scores, dim=1).cpu()
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1).cpu()
        log_likelihood_for_gen = sum(
            log_probs[0, i, token_id].item()
            for i, token_id in enumerate(generated_sequence)
        )
        total_probability_for_gen = round(np.exp(log_likelihood_for_gen), round_to)
        individual_token_probs = []
        for i, token_id in enumerate(generated_sequence):
            token_prob = np.exp(log_probs[0, i, token_id].item())
            individual_token_probs.append(
                (self.processor.decode(token_id), round(token_prob, round_to))
            )
        average_token_propabiility = round(
            sum(token_prob for _, token_prob in individual_token_probs)
            / len(individual_token_probs),
            round_to,
        )
        return (
            response,
            total_probability_for_gen,
            average_token_propabiility,
            individual_token_probs,
        )

    def test_on_triviaQA(self, n=100):
        """
        Method to test hallucination detection methods on the
        TriviaQA dataset.

        Args:
            n: int: The number of samples to test on.

        Returns:
            pd.DataFrame: The results of the test.
        """
        logging.info(
            "LLMAgent testing hallucination detection on TriviaQA dataset with %s samples.",
            n,
        )
        column_headers = [
            "question",
            "answer",
            "response",
            "total_probability",
            "average_token_probability",
            "individual_token_probs",
        ]
        df = pd.DataFrame(columns=column_headers)
        dataset = load_dataset("trivia_qa", "rc", split="train", streaming=True)
        iterator = iter(dataset)
        for i in range(n):
            logging.info("LLMAgent processing TriviaQA sample %s of %s.", i + 1, n)
            entry = next(iterator)
            question = entry["question"]
            answer = entry["answer"]["value"]
            response, total_probability, average_token_probability, _ = (
                self.generate_with_probability(question)
            )
            row = {
                "question": question,
                "answer": answer,
                "response": response,
                "total_probability": total_probability,
                "average_token_probability": average_token_probability,
            }
            row_df = pd.DataFrame([row], columns=column_headers)
            df = pd.concat([df, row_df], ignore_index=True)
        return df


# Instantiate object and load model

In [None]:
lm = LLMAgent(
    system_prompt="You are a helpful AI assistant that provides correct information as concisely as possible.\n",
    use_gpu=True
)

# Show generation with propabilities

Generating a completion from a Transformers backend with full logit information.

In [None]:
from IPython.display import clear_output

response, total_probability, average_token_probability, individual_token_probs = (
    lm.generate_with_probability("What is the capital of Australia?")
)

clear_output()

print(f"Response generated:\n{response}")
print(f"Total response probability:\n{total_probability}")
print(f"Average token probability:\n{average_token_probability}")
print("Individual token probabilities:")
for token in individual_token_probs:
    buff = 20 * " "
    buff = token[0] + buff[len(token[0]) :]
    print(f"{buff} : {token[1]}")

# Run tests on the TriviaQA dataset

In [None]:
results = lm.test_on_triviaQA()
results.to_csv("triviaQA_results.csv", index=False)

# Generate hidden layer activation heatmaps

In [None]:
from IPython.display import Image, display

generated_response = lm.generate_with_response_dict("What is the capital of Australia?")
image_bytes = lm.visualise_average_activations(generated_response)

with open("./average_activations.png", "wb") as f:
    f.write(image_bytes)

display(Image("./average_activations.png"))