In [5]:
from matplotlib.colors import Normalize
import os
import json
import torch
import numpy as np

def calculate_attention_values(attention_matrix, input_tokens, output_tokens, image_token_start_idx, image_token_end_idx):
    # Calculate the average image tokens for each output token before normalizing
    average_attention_scores = []

    # calculate rationale's attention scores
    for i, output_token in enumerate(output_tokens):
        output_relevancy = attention_matrix[len(input_tokens) - 1 + i, :]
        image_tokens = output_relevancy[image_token_start_idx:image_token_end_idx]
        text_tokens = output_relevancy[image_token_end_idx + 1:len(input_tokens)-3]
    
        average_text_tokens = np.mean(text_tokens)
        average_image_tokens = np.mean(image_tokens)
        
        average_attention_scores.append({
            "id": i,
            "output_token": output_token,
            "average_text_tokens": average_text_tokens,
            "average_image_tokens": average_image_tokens,
            "average_image_to_text_ratio": average_image_tokens / average_text_tokens,
            "text_tokens": text_tokens,
            "image_tokens": image_tokens
        })

    return average_attention_scores 

def calculate_image_text_ratio(attention_scores, normalize=True):
    ratio_list = []
    for output_token in attention_scores:
        ratio_list.append(output_token["average_image_to_text_ratio"])

    ratio_list = np.array(ratio_list)
    
    return np.mean(ratio_list)

def process_data(file_names, base):
  processed_data = {}
  for key, file_name in file_names.items():
    with open(os.path.join(base, file_name, "results.json"), "r") as file:
      results = json.load(file)

    for i, result in enumerate(results):
      average_attention_matrix = torch.load(os.path.join(base, file_name, "attention_weights", f"{i}.pt"), map_location=torch.device('cpu'))
      attention_file = os.path.join(base, file_name, "attention", f"{i}.json")
      with open(attention_file, "r") as file:
          attention_description = json.load(file)

      # filter tokens
      output_tokens = [token.replace('Ġ', '').replace('Ċ', '') for token in attention_description["output_tokens"]]
      input_tokens = [token.replace('Ġ', '').replace('Ċ', '') for token in attention_description["input_tokens"]]
      image_token_indices = [i for i, token in enumerate(input_tokens) if token.startswith("image_")]
      image_token_start_idx = image_token_indices[0]
      image_token_end_idx = image_token_indices[-1] + 1
      main_input_text_tokens = input_tokens[image_token_end_idx + 1:len(input_tokens)]

      attention_results = calculate_attention_values(average_attention_matrix, input_tokens, output_tokens, image_token_start_idx, image_token_end_idx)
      image_text_ratio = calculate_image_text_ratio(attention_results)

      # store in processed data structure
      processed_data[key]['items'][i]['attention_results'] = attention_results
      processed_data[key]['items'][i]['input_tokens'] = input_tokens
      processed_data[key]['items'][i]['output_tokens'] = output_tokens
      processed_data[key]['items'][i]['image_token_start_idx'] = image_token_start_idx
      processed_data[key]['items'][i]['image_token_end_idx'] = image_token_end_idx
      processed_data[key]['items'][i]['average_image_text_ratio'] = image_text_ratio
      processed_data[key]['items'][i]['main_input_text_tokens'] = main_input_text_tokens

      # more stuff
      processed_data[key]['items'][i]['predicted_answer'] = results["predicted_answer"]
      processed_data[key]['items'][i]['problem'] = {
         "problem_text": result["problem_text"],
         "choices": results["choices"],
         "answer": results["answer"],
         "image": results["image_id"] + ".png",
         "predicted_answer": results["predicted_answer"],
         "extracted_answer": results["extracted_answer"],
         "problem_type_graph": results["problem_type_graph"],
         "problem_type_goal": results["problem_type_goal"],
      }

    attention_results = processed_data[key].get('items', [])
    average_ratios = np.mean([result['average_image_to_text_ratio'] for result in attention_results])
    processed_data[key]['average_image_to_text_ratio'] = average_ratios
    
    return processed_data


In [6]:
import matplotlib.pyplot as plt

def plot_average_image_to_text_ratios(processed_data):
  # Extract keys and their corresponding average_image_to_text_ratio values
  keys = list(processed_data.keys())
  average_ratios = [processed_data[key]['average_image_to_text_ratio'] for key in keys]

  # Plotting the bar graph
  plt.figure(figsize=(10, 6))
  plt.bar(keys, average_ratios, color='skyblue')
  plt.xlabel('File Names')
  plt.ylabel('Average Image to Text Ratio')
  plt.title('Average Image to Text Ratio for Different File Names')
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()
  plt.show()

In [7]:
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import matplotlib.cm as cm
from matplotlib.colors import Normalize


def visualize_word_map(idx, processed_data, activations):
  # normalize ratios! 
  ratios = []
  for key, data in processed_data.items():
    if activations[key] == 1:
      item = data['items'][idx]
      ratios.append(item['average_image_to_text_ratio'])
  
  ratios_flat = ratios.flatten()
  vmin = min(ratios_flat)
  vmax = max(ratios_flat)
  norm = Normalize(vmin=vmin, vmax=vmax)
  normalized_ratios_flat = norm(ratios_flat)
  ratios = normalized_ratios_flat.reshape(ratios.shape)

  # plot the graph!
  main_idx = 0
  cmap = cm.Oranges
  for key, data in processed_data.items():
    if activations[key] == 1:
      fig, ax = plt.subplots(figsize=(3, 2))
      item = data['items'][idx]

      x_coord = 0
      y_coord = 0.5
      for i, (output_token, normalized_value) in enumerate(zip(item["output_tokens"], ratios[main_idx])):
        color = cmap(normalized_value)
        ax.text(x_coord, y_coord, output_token, ha='left', va='center', fontsize=30, 
                bbox=dict(facecolor=color, edgecolor='none', boxstyle='round,pad=0.3'))
        x_coord += (len(output_token)/6)
        if (i + 1) % 20 == 0:
            x_coord = 0
            y_coord -= 0.5  # Shift down a little for every 5 tokens

      # Remove axes
      ax.axis('off')

      # Adjust layout
      plt.tight_layout()
      plt.show()
      main_idx += 1

In [10]:
base = "results/files/"

file_names = {
  "baseline_direct": "results_geometry_3k_baseline_direct_20240601_211922",
  "baseline_cot": "results_geometry_3k_baseline_cot_20240601_194854",
  "finetuned_cot": "results_geometry_3k_finetuned_20240601_201028",
  "finetuned_logic": "results_geometry_3k_finetuned_logic_20240601_222942",
  "finetuned_peano": "results_geometry_3k_finetuned_peano_20240601_213131"
}

word_map_show = {
  "baseline_direct": 0,
  "baseline_cot": 1,
  "finetuned_cot": 1,
  "finetuned_logic": 0,
  "finetuned_peano": 0
}

processed_data = process_data(file_names, base)
plot_average_image_to_text_ratios(processed_data)

idx = 0
visualize_word_map(idx, processed_data, word_map_show)

/Users/josephtey/Projects/multimodal-reasoning


FileNotFoundError: [Errno 2] No such file or directory: 'results/files/results_geometry_3k_baseline_direct_20240601_211922/attention_weights/0.pt'