In [13]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact, IntSlider, Layout
import ast

class AttentionVisualizer:
    def __init__(self, file_path):
        """
        Initialize the attention visualizer with the file path.
        
        Args:
            file_path (str): Path to the attention weights file
        """
        self.file_path = file_path
        self.attention_data = self._load_attention_data()
        self.num_rows = len(self.attention_data)
        
        # Get dimensions from first row
        first_tensor = self.attention_data[0]
        self.num_heads = first_tensor.shape[1]
        self.seq_len = first_tensor.shape[2]
        
        print(f"Loaded attention data:")
        print(f"  - Number of rows: {self.num_rows}")
        print(f"  - Number of attention heads: {self.num_heads}")
        print(f"  - Sequence length: {self.seq_len}")
    
    def _load_attention_data(self):
        """Load and parse the attention data from the file."""
        attention_tensors = []
        
        with open(self.file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    # Parse the line as a Python literal (list)
                    tensor_data = ast.literal_eval(line)
                    # Convert to numpy array and reshape
                    tensor = np.array(tensor_data)
                    attention_tensors.append(tensor)
        
        return attention_tensors
    
    def _get_tick_positions(self, seq_len, tick_interval=50):
        """
        Generate tick positions at specified intervals.
        
        Args:
            seq_len (int): Length of the sequence
            tick_interval (int): Interval between ticks
            
        Returns:
            tuple: (tick_positions, tick_labels)
        """
        tick_positions = list(range(0, seq_len, tick_interval))
        # Always include the last position if it's not already included
        if tick_positions[-1] != seq_len - 1:
            tick_positions.append(seq_len - 1)
        
        tick_labels = [str(pos) for pos in tick_positions]
        return tick_positions, tick_labels
    
    def plot_attention_heatmap(self, row_idx=0, head_idx=0, figsize=(10, 8), cmap='Blues', tick_interval=50):
        """
        Plot attention heatmap for a specific row and head.
        
        Args:
            row_idx (int): Row index to visualize
            head_idx (int): Attention head index to visualize
            figsize (tuple): Figure size for the plot
            cmap (str): Colormap for the heatmap
            tick_interval (int): Interval between axis ticks (default: 50)
        """
        if row_idx >= self.num_rows:
            print(f"Row index {row_idx} out of range. Max row index: {self.num_rows - 1}")
            return
        
        if head_idx >= self.num_heads:
            print(f"Head index {head_idx} out of range. Max head index: {self.num_heads - 1}")
            return
        
        # Extract the attention matrix for the specified row and head
        attention_matrix = self.attention_data[row_idx][0, head_idx, :, :]
        
        # Get tick positions and labels
        tick_positions, tick_labels = self._get_tick_positions(self.seq_len, tick_interval)
        
        # Create the heatmap
        plt.figure(figsize=figsize)
        sns.heatmap(attention_matrix, 
                   cmap=cmap, 
                   cbar=True,
                   square=True,
                   xticklabels=tick_positions,
                   yticklabels=tick_positions,
                   cbar_kws={'label': 'Attention Weight'})
        
        # Set custom tick positions and labels
        plt.xticks(tick_positions, tick_labels)
        plt.yticks(tick_positions, tick_labels)
        
        plt.title(f'Attention Heatmap - Row {row_idx}, Head {head_idx}')
        plt.xlabel('Key Position')
        plt.ylabel('Query Position')
        plt.tight_layout()
        plt.show()
    
    def interactive_heatmap(self, figsize=(12, 10), cmap='gray', tick_interval=50):
        """
        Create an interactive heatmap with sliders for row and head selection.
        
        Args:
            figsize (tuple): Figure size for the plot
            cmap (str): Colormap for the heatmap
            tick_interval (int): Interval between axis ticks (default: 50)
        """
        def plot_heatmap(row_idx, head_idx):
            self.plot_attention_heatmap(row_idx, head_idx, figsize, cmap, tick_interval)
        
        # Create sliders
        row_slider = IntSlider(
            value=0,
            min=0,
            max=self.num_rows - 1,
            step=1,
            description='Row:',
            layout=Layout(width='400px')
        )
        
        head_slider = IntSlider(
            value=0,
            min=0,
            max=self.num_heads - 1,
            step=1,
            description='Head:',
            layout=Layout(width='400px')
        )
        
        # Create interactive widget
        interact(plot_heatmap, row_idx=row_slider, head_idx=head_slider)
    
    def compare_heads(self, row_idx=0, heads_to_compare=None, figsize=(15, 5), tick_interval=50):
        """
        Compare multiple attention heads side by side for a given row.
        
        Args:
            row_idx (int): Row index to visualize
            heads_to_compare (list): List of head indices to compare. If None, shows first 3 heads.
            figsize (tuple): Figure size for the plot
            tick_interval (int): Interval between axis ticks (default: 50)
        """
        if heads_to_compare is None:
            heads_to_compare = list(range(min(3, self.num_heads)))
        
        if row_idx >= self.num_rows:
            print(f"Row index {row_idx} out of range. Max row index: {self.num_rows - 1}")
            return
        
        # Get tick positions and labels
        tick_positions, tick_labels = self._get_tick_positions(self.seq_len, tick_interval)
        
        num_heads_to_show = len(heads_to_compare)
        fig, axes = plt.subplots(1, num_heads_to_show, figsize=figsize)
        
        if num_heads_to_show == 1:
            axes = [axes]
        
        for i, head_idx in enumerate(heads_to_compare):
            if head_idx >= self.num_heads:
                print(f"Head index {head_idx} out of range. Max head index: {self.num_heads - 1}")
                continue
            
            attention_matrix = self.attention_data[row_idx][0, head_idx, :, :]
            
            sns.heatmap(attention_matrix, 
                       ax=axes[i],
                       cmap='gray',
                       cbar=True,
                       square=True,
                       xticklabels=tick_positions,
                       yticklabels=tick_positions,
                       cbar_kws={'label': 'Attention Weight'})
            
            # Set custom tick positions and labels
            axes[i].set_xticks(tick_positions)
            axes[i].set_xticklabels(tick_labels)
            axes[i].set_yticks(tick_positions)
            axes[i].set_yticklabels(tick_labels)
            
            axes[i].set_title(f'Head {head_idx}')
            axes[i].set_xlabel('Key Position')
            if i == 0:
                axes[i].set_ylabel('Query Position')
        
        plt.suptitle(f'Attention Heads Comparison - Row {row_idx}')
        plt.tight_layout()
        plt.show()
    
    def get_attention_stats(self, row_idx=0, head_idx=0):
        """
        Get statistics for a specific attention matrix.
        
        Args:
            row_idx (int): Row index
            head_idx (int): Head index
            
        Returns:
            dict: Statistics about the attention matrix
        """
        if row_idx >= self.num_rows or head_idx >= self.num_heads:
            return None
        
        attention_matrix = self.attention_data[row_idx][0, head_idx, :, :]
        
        stats = {
            'mean': np.mean(attention_matrix),
            'std': np.std(attention_matrix),
            'min': np.min(attention_matrix),
            'max': np.max(attention_matrix),
            'sparsity': np.sum(attention_matrix < 0.01) / attention_matrix.size,
            'entropy': -np.sum(attention_matrix * np.log(attention_matrix + 1e-10), axis=-1).mean()
        }
        
        return stats

# Usage example:
"""
# Initialize the visualizer
visualizer = AttentionVisualizer('path/to/your/attention_weights.txt')

# Plot a specific attention head with custom tick interval
visualizer.plot_attention_heatmap(row_idx=0, head_idx=0, tick_interval=50)

# Use interactive visualization with custom tick interval
visualizer.interactive_heatmap(tick_interval=50)

# Compare multiple heads with custom tick interval
visualizer.compare_heads(row_idx=0, heads_to_compare=[0, 1, 2], tick_interval=50)

# Get statistics
stats = visualizer.get_attention_stats(row_idx=0, head_idx=0)
print(stats)
"""

"\n# Initialize the visualizer\nvisualizer = AttentionVisualizer('path/to/your/attention_weights.txt')\n\n# Plot a specific attention head with custom tick interval\nvisualizer.plot_attention_heatmap(row_idx=0, head_idx=0, tick_interval=50)\n\n# Use interactive visualization with custom tick interval\nvisualizer.interactive_heatmap(tick_interval=50)\n\n# Compare multiple heads with custom tick interval\nvisualizer.compare_heads(row_idx=0, heads_to_compare=[0, 1, 2], tick_interval=50)\n\n# Get statistics\nstats = visualizer.get_attention_stats(row_idx=0, head_idx=0)\nprint(stats)\n"

In [14]:
visualizer = AttentionVisualizer('attentions.txt')


Loaded attention data:
  - Number of rows: 1
  - Number of attention heads: 16
  - Sequence length: 388


In [None]:
visualizer.interactive_heatmap()

interactive(children=(IntSlider(value=0, description='Row:', layout=Layout(width='400px'), max=0), IntSlider(v…