In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re
from collections import defaultdict

from PIL import Image, ImageDraw, ImageFont

from src.consts import GRAPHS_ORDER, MODEL_SIZES_PER_ARCH_TO_MODEL_ID, PATHS
from src.types import DATASETS, MODEL_ARCH, DatasetArgs

In [3]:
# Parameters
ds = DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all")

In [5]:
MODEL_TO_VERSION = {
    MODEL_ARCH.MAMBA1: "_debug_10_last_windows",
    MODEL_ARCH.MINIMAL_MAMBA2_new: "_debug_10_last_windows",
}

In [None]:
def display_all_heatmaps():
    pattern = r"/state-spaces/(?P<model_id>[\w\.-]+)/info_flow(?P<version>[\w]+)/ds=(?P<dataset>[\w_]+)/ws=(?P<window_size>\d+)/results_ws=\d+_knockout_target=last.png"

    models_ws = defaultdict(lambda: defaultdict(list))
    ws_opts = set()

    # Populate `models_ws` and `ws_opts`
    for model, size in GRAPHS_ORDER:
        model_id = MODEL_SIZES_PER_ARCH_TO_MODEL_ID[model][size]
        model_dir = PATHS.OUTPUT_DIR / f"{model_id}/info_flow{MODEL_TO_VERSION[model]}"
        for file in model_dir.rglob("*.png"):
            match = re.search(pattern, str(file))
            if not match:
                continue
            img = Image.open(file)
            window_size = match.group("window_size")
            models_ws[model_id][window_size].append(img)
            ws_opts.add(window_size)

    # Get image dimensions
    sample_img = next(iter(next(iter(models_ws.values())).values()))[0]
    img_width, img_height = sample_img.size  # Original dimensions of the images
    padding = 10
    title_height = 30  # Height for titles

    # Calculate grid size
    ws_opts = sorted(ws_opts)  # Ensure consistent order of window sizes
    num_rows = len(ws_opts)
    num_cols = len(models_ws)

    canvas_width = num_cols * (img_width + padding)
    canvas_height = num_rows * (img_height + title_height + padding)

    # Create a blank image
    combined_image = Image.new("RGB", (canvas_width, canvas_height), "white")
    draw = ImageDraw.Draw(combined_image)
    font = ImageFont.load_default()  # Use a default font

    # Populate canvas with images and titles
    for row_idx, window_size in enumerate(ws_opts):
        y_offset = row_idx * (img_height + title_height + padding)
        for col_idx, (model_id, ws) in enumerate(models_ws.items()):
            x_offset = col_idx * (img_width + padding)
            images = ws.get(str(window_size), [])
            if not images:
                continue
            img = images[0]  # Use the first image for the given window size
            # Add image to canvas
            combined_image.paste(
                img.resize((img_width, img_height)),
                (x_offset, y_offset + title_height),
            )
            # Add title above the image
            title_text = f"{model_id.split('/')[1]} (ws={window_size})"
            draw.text((x_offset, y_offset), title_text, fill="black")

    # Save or show the combined image
    combined_image.save(PATHS.RESULTS_DIR / "combined_debug_10_last_windows.png")


# Example usage
display_all_heatmaps()

In [16]:
include_window_sizes = ["9", "15"]
MODEL_TO_VERSION = {
    MODEL_ARCH.MAMBA1: "_v6",
    MODEL_ARCH.MINIMAL_MAMBA2_new: "_v6",
}


def display_all_heatmaps():
    for knockout_target in [
        "last",
        "relation",
        "subject",
    ]:
        pattern = rf"/state-spaces/(?P<model_id>[\w\.-]+)/info_flow(?P<version>[\w]*)/ds=(?P<dataset>[\w_]+)/ws=(?P<window_size>\d+)/results_ws=\d+_knockout_target={knockout_target}.png"

        models_ws = defaultdict(lambda: defaultdict(list))
        ws_opts = set()

        # Populate `models_ws` and `ws_opts`
        for model, size in GRAPHS_ORDER:
            model_id = MODEL_SIZES_PER_ARCH_TO_MODEL_ID[model][size]
            model_dir = PATHS.OUTPUT_DIR / f"{model_id}/info_flow{MODEL_TO_VERSION[model]}"
            for file in model_dir.rglob("*.png"):
                match = re.search(pattern, str(file))
                if not match:
                    continue
                img = Image.open(file)
                window_size = match.group("window_size")
                if window_size not in include_window_sizes:
                    continue
                models_ws[model_id][window_size].append(img)
                ws_opts.add(window_size)

        # return models_ws, ws_opts
        # Get image dimensions
        sample_img = next(iter(next(iter(models_ws.values())).values()))[0]
        img_width, img_height = sample_img.size  # Original dimensions of the images
        padding = 10
        title_height = 30  # Height for titles

        # Calculate grid size
        ws_opts = sorted(ws_opts)  # Ensure consistent order of window sizes
        num_rows = len(ws_opts)
        num_cols = len(models_ws)

        canvas_width = num_cols * (img_width + padding)
        canvas_height = num_rows * (img_height + title_height + padding)

        # Create a blank image
        combined_image = Image.new("RGB", (canvas_width, canvas_height), "white")
        draw = ImageDraw.Draw(combined_image)
        font = ImageFont.load_default()  # Use a default font

        # Populate canvas with images and titles
        for row_idx, window_size in enumerate(ws_opts):
            y_offset = row_idx * (img_height + title_height + padding)
            for col_idx, (model_id, ws) in enumerate(models_ws.items()):
                x_offset = col_idx * (img_width + padding)
                images = ws.get(str(window_size), [])
                if not images:
                    continue
                img = images[0]  # Use the first image for the given window size
                # Add image to canvas
                combined_image.paste(
                    img.resize((img_width, img_height)),
                    (x_offset, y_offset + title_height),
                )
                # Add title above the image
                title_text = f"{model_id.split('/')[1]} (ws={window_size})"
                draw.text((x_offset, y_offset), title_text, fill="black")

        # Save or show the combined image
        combined_image.save(PATHS.RESULTS_DIR / f"combined_info_flows_target={knockout_target}.png")


# Example usage
display_all_heatmaps()

In [27]:
def display_all_heatmaps_for_ws():
    for WS in ["9", "15"]:
        for knockout_target in [
            "last",
            "relation",
            "subject",
        ]:
            pattern = rf"/state-spaces/(?P<model_id>[\w\.-]+)/info_flow(?P<version>[\w]*)/ds=(?P<dataset>[\w_]+)/ws=(?P<window_size>\d+)/results_ws=\d+_knockout_target={knockout_target}.png"

            models_ws = defaultdict(lambda: defaultdict(list))

            # Populate `models_ws` and `ws_opts`
            for model, size in GRAPHS_ORDER:
                model_id = MODEL_SIZES_PER_ARCH_TO_MODEL_ID[model][size]
                model_dir = PATHS.OUTPUT_DIR / f"{model_id}/info_flow{MODEL_TO_VERSION[model]}"
                for file in model_dir.rglob("*.png"):
                    match = re.search(pattern, str(file))
                    if not match:
                        continue
                    window_size = match.group("window_size")
                    if window_size != WS:  # Filter only for ws=9
                        continue
                    img = Image.open(file)
                    models_ws[model_id][window_size].append(img)

            # return models_ws
            if not models_ws:
                print(f"No images found for ws={WS}.")
                return

            # Get image dimensions
            sample_img = next(iter(next(iter(models_ws.values())).values()))[0]
            img_width, img_height = sample_img.size  # Original dimensions of the images
            padding = 10
            title_height = 30  # Height for titles

            # Calculate grid size
            num_cols = 2  # Only two models side by side
            num_rows = -(-len(models_ws) // num_cols)  # Ceiling division for number of rows

            canvas_width = num_cols * (img_width + padding)
            canvas_height = num_rows * (img_height + title_height + padding)

            # Create a blank image
            combined_image = Image.new("RGB", (canvas_width, canvas_height), "white")
            draw = ImageDraw.Draw(combined_image)
            font = ImageFont.load_default()  # Use a default font

            # Populate canvas with images and titles
            model_list = list(models_ws.items())
            for idx, (model_id, ws) in enumerate(model_list):
                col_idx = idx % num_cols
                row_idx = idx // num_cols
                x_offset = col_idx * (img_width + padding)
                y_offset = row_idx * (img_height + title_height + padding)
                images = ws.get(WS, [])
                if not images:
                    print(f"No images for model_id={model_id} and ws={WS}.")  # Debug
                    continue
                img = images[0]  # Use the first image for ws=9
                # Add image to canvas
                combined_image.paste(
                    img.resize((img_width, img_height)),
                    (x_offset, y_offset + title_height),
                )
                # Add title above the image
                title_text = f"{model_id.split('/')[1]} (ws={WS})"
                draw.text((x_offset, y_offset), title_text, fill="black")

            # Save or show the combined image
            combined_image.save(PATHS.RESULTS_DIR / f"combined_info_flows_ws_{WS}_target={knockout_target}.png")


# Example usage
display_all_heatmaps_for_ws()

In [6]:
MODEL_TO_VERSION = {
    MODEL_ARCH.MAMBA1: "_v6",
    MODEL_ARCH.MINIMAL_MAMBA2_new: "_v6",
}


def display_all_heatmaps_for_ws():
    for WS in ["9", "15"]:
        for knockout_target in [
            "last",
            "relation",
            "subject",
        ]:
            for metric in ["accuracy", "norm_change"]:
                results_var = "results"
                results_dir_name = "results_for_multi_plot"

                pattern = rf"/state-spaces/(?P<model_id>[\w\.-]+)/info_flow(?P<version>[\w]*)/ds=(?P<dataset>[\w_]+)/ws=(?P<window_size>\d+)/{results_dir_name}/knockout_target={knockout_target}/{metric}.png"

                models_ws = defaultdict(lambda: defaultdict(list))

                # Populate `models_ws` and `ws_opts`
                for model, size in GRAPHS_ORDER:
                    model_id = MODEL_SIZES_PER_ARCH_TO_MODEL_ID[model][size]
                    model_dir = PATHS.OUTPUT_DIR / f"{model_id}/info_flow{MODEL_TO_VERSION[model]}"
                    for file in model_dir.rglob("*.png"):
                        match = re.search(pattern, str(file))
                        if not match:
                            continue
                        window_size = match.group("window_size")
                        if window_size != WS:  # Filter only for ws=9
                            continue
                        img = Image.open(file)
                        models_ws[model_id][window_size].append(img)

                # return models_ws
                if not models_ws:
                    print(f"No images found for ws={WS}.")
                    return

                # Get image dimensions
                sample_img = next(iter(next(iter(models_ws.values())).values()))[0]
                img_width, img_height = sample_img.size  # Original dimensions of the images
                padding = 10
                title_height = 0  # Height for titles

                # Calculate grid size
                num_cols = 2  # Only two models side by side
                num_rows = -(-len(models_ws) // num_cols)  # Ceiling division for number of rows

                canvas_width = num_cols * (img_width + padding)
                canvas_height = num_rows * (img_height + title_height + padding)

                # Create a blank image
                combined_image = Image.new("RGB", (canvas_width, canvas_height), "white")
                draw = ImageDraw.Draw(combined_image)

                # Populate canvas with images and titles
                model_list = list(models_ws.items())
                for idx, (model_id, ws) in enumerate(model_list):
                    col_idx = idx % num_cols
                    row_idx = idx // num_cols
                    x_offset = col_idx * (img_width + padding)
                    y_offset = row_idx * (img_height + title_height + padding)
                    images = ws.get(WS, [])
                    if not images:
                        print(f"No images for model_id={model_id} and ws={WS}.")  # Debug
                        continue
                    img = images[0]  # Use the first image for ws=9
                    # Add image to canvas
                    combined_image.paste(
                        img.resize((img_width, img_height)),
                        (x_offset, y_offset + title_height),
                    )
                    # Add title above the image
                    # title_text = f"{model_id.split('/')[1]} (ws={WS})"
                    # draw.text((x_offset, y_offset), title_text, fill="black")

                # Save or show the combined image
                base_path = PATHS.RESULTS_DIR / "combined_info_flows"
                base_path.mkdir(parents=True, exist_ok=True)
                combined_image.save(base_path / f"ws_{WS}_target={knockout_target}_metric={metric}.png")


# Example usage
display_all_heatmaps_for_ws()