<a href="https://colab.research.google.com/github/erikrosen01/LLM-saliency-map/blob/main/LLM_saliency_map.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This file creates and visualizes saliency maps for LLMs

In [1]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from IPython.display import display, clear_output
import ipywidgets as widgets

## Main class

In [18]:
class InteractiveSaliencyNotebook:
    def __init__(self):
        """Initialize GPT-2 model and tokenizer"""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

        # Set requires_grad for embedding layer
        for param in self.model.transformer.wte.parameters():
            param.requires_grad = True

        # Track selected tokens and their positions
        self.selected_tokens = {}  # {position: token_text}

        # Initialize widgets
        self.setup_widgets()

    def setup_widgets(self):
        """Create and arrange IPython widgets"""
        # Tab widget for switching between modes
        self.mode_tabs = widgets.Tab()

        # Generate mode widgets
        self.prompt_area = widgets.Textarea(
            description='Prompt:',
            placeholder='Enter your prompt here...',
            layout=widgets.Layout(width='90%', height='100px')
        )

        self.generate_button = widgets.Button(
            description='Generate Response',
            button_style='primary',
            layout=widgets.Layout(width='200px')
        )
        self.generate_button.on_click(self.generate_and_display)

        # Analyze text mode widgets
        self.text_area = widgets.Textarea(
            description='Text:',
            placeholder='Enter text to analyze...',
            layout=widgets.Layout(width='90%', height='100px')
        )

        self.load_text_button = widgets.Button(
            description='Load Text',
            button_style='primary',
            layout=widgets.Layout(width='200px')
        )
        self.load_text_button.on_click(self.load_and_display_text)

        # Common buttons
        self.analyze_button = widgets.Button(
            description='Analyze Selected Tokens',
            button_style='success',
            layout=widgets.Layout(width='200px')
        )
        self.analyze_button.on_click(self.analyze_selected_tokens)

        self.clear_button = widgets.Button(
            description='Clear Selection',
            button_style='warning',
            layout=widgets.Layout(width='200px')
        )
        self.clear_button.on_click(self.clear_selection)

        # Output areas
        self.token_output = widgets.Output()
        self.viz_output = widgets.Output()

        # Create tab contents
        generate_tab = widgets.VBox([
            self.prompt_area,
            widgets.HBox([self.generate_button, self.analyze_button, self.clear_button]),
            self.token_output,
            self.viz_output
        ])

        analyze_tab = widgets.VBox([
            self.text_area,
            widgets.HBox([self.load_text_button, self.analyze_button, self.clear_button]),
            self.token_output,
            self.viz_output
        ])

        # Set up tabs
        self.mode_tabs.children = [generate_tab, analyze_tab]
        self.mode_tabs.set_title(0, 'Generate Mode')
        self.mode_tabs.set_title(1, 'Analyze Text Mode')

        # Main container
        self.main_container = self.mode_tabs

    def load_and_display_text(self, _):
        """Load and display input text for analysis"""
        with self.token_output:
            clear_output()
            text = self.text_area.value
            if text:
                self.current_text = text
                display(self.create_token_buttons(text))

    def generate_response(self, prompt, max_length=100):
        """Generate text response from prompt"""
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)

        output_ids = self.model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=self.tokenizer.eos_token_id,
            do_sample=True,
            temperature=0.7
        )

        response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return response[len(prompt):]

    def generate_and_display(self, _):
        """Generate response and display token buttons"""
        with self.token_output:
            clear_output()
            prompt = self.prompt_area.value
            if prompt:
                response = self.generate_response(prompt)
                self.current_text = prompt + response
                print("Generated response:", response, "\n")
                display(self.create_token_buttons(self.current_text))

    def compute_saliency(self, text, target_tokens_with_positions):
        """Compute combined saliency scores for multiple target tokens."""
        self.model.zero_grad()

        input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device)
        combined_saliency = None
        last_target_position = -1
        target_positions = set()  # Track all target positions

        # Sort by position to process in order
        for position, target_token in sorted(target_tokens_with_positions.items()):
            embeddings = self.model.transformer.wte(input_ids)
            embeddings.retain_grad()

            hidden_states = self.model.transformer(inputs_embeds=embeddings).last_hidden_state
            hidden_state = hidden_states[0, position]

            hidden_state.norm().backward(retain_graph=True)

            if embeddings.grad is None:
                continue

            embedding_gradients = embeddings.grad[0].norm(dim=-1)
            saliency_scores = embedding_gradients.cpu().detach().numpy()
            target_positions.add(position)
            last_target_position = max(last_target_position, position)

            if combined_saliency is None:
                combined_saliency = saliency_scores
            else:
                combined_saliency += saliency_scores

            embeddings.grad.zero_()

        if combined_saliency is None:
            raise ValueError("No valid target tokens found in input text")

        tokens = [self.tokenizer.decode(token_id.item()) for token_id in input_ids[0]]
        return tokens, combined_saliency, last_target_position, target_positions

    def create_token_buttons(self, text):
        """Create selectable buttons for each token"""
        tokens = self.tokenizer.encode(text, return_tensors='pt')[0]
        token_buttons = []

        for pos, token_id in enumerate(tokens):
            token_text = self.tokenizer.decode(token_id.item())
            button = widgets.Button(
                description=token_text,
                layout=widgets.Layout(width='auto', margin='2px')
            )

            if pos in self.selected_tokens:
                button.button_style = 'info'

            button.on_click(lambda b, t=token_text, p=pos: self.toggle_token_selection(b, t, p))
            token_buttons.append(button)

        # Arrange buttons in rows
        rows = []
        current_row = []
        current_width = 0
        max_width = 800

        for button in token_buttons:
            button_width = len(button.description) * 8 + 20
            if current_width + button_width > max_width and current_row:
                rows.append(widgets.HBox(current_row))
                current_row = []
                current_width = 0
            current_row.append(button)
            current_width += button_width

        if current_row:
            rows.append(widgets.HBox(current_row))

        return widgets.VBox(rows)

    def toggle_token_selection(self, button, token_text, position):
        """Toggle token selection state"""
        if position in self.selected_tokens:
            del self.selected_tokens[position]
            button.button_style = ''
        else:
            self.selected_tokens[position] = token_text
            button.button_style = 'info'

    def clear_selection(self, _):
        """Clear all selected tokens"""
        self.selected_tokens.clear()
        self.generate_and_display(None)
        with self.viz_output:
            clear_output()

    def analyze_selected_tokens(self, _):
        """Analyze saliency for all selected tokens"""
        if not self.selected_tokens:
            with self.viz_output:
                clear_output()
                print("Please select at least one token to analyze")
                return

        with self.viz_output:
            clear_output()
            try:
                tokens, scores, last_position, target_positions = self.compute_saliency(
                    self.current_text,
                    self.selected_tokens
                )
                # Truncate tokens and scores up to the last selected token
                tokens = tokens[:last_position + 1]
                scores = scores[:last_position + 1]
                fig = self.visualize_saliency(tokens, scores, target_positions)
                plt.show()

            except ValueError as e:
                print(f"Error: {e}")

    def visualize_saliency(self, tokens, scores, target_positions, chunk_size=25):
        """Create visualization for multiple target tokens"""
        # Create mask for target positions
        target_mask = np.zeros_like(scores, dtype=bool)
        for pos in target_positions:
            if pos < len(target_mask):
                target_mask[pos] = True

        # Normalize scores excluding target positions
        non_target_scores = scores[~target_mask]
        if len(non_target_scores) > 0:
            min_score = non_target_scores.min()
            max_score = non_target_scores.max()
            normalized_scores = np.copy(scores)
            normalized_scores[~target_mask] = (non_target_scores - min_score) / (max_score - min_score + 1e-10)
            normalized_scores[target_mask] = 1.0  # Set target tokens to maximum activation
        else:
            normalized_scores = np.copy(scores)

        # Split into chunks
        num_chunks = (len(tokens) + chunk_size - 1) // chunk_size
        fig, axes = plt.subplots(num_chunks, 1,
                               figsize=(15, 2 * num_chunks),
                               squeeze=False)

        for chunk_idx in range(num_chunks):
            start_idx = chunk_idx * chunk_size
            end_idx = min((chunk_idx + 1) * chunk_size, len(tokens))
            chunk_tokens = tokens[start_idx:end_idx]
            chunk_scores = normalized_scores[start_idx:end_idx]
            chunk_target_mask = target_mask[start_idx:end_idx]
            ax = axes[chunk_idx, 0]

            num_tokens_in_chunk = len(chunk_tokens)
            for i in range(chunk_size):
                if i < num_tokens_in_chunk:
                    token = chunk_tokens[i]
                    score = chunk_scores[i]
                    is_target = chunk_target_mask[i]

                    rect = Rectangle((i, 0), 1, 1, facecolor=plt.cm.YlOrRd(score))
                    ax.add_patch(rect)

                    # Only highlight the specifically selected positions
                    if is_target:
                        rect.set_edgecolor('blue')
                        rect.set_linewidth(2)

                    ax.text(i + 0.5, 0.5, f"{token}\n{score:.2f}",
                           ha='center', va='center',
                           color='white' if score > 0.5 else 'black')

            ax.set_xlim(0, chunk_size)
            ax.set_ylim(0, 1)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        # Set main title - only show "Token Saliency Analysis"
        fig.suptitle('Token Saliency Analysis', fontsize=14)

        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=plt.cm.YlOrRd)
        sm.set_array([])
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        fig.colorbar(sm, cax=cbar_ax, label='Saliency Score')

        plt.tight_layout(rect=[0, 0, 0.9, 0.98])
        return fig

    def display(self):
        """Display the interactive interface"""
        display(self.main_container)

In [21]:
# Function to create analyzer for generated text
def create_generator():
    analyzer = InteractiveSaliencyNotebook()
    analyzer.mode_tabs.selected_index = 0  # Set to Generate Mode
    analyzer.display()
    return analyzer

# Function to create analyzer for existing text
def create_analyzer():
    analyzer = InteractiveSaliencyNotebook()
    analyzer.mode_tabs.selected_index = 1  # Set to Analyze Text Mode
    analyzer.display()
    return analyzer

## Usecase 1
Prompt the model and analyze its response

In [23]:
analyzer = create_generator()

Tab(children=(VBox(children=(Textarea(value='', description='Prompt:', layout=Layout(height='100px', width='90…

## Usecase 2
Enter your own text and get the model to analyze using a saliency map

In [25]:
analyzer = create_analyzer()

Tab(children=(VBox(children=(Textarea(value='', description='Prompt:', layout=Layout(height='100px', width='90…