# Research aside - Hallucination detection & LLM explainability - Part 2

NB: May require a T4 GPU and Colab high-RAM to get the model loaded.


## Setup - Install dependencies

In [None]:
!pip install datasets # &> /dev/null
!pip install sentence-transformers # &> /dev/null

## Declare LLMAgent object

We use a single handler class to hold the model and our token attributes to avoid repeating boilerplate.

This object has some additional functionality inside "test_on_triviaQA()" to automatically evaluate answers based on cosine similarity, supported by a attribute on LLMAgent containing an embedding model.

LLM used: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct

Embedding model used: https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1

In [None]:
"""
Python script containing explainability and visualisation
utilities for large language models.
"""

import io
import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from PIL import Image
import json

logging.basicConfig(level=logging.INFO)


class LLMAgent:
    """
    Class encapsulating explainability and visualisation
    utilities for large language models.
    """

    def __init__(
        self,
        hf_api_key: str = None,
        embedding_model_id: str = "mixedbread-ai/mxbai-embed-large-v1",
        embedding_instruction: str = "Represent this sentence for searching relevant passages:",
        embedding_sameness_threshold: float = 0.67,
        model_id: str = "microsoft/Phi-3-mini-4k-instruct",
        use_gpu: bool = False,
        prompt_suffix="<|end|>\n",
        user_prompt_start="<|user|>\n",
        assistant_prompt_start="<|assistant|>\n",
        system_prompt_start="<|system|>\n",
        system_prompt="You are a helpful AI assistant that provides concise answers.\n",
        end_token="<|end|>",
        eot_token="<|endoftext|>",
    ):
        """
        Initialise the LLMAgent object.
        """
        logging.info("LLMAgent initialising.")
        self.embedding_model = SentenceTransformer(embedding_model_id, device="cpu")
        self.embedding_instruction = embedding_instruction
        self.embedding_sameness_threshold = embedding_sameness_threshold
        logging.info(
            "LLMAgent downloading/ensuring presence of large language model: %s.",
            model_id,
        )
        self.device = "cuda" if use_gpu else "mps" if torch.has_mps else "cpu"
        logging.info("LLMAgent sending model to device: %s", self.device)
        self.processor = AutoProcessor.from_pretrained(model_id, token=hf_api_key)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            # torch_dtype=torch.bfloat16,
            token=hf_api_key,
            output_hidden_states=True,
            output_attentions=True,
            do_sample=True,
        ).to(self.device)
        self.prompt_suffix = prompt_suffix
        self.user_prompt_start = user_prompt_start
        self.assistant_prompt_start = assistant_prompt_start
        self.end_token = end_token
        self.eot_token = eot_token
        self.system_prompt_start = system_prompt_start
        self.system_prompt = system_prompt
        logging.info("LLMAgent initialisaed.")

    def generate_with_response_dict(
        self, prompt: str, max_tokens: int = 200, temperature=0.0
    ):
        """
        Method to inference on the model and return the full response dict,
        containing attentions, hidden states, and the generated response.

        Args:
            prompt: str: The prompt to generate a response for.
            max_tokens: int: The maximum number of tokens to generate.

        Returns:
            dict: The full response dict containing attentions, hidden states, and the generated response.
        """
        logging.info("LLMAgent generating response for: %s.", prompt)
        formatted_prompt = f"{self.system_prompt_start}{self.system_prompt}{self.prompt_suffix}{self.user_prompt_start}{prompt}{self.prompt_suffix}{self.assistant_prompt_start}"
        inputs = self.processor(formatted_prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            return_dict_in_generate=True,
            do_sample=bool(temperature > 0.0),
            temperature=temperature,
        )
        logging.info("LLMAgent response generated.")
        return outputs

    def visualise_average_activations(self, outputs):
        """
        Method to visualise average activations per layer as a heatmap.
        """
        logging.info("LLMAgent visualising average activations for sequence.")
        tokens = [
            self.processor.decode(input_token) for input_token in outputs.sequences[0]
        ]
        average_activations = []
        for layer_states in outputs.hidden_states[0]:
            avg_activation = layer_states.squeeze(0).mean(dim=-1)
            average_activations.append(avg_activation)

        for layer_states in outputs.hidden_states[1:]:
            for i, layer_state in enumerate(layer_states):
                avg_activation = layer_state.squeeze(0).mean(dim=-1)
                average_activations[i] = torch.cat(
                    [average_activations[i], avg_activation]
                )

        average_activations = torch.stack(average_activations, dim=1)
        figsize_x = max(12, len(outputs.hidden_states[0]) * 0.8)
        figsize_y = max(8, len(tokens) * 0.3)

        plt.figure(figsize=(figsize_x, figsize_y))
        sns.heatmap(
            average_activations.detach().cpu().numpy(),
            cmap="mako_r",
            xticklabels=[f"Layer {i}" for i in range(len(outputs.hidden_states[0]))],
            yticklabels=tokens,
            linecolor="lightgrey",
            linewidths=0.2,
            cbar=True,
        )
        plt.title("Average activation per layer per token")
        plt.tight_layout()
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png")
        buffer.seek(0)
        image_bytes = buffer.read()
        plt.close()
        logging.info("LLMAgent visualised average activations for sequence.")
        return image_bytes

    def generate_with_probability(
        self,
        prompt,
        response_prefix=None,
        max_tokens=200,
        temperature=0.0,
        round_to=6,
    ):
        """
        Method to generate a response with information about
        token probabilities.

        Args:
            prompt: str: The prompt to generate a response for.
            response_prefix: str: String to prefix the LLM's response.
            max_tokens: int: The maximum number of tokens to generate.
            temperature: float: The temperature to use for sampling.
            round_to: int: The number of decimal places to round to.

        Returns:
            str: The generated response.
            float: The total probability of the generated response.
            float: The average probability of each token in the generated response.
            list[tuple]: The individual token probabilities.
            dict: The full generation output.
        """
        formatted_prompt = f"{self.system_prompt_start}{self.system_prompt}{self.prompt_suffix}{self.user_prompt_start}{prompt}{self.prompt_suffix}{self.assistant_prompt_start}"
        if response_prefix:
            formatted_prompt = formatted_prompt + response_prefix
        logging.info("LLMAgent generating response with probability for: %s.", prompt)
        inputs = self.processor(formatted_prompt, return_tensors="pt").to(self.device)
        generate_output = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            output_scores=True,
            return_dict_in_generate=True,
            do_sample=bool(temperature > 0.0),
            temperature=temperature,
        )
        generate_ids = generate_output.sequences[:, inputs["input_ids"].shape[1] :]
        generated_sequence = generate_ids[0].cpu().numpy()
        generated_sequence = generated_sequence[:-2]  # Remove eos tokens
        response = self.processor.batch_decode(
            [generated_sequence],
            skip_special_tokens=False,
            clean_up_tokenization_spaces=False,
        )[0]
        logging.info("  LLMAgent generated response: '%s'", response)
        logits = torch.stack(generate_output.scores, dim=1).cpu()
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1).cpu()
        log_likelihood_for_gen = sum(
            log_probs[0, i, token_id].item()
            for i, token_id in enumerate(generated_sequence)
        )
        total_probability_for_gen = round(np.exp(log_likelihood_for_gen), round_to)
        individual_token_probs = []
        for i, token_id in enumerate(generated_sequence):
            token_prob = np.exp(log_probs[0, i, token_id].item())
            individual_token_probs.append(
                (self.processor.decode(token_id), round(token_prob, round_to))
            )
        average_token_propabiility = round(
            sum(token_prob for _, token_prob in individual_token_probs)
            / len(individual_token_probs),
            round_to,
        )
        return (
            response,
            total_probability_for_gen,
            average_token_propabiility,
            individual_token_probs,
            generate_output,
        )

    def test_on_triviaQA(self, filename="triviaQA.ndjson", n=100):
        """
        Method to test hallucination detection methods on the
        TriviaQA dataset. Streams results into an ndjson file.

        Args:
            filename: str: The filename to save the results to.
            n: int: The number of samples to test on.
        """
        logging.info(
            "LLMAgent testing hallucination detection on TriviaQA dataset with %s samples.",
            n,
        )
        dataset = load_dataset("trivia_qa", "rc", split="train", streaming=True)
        iterator = iter(dataset)
        for i in range(n):
            logging.info("LLMAgent processing TriviaQA sample %s of %s.", i + 1, n)
            entry = next(iterator)
            question = entry["question"]
            answer = entry["answer"]["value"]
            (
                response,
                total_probability,
                average_token_probability,
                _,
                generate_output,
            ) = self.generate_with_probability(question, max_tokens=15)
            is_same, max_similarity = self.check_answer(response, [answer])
            target_token = self.processor.encode(self.assistant_prompt_start)[0]
            for i, token_id in enumerate(generate_output.sequences[0]):
                if token_id == target_token:
                    target_token_index = i - 1
                    break
            row = {
                "question": question,
                "answer": answer,
                "response": response,
                "correct": is_same,
                "similarity": max_similarity,
                "total_probability": total_probability,
                "total_probability_predicts": (
                    True if total_probability > 0.5 else False
                ),
                "average_token_probability": average_token_probability,
                "average_token_propability_predicts": (
                    True if average_token_probability > 0.75 else False
                ),
                "both_metrics_predict": (
                    True
                    if (total_probability + average_token_probability) / 2 > 0.625
                    else False
                ),
                "middle_layer_activations_prompt": self.visualise_activation_map_at_layer_at_token(
                    generate_output, 16, target_token_index, return_numeric_state=True
                ).tolist(),
                "middle_layer_activations_response": self.visualise_activation_map_at_layer_at_token(
                    generate_output, 16, -1, return_numeric_state=True
                ).tolist(),
                "final_layer_activations_prompt": self.visualise_activation_map_at_layer_at_token(
                    generate_output, -1, target_token_index, return_numeric_state=True
                ).tolist(),
                "final_layer_activations_response": self.visualise_activation_map_at_layer_at_token(
                    generate_output, -1, -1, return_numeric_state=True
                ).tolist(),
            }
            with open(filename, "a") as file:
                file.write(json.dumps(row) + "\n")

    def check_answer(self, query, comparisons):
        """
        Method to check is an output answer is close
        enough to any provided comparison answers.

        Args:
            query: str: The first sentence.
            comparisons: list[str]: Comparison sentences.

        Returns:
            bool: True if the sentence matches any of the comparisons.
            float: The max similarity score.
        """
        if not query.startswith(self.embedding_instruction):
            query = self.embedding_instruction + query
        inputs = [query] + comparisons
        embeddings = self.embedding_model.encode(inputs)
        similarities = cos_sim(embeddings[0], embeddings[1:])[0].tolist()
        is_same = False
        for similarity in similarities:
            if similarity > self.embedding_sameness_threshold:
                return True, max(similarities)
        return False, max(similarities)

    def visualise_and_stack_layers(self, outputs, alpha=0.15, gap=100):
        """
        Method to visualise the activations of all layers and
        stack them into a single image.

        Args:
            outputs: dict: The outputs of the model.
            alpha: float: The alpha value for the image.
            gap: int: The gap between layers.

        Returns:
            bytes: The image bytes.
        """
        logging.info("LLMAgent visualising and stacking activations for sequence.")
        num_layers = len(outputs.hidden_states[0])
        logging.info("LLMAgent visualising activations for %s layers.", num_layers)
        layer_activation_images_bytes = []
        for i in range(num_layers):
            image_bytes = self.visualise_layer_activations(outputs, layer=i)
            layer_activation_images_bytes.append(image_bytes)
        logging.info("LLMAgent stacking activation images.")
        images = np.array(
            [
                Image.open(io.BytesIO(image_bytes)).resize((100, 100), Image.LANCZOS)
                for image_bytes in layer_activation_images_bytes
            ]
        )
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection="3d")
        fig.patch.set_alpha(0)
        ax.patch.set_alpha(0)
        num_images, height, width, _ = images.shape

        for i in range(num_images):
            img = images[i]
            x, y = np.meshgrid(np.arange(width), np.arange(height))
            z = np.full_like(x, i * gap)

            img_normalized = img / 255.0
            facecolors = np.empty(img_normalized.shape, dtype=img_normalized.dtype)
            facecolors[..., :3] = img_normalized[..., :3]
            facecolors[..., 3] = img_normalized[..., 3] * alpha

            ax.plot_surface(
                x, y, z, rstride=1, cstride=1, facecolors=facecolors, shade=False
            )

        ax.set_xlim(0, width)
        ax.set_ylim(0, height)
        ax.set_zlim(0, num_images * gap)
        ax.view_init(elev=30, azim=30)
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        plt.tight_layout()
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png")
        buffer.seek(0)
        image_bytes = buffer.read()
        plt.close()

        return image_bytes

    def visualise_layer_activations(self, outputs, layer=0):
        """
        Method to visualise the per-neuron activations for a given layer.

        Args:
            outputs: dict: The outputs of the model.
            layer: int: The layer to visualise.

        Returns:
            bytes: The image bytes.
        """
        logging.info(
            "LLMAgent visualising activations for layer %s for sequence.", layer
        )
        tokens = [
            self.processor.decode(input_token) for input_token in outputs.sequences[0]
        ]

        layer_feature_maps = []
        for tensor in outputs.hidden_states:
            target_layer = tensor[layer]
            tokens_in_tensor = target_layer.shape[1]
            for i in range(tokens_in_tensor):
                feature_map = target_layer[0, i, :].cpu().detach().numpy()
                layer_feature_maps.append(feature_map)

        total_tokens = len(layer_feature_maps)
        grid_size = int(np.ceil(np.sqrt(total_tokens)))
        plt.figure(figsize=(100, 100))
        plt.gca().patch.set_alpha(0)

        for idx, feature_map in enumerate(layer_feature_maps):
            n_activations = len(feature_map)
            heatmap_size = int(np.ceil(np.sqrt(n_activations)))
            padded_activations = np.pad(
                feature_map, (0, heatmap_size**2 - n_activations), mode="constant"
            )
            activation_grid = padded_activations.reshape(heatmap_size, heatmap_size)

            ax = plt.subplot(grid_size, grid_size, idx + 1)
            sns.heatmap(
                activation_grid,
                cmap="mako_r",
                cbar=False,
                linecolor="lightgrey",
                linewidths=0.2,
                xticklabels=False,
                yticklabels=False,
            )
            ax.text(
                0.5,
                0.5,
                tokens[idx],
                fontsize=80,
                color="white",
                ha="center",
                va="center",
                alpha=0.6,
                transform=ax.transAxes,
                weight="bold",
            )

        plt.tight_layout()
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png", transparent=True)
        buffer.seek(0)
        image_bytes = buffer.read()
        plt.close()
        logging.info("LLMAgent visualised activations for layer %s.", layer)
        return image_bytes

    def visualise_activation_map_at_layer_at_token(
        self, outputs, layer, token, return_numeric_state=True
    ):
        """
        Method to visualise the activation map at a given layer
        and token.

        Args:
            outputs: dict: The outputs of the model.
            layer: int: The layer to visualise.
            token: int: The token to visualise.
            return_numeric_state: bool: Whether to return the raw state instead of image bytes.

        Returns:
            bytes: The image bytes. | numpy.ndarray: The numeric activation state.
        """
        logging.info(
            "LLMAgent visualising activations for layer %s and token %s.",
            layer,
            token,
        )
        tokens = [
            self.processor.decode(input_token) for input_token in outputs.sequences[0]
        ]

        layer_feature_maps = []
        for tensor in outputs.hidden_states:
            target_layer = tensor[layer]
            tokens_in_tensor = target_layer.shape[1]
            for i in range(tokens_in_tensor):
                feature_map = target_layer[0, i, :].cpu().detach().numpy()
                layer_feature_maps.append(feature_map)

        token_feature_map = layer_feature_maps[token]
        if return_numeric_state:
            return token_feature_map

        total_tokens = len(layer_feature_maps)
        grid_size = int(np.ceil(np.sqrt(total_tokens)))
        plt.figure(figsize=(10, 10))
        plt.gca().patch.set_alpha(0)

        n_activations = len(token_feature_map)
        heatmap_size = int(np.ceil(np.sqrt(n_activations)))
        padded_activations = np.pad(
            token_feature_map, (0, heatmap_size**2 - n_activations), mode="constant"
        )
        activation_grid = padded_activations.reshape(heatmap_size, heatmap_size)
        ax = plt.subplot(1, 1, 1)
        sns.heatmap(
            activation_grid,
            cmap="mako_r",
            cbar=False,
            linecolor="lightgrey",
            linewidths=0.2,
            xticklabels=False,
            yticklabels=False,
        )
        ax.text(
            0.5,
            0.5,
            tokens[token],
            fontsize=80,
            color="white",
            ha="center",
            va="center",
            alpha=0.6,
            transform=ax.transAxes,
            weight="bold",
        )

        plt.tight_layout()
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png", transparent=True)
        buffer.seek(0)
        image_bytes = buffer.read()
        plt.close()
        logging.info(
            "LLMAgent visualised activations for layer %s and token %s.", layer, token
        )
        return image_bytes

# Instantiate object and load model

In [None]:
from IPython.display import clear_output
from IPython.display import display

agent = None
agent = LLMAgent(
    system_prompt="You are a helpful AI assistant that provides concise answers.",
    use_gpu=True
)

# Generate 3d model state cube

In [None]:
response_dict = agent.generate_with_response_dict("What is the capital of Australia?")

image_bytes = agent.visualise_and_stack_layers(response_dict)

image = Image.open(io.BytesIO(image_bytes))
clear_output()
display(image)


# Generate 2d layer state for sequence

In [None]:
response_dict = agent.generate_with_response_dict("What is the capital of Australia?")

image_bytes = agent.visualise_layer_activations(response_dict, layer=16)

image = Image.open(io.BytesIO(image_bytes))
clear_output()
display(image)

# Generate 2d layer state for single token

In [None]:
response_dict = agent.generate_with_response_dict("What is the capital of Australia?")

image_bytes = agent.visualise_activation_map_at_layer_at_token(response_dict, 16, 15)

image = Image.open(io.BytesIO(image_bytes))
clear_output()
display(image)