In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict
from dataclasses import dataclass
from hmac import new
import json
from pathlib import Path
from typing import Optional, assert_never, Dict, List, Tuple
from torch import Tensor
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

import pandas as pd
import pyrallis
import torch
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
)

from scripts.create_slurm_file import run_slurm
from src.consts import (
    FILTERATIONS,
    MODEL_SIZES_PER_ARCH_TO_MODEL_ID,
    PATHS,
)
from src.datasets.download_dataset import load_dataset, load_splitted_counter_fact
from src.datasets.download_dataset import load_knowns_pd, get_hit_dataset
from src.logit_utils import get_last_token_logits, logits_to_probs
from src.models.model_interface import get_model_interface
from src.types import DATASETS
from src.types import MODEL_ARCH, SPLIT, DatasetArgs, TModelID
from src.utils.setup_models import get_tokenizer_and_model
from src.utils.slurm import submit_job
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.pyplot as plt
import numpy as np

Not using causal_conv1d
Not using causal_conv1d
Not using causal_conv1d


In [3]:

@dataclass
class Args:
    # model_arch: MODEL_ARCH = MODEL_ARCH.MINIMAL_MAMBA2_new
    model_arch: MODEL_ARCH = MODEL_ARCH.MAMBA1
    model_size: str = "2.8B"
    dataset_args: DatasetArgs = pyrallis.field(
        default=DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all"), is_mutable=True
    )
    filteration: str = FILTERATIONS.all_correct
    _batch_size: int = 16  # Adjust based on GPU memory
    output_file: Optional[Path] = None
    with_slurm: bool = False
    temperature = 1
    top_k = 0
    top_p = 1
    window_size = 9
    prompt_indices = [1,2,3,4,5]
    knockout_map = {'last': ['last', 'first', "subject", "relation"], 
                    'subject': ['context', 'subject']}

    output_dir: Optional[Path] = None

    @property
    def batch_size(self) -> int:
        return (
            1
            if (
                self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2
                or self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2_new
            )
            else self._batch_size
        )

    @property
    def model_id(self) -> TModelID:
        return MODEL_SIZES_PER_ARCH_TO_MODEL_ID[self.model_arch][self.model_size]


In [4]:
args = Args()

In [89]:
def get_top5_knockout(
    block_source,
    block_target,
):
    # Open and read the JSON file
    data = pd.read_json(
        PATHS.OUTPUT_DIR / 
        f'{args.model_id}/info_flow_test_top_outputs/ds={args.dataset_args.dataset_name}/ws={args.window_size}'
        f'/block_{block_source}_target_{block_target}/outputs.json'
    )

    # Print the data
    # print(data)
    new_df = defaultdict(list)
    for i, row in data.iterrows():
        row = row[0][0]
        # print(row)
        new_df['token'].append(row[1])
        new_df['prob'].append(row[2])
        new_df['token_id'].append(row[0])

    new_df = pd.DataFrame(new_df)
    return new_df

def calc_correct(merged):
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    merged['correct'] = merged.apply(lambda row: row['token_id'] == tokenizer(row['target_true'])['input_ids'][0], axis=1)
    print(f"Correct: {merged['correct'].sum()}, Incorrect: {(~merged['correct']).sum()}")
    # filtered = merged[merged['correct']]
    return merged

def parse_outputs(p):
    data = json.load(p.open('r'))
    d = defaultdict(list)
    for window, outputs in data.items():
        for metric, output in outputs.items():
            d[metric].append(output)
            

    return {
        k: pd.DataFrame.from_dict(v).T
        for k, v in d.items()
    }

In [93]:
def create_info_flow_animation(
    title: str,
    output_path: Path,
    prompts: pd.DataFrame,
    probs_per_window: pd.DataFrame,
    correct_per_window: pd.DataFrame,
    point_size=15,
    frames_per_transition=20,  # Number of intermediate frames between original timesteps
    seconds_per_transition=1,  # Time in seconds for each transition
):
    # Parameters
    first_frame = int(probs_per_window.columns[0])
    interval = (seconds_per_transition * 1000) / frames_per_transition
    frames_per_second = int(frames_per_transition / seconds_per_transition)

    # Original data
    correct = np.array(correct_per_window).T # shape (num_windows, num_points)
    probs_aligned = np.array(probs_per_window).T # shape (num_windows, num_points)
    base_probs = np.array(prompts["true_prob"]) # shape (num_points,)
    num_points = probs_aligned.shape[1]
    num_timesteps = probs_aligned.shape[0]
    accuracy = correct.sum(axis=1) / num_points

    # Interpolation to create smoother transitions
    new_num_timesteps = num_timesteps * frames_per_transition
    probs_smooth = np.zeros((new_num_timesteps, num_points))

    for i in range(num_timesteps - 1):
        start_frame = probs_aligned[i]
        end_frame = probs_aligned[i + 1]
        for j in range(frames_per_transition):
            alpha = j / frames_per_transition  # Linear interpolation factor
            probs_smooth[i * frames_per_transition + j] = (
                1 - alpha
            ) * start_frame + alpha * end_frame

    # Add the final frame to the end, frames_per_transition times
    probs_smooth[-frames_per_transition:] = probs_aligned[-1]

    # Create figure and scatter plot
    fig, ax = plt.subplots()
    colors = np.where(correct, "green", "red")  # Green for True, Red for False
    scat = ax.scatter(
        base_probs, probs_smooth[0], s=point_size, c=colors[0], edgecolor="k"
    )

    ax.set_xlabel("Base Probability")
    ax.set_ylabel("Knockout Probability")

    # Update function for animation
    def update(frame):
        scat.set_offsets(np.column_stack((base_probs, probs_smooth[frame])))
        if frame % frames_per_transition == 0:
            # change the color
            current_og_frame = frame // frames_per_transition
            scat.set_color(colors[int(current_og_frame)])
            ax.set_title(
                f"{title}\nProbabilities - Window {first_frame + current_og_frame}/{first_frame + num_timesteps-1} - Accuracy: {accuracy[current_og_frame]*100:.1f}%"
            )
        return (scat,)

    # Create the animation
    ani = FuncAnimation(
        fig, update, frames=new_num_timesteps, interval=interval, blit=True
    )

    # Display the animation in Jupyter Notebook
    plt.close(fig)
    ani.save(output_path, writer="imagemagick", fps=frames_per_second)
    # HTML(ani.to_jshtml())

In [94]:
def all_combined(args: Args):
    for block_source in ['last', 'subject']:
        for block_target in ['last']:
            outputs = parse_outputs(
                PATHS.OUTPUT_DIR
                / f"{args.model_id}/info_flow_v7/ds={args.dataset_args.dataset_name}/ws={args.window_size}/block_{block_source}_target_{block_target}/outputs.json"
            )
            
            output_path = (
                PATHS.RESULTS_DIR 
                / 'info_flow_animation'
                / f"{args.model_id.split('/')[1]}_{args.window_size}_{block_source}_{block_target}.gif"
            )
            output_path.parent.mkdir(parents=True, exist_ok=True)
            title = f"{args.model_id.split('/')[1]} - ws={args.window_size} - source={block_source} - target={block_target}"
            create_info_flow_animation(
                title, output_path, 
                prompts = get_hit_dataset(args.model_id, args.dataset_args).reset_index(drop=True),
                correct_per_window = outputs['hit'],
                probs_per_window = outputs['true_probs'],
            )

In [None]:
window_sizes = [9]
for model_arch, model_size in [
    (MODEL_ARCH.MAMBA1, "130M"),
    (MODEL_ARCH.MAMBA1, "1.4B"),
    (MODEL_ARCH.MAMBA1, "2.8B"),
    (MODEL_ARCH.MINIMAL_MAMBA2_new, "130M"),
    (MODEL_ARCH.MINIMAL_MAMBA2_new, "1.3B"),
    (MODEL_ARCH.MINIMAL_MAMBA2_new, "2.7B"),
]:
    args.model_arch = model_arch
    args.model_size = model_size
    for window_size in window_sizes:
        args.window_size = window_size

        print(f"Creating animation for {args.model_id} with window size {args.window_size}")
        all_combined(args)
        
        # display(all_combined(args))
        # break

MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba-130M-hf with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.
MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba-1.4B-hf with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba-2.8B-hf with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.
MovieWriter imagemagick unavailable; using Pillow instead.
MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba2-130M with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba2-1.3b with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.
MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba2-2.7B with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.
MovieWriter imagemagick unavailable; using Pillow instead.
