In [1]:
%load_ext autoreload
%autoreload 2


In [4]:

import functools
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
from tqdm.auto import tqdm
from transformers import AutoTokenizer

import src.plots as plots
from src.consts import GRAPHS_ORDER, MODEL_SIZES_PER_ARCH_TO_MODEL_ID, PATHS
from src.utils.logit_utils import decode_tokens
from src.types import DATASETS, MODEL_ARCH, DatasetArgs




In [5]:
# Parameters


MODEL_TO_HEATMAP_VERSION = {
    MODEL_ARCH.MAMBA1: "_v6",
    MODEL_ARCH.MINIMAL_MAMBA2_new: "_v6",
}

ds = DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all")


In [6]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token_id = tokenizer.eos_token_id


In [7]:
models_data = {}


In [8]:
models_data.keys()


dict_keys([])

In [10]:
plot_suffix_to_function = {
    # '_simple': plots.plot_simple_heatmap,
    '_simple_diff_fixed_0.3': functools.partial(plots.simple_diff_fixed, fixed_diff=0.3),
    # '_minimal_title_simple_diff_fixed_0.3': functools.partial(plots.simple_diff_fixed, fixed_diff=0.3, minimal_title=True),
    # '_simple_diff_fixed_0.2': functools.partial(plots.simple_diff_fixed, fixed_diff=0.2),
    # '_simple_diff_fixed_0.3': functools.partial(plots.simple_diff_fixed, fixed_diff=0.3),
    # '_robust': plots.plot_heatmap_robust,
    # '_robust_diff': plots.plot_heatmap_robust_diff,
    # '_diff_symlog': plots.plot_heatmap_diff_symlog,
}


In [37]:


pattern = r"/state-spaces/(?P<model_id>[\w\.-]+)/heatmap(?P<version>_v\d+)/ds=(?P<dataset>[\w_]+)/ws=(?P<window_size>\d+)/idx=(?P<prompt_idx>\d+)\.npy"

for p in tqdm(list(PATHS.OUTPUT_DIR.rglob('*.npy'))[:]):
    match = re.search(pattern, str(p))
    if match:
        details = match.groupdict()
        if details['version'] != '_v6': continue
        model_id = details['model_id']
        # if 'mamba-1.4B' not in model_id: continue
        # if 'mamba-2.8B' not in model_id: continue
        window_size = details['window_size']
        prompt_idx = int(details['prompt_idx'])
        
        if model_id not in models_data:
            print(f"fetching data for {model_id}")
            original_res, attn_res = [
                pd.read_parquet(
                    PATHS.OUTPUT_DIR
                    / 'state-spaces'
                    / model_id
                    / "data_construction"
                    / f"ds={details['dataset']}"
                    / f"entire_results_{"attention" if attention else "original"}.parquet"
                )
                for attention in [True, False]
            ]

            mask = (original_res["hit"] == attn_res["hit"]) & (attn_res["hit"] == True)
            models_data[model_id] = attn_res[mask]
        
        data = models_data[model_id]
        prompt = data.loc[prompt_idx, "prompt"]
        true_word = data.loc[prompt_idx, "target_true"]
        base_prob = data.loc[prompt_idx, "true_prob"]
        tokens = tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = tokens.input_ids
        toks = decode_tokens(tokenizer, input_ids[0])
        last_tok = toks[-1]
        toks[-1] = toks[-1] + "*"
        
        prob_mat = np.load(p)
        for plot_suffix, plot_func in plot_suffix_to_function.items():
            fig, _ = plot_func(
                prob_mat=prob_mat,
                model_id=model_id,
                window_size=window_size,
                last_tok=last_tok,
                base_prob=base_prob,
                true_word=true_word,
                toks=toks,
            )
            plt.savefig(p.parent / f"idx={prompt_idx}{plot_suffix}.png", bbox_inches="tight")
            plt.close(fig)

        # break
        
        
    

  0%|          | 0/107 [00:00<?, ?it/s]

In [24]:
def display_all_heatmaps(suffix):
    pattern = f"idx=*{suffix}.png"
    for i, size_cat in enumerate(['small', 'medium', 'large']):
        requested_models = GRAPHS_ORDER[2*i:2*(i+1)]
        for requested_ws in [1,5,9]:        
            prompts_ws_models = defaultdict(lambda: defaultdict(list))
            ws_opts = set()

            img_width = 0
            img_height = 0
            for model, size in requested_models:
                model_id = MODEL_SIZES_PER_ARCH_TO_MODEL_ID[model][size]
                model_dir = (
                    PATHS.OUTPUT_DIR / f"{model_id}/heatmap{MODEL_TO_HEATMAP_VERSION[model]}"
                )
                for file in model_dir.rglob(pattern):
                    
                    window_size = re.search(r"ws=(\d+)", str(file)).group(1)
                    
                    if match := re.search(fr"idx=(\d+){suffix}.png", str(file)):
                        prompt_id = match.group(1)
                    else:
                        continue
                    
                    if int(window_size) != requested_ws: continue
                    
                    img = Image.open(file)
                    img_width = max(img_width, img.width)
                    img_height = max(img_height, img.height)

                    prompts_ws_models[prompt_id][window_size].append(
                        (model_id.split("/")[1], img)
                    )
                    ws_opts.add(window_size)

                padding = 10
                # title_height = 30  # Height for titles
                title_height = 0  # Height for titles

                # Calculate grid size
                num_rows = len(prompts_ws_models) * len(ws_opts)
                num_cols = len(requested_models)

                canvas_width = num_cols * (img_width + padding)
                canvas_height = num_rows * (img_height + title_height + padding)

                # Create a blank image
                combined_image = Image.new("RGB", (canvas_width, canvas_height), "white")
                draw = ImageDraw.Draw(combined_image)

                # Positioning variables
                y_offset = 0  # Tracks vertical position on canvas

                # Populate canvas with images and titles
                for prompt_id, ws_models in prompts_ws_models.items():
                    for window_size, models in ws_models.items():
                        x_offset = 0  # Reset horizontal position for each row
                        for model_name, img in models:
                            # Add image to canvas
                            combined_image.paste(
                                img.resize((img_width, img_height)),
                                (x_offset, y_offset + title_height),
                            )

                            # Add title above the image
                            # title_text = f"{model_name} (ws={window_size})"
                            # draw.text((x_offset, y_offset), title_text, fill="black")

                            # Update x_offse\t for next column
                            x_offset += img_width + padding

                        # Update y_offset for the next row
                        y_offset += img_height + title_height + padding

                # Save or show the combined image
                base_dir = PATHS.RESULTS_DIR / "combined_heatmaps" / suffix
                base_dir.mkdir(exist_ok=True, parents=True)
                combined_image.save(base_dir / f"ws={requested_ws}_{size_cat}.png")
                # combined_image.show()


for suffix in tqdm(list(plot_suffix_to_function.keys())):
    # Example usage
    display_all_heatmaps(suffix)


  0%|          | 0/2 [00:00<?, ?it/s]