# Activations to Visualization

This notebook contains code for loading and visualizing activations of an LLM during output generation. The code assumes the generated output is stored in a saved .pt file on your google drive. 


Currently, the model being tested is [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/), 7B parameters

## Description of the saved file

The input prompt, generated output and hidden states are saved as a `.pt` file from pytorch. 

The file is saved as `{input_prompt}.pt`

To load the file use,

`data = torch.load("{input_prompt}.pt", map_location=torch.device('cpu'))`

A peek into what that file looks like when loaded:
```
prompt = data['prompt']
hidden_states = data['hidden_states']
output_sequence = data['sequences'][0]
output = data['output'].split("Response:")[1]
```

The shape of the hidden states will be:

```
hidden states for full output shape: (n_output_tokens, n_layers, num_beams, n_iterations, hidden_size)

n_output_tokens : includes the input tokens, I think even in input each token is fed one at a time
n_layers : 33, number of decoder layers + input layer
num_beams : 1, number of beam searches
n_iterations: n_input_tokens, for first and then 1 for all other output tokens
hidden_size: 4096, based on model config
```

### Mount Google Drive

In [2]:
from google.colab import drive

drive.mount('/content/drive')
!ls '/content/drive/MyDrive/llm'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
activations  models


###Import Packages

In [29]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.colors as mcolors
from matplotlib.colors import LogNorm
import torch
from tqdm import tqdm 

### Load Activations (.pt file)

In [30]:
input_prompt = "What_is_4_+_2?"

activation_path = "/content/drive/MyDrive/llm/activations/"+str(input_prompt)+".pt"
data = torch.load(activation_path, map_location=torch.device('cpu'))

# from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
# tokenizer = LlamaTokenizer.from_pretrained('decapoda-research/llama-7b-hf')

###Define Visualization Options & Utility Functions


In [33]:
def generate_bitmap_animation(data):

  hidden_states = data['hidden_states']
  output_response = data['output'].split("Response:")[1]
  # tokenized_output_response = tokenizer.encode(output_response)

  all_images = []
  vmin=0
  vmax=0

  #iterate through hidden states of all tokens 
  for token_id, token_hidden_states in tqdm(enumerate(hidden_states)):
    # current_token = tokenized_output_response[token_id]
    current_token_string = output_response[token_id]
    # print("Token ID: "+str(token_id)+"\t Token Value: " + str(current_token), "\t Token String: '" + str(current_token_string)+"'")

    # Initialize an empty dictionary to store activations
    activations = []

    # iterate through all layers for each token's hidden states
    for layer_id, layers in enumerate(token_hidden_states):
      # print("Layer: "+str(layer_id))
      for beam_id, beams in enumerate(layers):
        # print("Beam: "+str(beam_id))
        for token_activation_id, token_activations in enumerate(beams):
          token_activations_np = token_activations.numpy()  # Detach and convert to NumPy array
          activations.extend(token_activations_np)

    # Determine the size of the square image
    image_size = int(np.ceil(np.sqrt(len(activations))))

    # Create an empty square image with pixel values set to zero
    img = np.zeros((image_size, image_size), dtype=np.uint8)

    # Fill in the image with the normalized activation values
    img.flat[:len(activations)] = activations

    # Add the image to the all_images list
    all_images.append(img)

    # Update the global vmin and vmax
    vmin_token = np.min(activations)
    if vmin_token < vmin:
      vmin = vmin_token
    vmax_token = np.max(activations)
    if vmax_token > vmax:
      vmax = vmax_token

  #Set log norm based on global max and min activations values
  log_norm = LogNorm(vmin=vmin, vmax=vmax)

  # Set the matplotlib backend to save files in the desired format
  matplotlib.use("Agg")

  # Define the update function for the animation
  def update(frame):
      plt.clf()
      plt.imshow(all_images[frame], cmap='viridis', norm=log_norm)
      plt.title(f"Token {frame + 1}: {output_response[frame]}", fontsize=30)
      plt.axis('off')

  # Create the animation
  fig = plt.figure(figsize=(20, 20))
  ani = FuncAnimation(fig, update, frames=len(output_response), interval=250)

  # Save the animation as an MP4 file
  output_file = data["prompt"].replace(' ', '_')+".mp4"
  output_path = "/content/drive/MyDrive/llm/visualizations/"+str(output_file)
  ani.save(output_file, dpi=100, writer="ffmpeg")

  # Set the matplotlib backend back to the default
  matplotlib.use("module://ipykernel.pylab.backend_inline")

  print("visualization succesfully saved: '"+str(output_path)+"'")

def generate_visuals(data, visual_type="bitmap_animation"):
  if visual_type == "bitmap_animation":
    generate_bitmap_animation(data)

###Run Visual Generation & Save Ouput

In [34]:
generate_visuals(data, visual_type="bitmap_animation")

10it [00:02,  4.59it/s]


visualization succesfully saved: '/content/drive/MyDrive/llm/visualizations/What_is_4_+_2?.mp4'
