In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
from collections import defaultdict
from dataclasses import dataclass
from hmac import new
import json
from pathlib import Path
from typing import Optional, assert_never, Dict, List, Tuple
from torch import Tensor
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

import pandas as pd
import pyrallis
import torch
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
)

from scripts.create_slurm_file import run_slurm
from src.consts import (
    FILTERATIONS,
    MODEL_SIZES_PER_ARCH_TO_MODEL_ID,
    PATHS,
)
from src.datasets.download_dataset import load_dataset, load_splitted_counter_fact
from src.datasets.download_dataset import load_knowns_pd
from src.logit_utils import get_last_token_logits, logits_to_probs
from src.models.model_interface import get_model_interface
from src.types import DATASETS
from src.types import MODEL_ARCH, SPLIT, DatasetArgs, TModelID
from src.utils.setup_models import get_tokenizer_and_model
from src.utils.slurm import submit_job

Not using causal_conv1d
Not using causal_conv1d
Not using causal_conv1d


In [3]:
@dataclass
class Args:
    # model_arch: MODEL_ARCH = MODEL_ARCH.MINIMAL_MAMBA2_new
    model_arch: MODEL_ARCH = MODEL_ARCH.MAMBA1
    model_size: str = "2.8B"
    dataset_args: DatasetArgs = pyrallis.field(
        default=DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all"), is_mutable=True
    )
    filteration: str = FILTERATIONS.all_correct
    _batch_size: int = 16  # Adjust based on GPU memory
    output_file: Optional[Path] = None
    with_slurm: bool = False
    temperature = 1
    top_k = 0
    top_p = 1
    window_size = 5
    prompt_indices = [1,2,3,4,5]
    knockout_map = {'last': ['last', 'first', "subject", "relation"], 
                    'subject': ['context', 'subject']}

    output_dir: Optional[Path] = None

    @property
    def batch_size(self) -> int:
        return (
            1
            if (
                self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2
                or self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2_new
            )
            else self._batch_size
        )

    @property
    def model_id(self) -> TModelID:
        return MODEL_SIZES_PER_ARCH_TO_MODEL_ID[self.model_arch][self.model_size]


In [4]:
args = Args()

In [5]:

original_res, attn_res = [
    pd.read_parquet(
        PATHS.OUTPUT_DIR
        / args.model_id
        / "data_construction"
        / f"ds={args.dataset_args.dataset_name}"
        / f"entire_results_{"attention" if attention else "original"}.parquet"
    )
    for attention in [True, False]
]

mask = (original_res["hit"] == attn_res["hit"]) & (attn_res["hit"] == True)
data = attn_res[mask]


In [6]:
probs = pd.read_json(PATHS.OUTPUT_DIR / 'state-spaces/mamba-2.8B-hf/info_flow_test_top_outputs_5_last_windows/ds=counter_fact/ws=9/block_last_target_last/outputs.json')

In [7]:
from email.policy import default


def get_top5_knockout():

    # Open and read the JSON file
    data = pd.read_json(PATHS.OUTPUT_DIR / 'state-spaces/mamba-2.8B-hf/info_flow_test_top_outputs/ds=counter_fact/ws=9/block_last_target_last/outputs.json')

    # Print the data
    # print(data)
    new_df = defaultdict(list)
    for i, row in data.iterrows():
        row = row[0][0]
        # print(row)
        new_df['token'].append(row[1])
        new_df['prob'].append(row[2])
        new_df['token_id'].append(row[0])

    new_df = pd.DataFrame(new_df)
    return new_df
top_outputs = get_top5_knockout()
top_outputs

Unnamed: 0,token,prob,token_id
0,Microsoft,0.960056,9664
1,English,0.999680,4383
2,French,0.987940,5112
3,French,0.784016,5112
4,Google,0.998711,5559
...,...,...,...
805,Microsoft,0.994226,9664
806,K,0.220406,611
807,IBM,0.983616,21314
808,Antar,0.994120,31913


In [8]:
def load_dataset():
    from src.consts import (
        FILTERATIONS,
        MODEL_SIZES_PER_ARCH_TO_MODEL_ID,
        PATHS,
    )
    # "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-2.8B-hf/data_construction/ds=counter_fact/entire_results_original.parquet"
    original_res, attn_res = [
        pd.read_parquet(
            PATHS.OUTPUT_DIR
            / 'state-spaces/mamba-2.8B-hf'
            / "data_construction"
            / f"ds=counter_fact"
            / f"entire_results_{"attention" if attention else "original"}.parquet"
        )
        for attention in [True, False]
    ]

    mask = (original_res["hit"] == attn_res["hit"]) & (attn_res["hit"] == True)
    data = attn_res[mask]
    data = data.reset_index(drop=True)
    return data

def calc_correct(merged):
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    merged['correct'] = merged.apply(lambda row: row['token_id'] == tokenizer(row['target_true'])['input_ids'][0], axis=1)
    print(merged['correct'].sum())
    print((~merged['correct']).sum())
    # filtered = merged[merged['correct']]
    return merged

In [9]:
ds = load_dataset()
top5 = get_top5_knockout()


In [10]:
# merge on inded
merged = (
    ds.merge(top5, left_index=True, right_index=True)
    .pipe(calc_correct)
    )

762
48


In [11]:
merged['true_prob']

0      0.726621
1      0.529818
2      0.748545
3      0.184752
4      0.757935
         ...   
805    0.425025
806    0.134802
807    0.523892
808    0.820605
809    0.606582
Name: true_prob, Length: 810, dtype: float64

In [12]:
merged['correct']

0      True
1      True
2      True
3      True
4      True
       ... 
805    True
806    True
807    True
808    True
809    True
Name: correct, Length: 810, dtype: bool

In [13]:
probs

Unnamed: 0,51,52,53,54,55
0,0.938020,0.844909,0.845702,0.753998,0.960056
1,0.883308,0.908491,0.966957,0.306658,0.999680
2,0.567339,0.481568,0.800203,0.291283,0.987940
3,0.211515,0.170901,0.232790,0.017412,0.784016
4,0.857686,0.968239,0.936548,0.858015,0.998711
...,...,...,...,...,...
805,0.737593,0.677925,0.767271,0.423829,0.994226
806,0.053069,0.039376,0.109175,0.027217,0.220406
807,0.906816,0.866823,0.735856,0.498270,0.983616
808,0.841295,0.855963,0.965536,0.899193,0.994120


In [14]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.pyplot as plt
import numpy as np

In [157]:
point_size = 30
interval = 100
cmap = 'viridis'

correct = merged['correct']  # Boolean column
probs_aligned = probs.loc[:, 51:55]  # Probability columns

correct = np.array(correct)
probs_aligned = np.array(probs_aligned.T)
base_probs = np.array(merged['true_prob'])
num_points = probs_aligned.shape[1]
num_timesteps = probs_aligned.shape[0]
print(num_points, num_timesteps)

# Create figure and scatter plot
fig, ax = plt.subplots()
colors = np.where(correct, 'green', 'red')  # Green for True, Red for False
scat = ax.scatter(base_probs, probs_aligned[0], s=point_size, c=colors, edgecolor='k')

ax.set_xlabel("Probability")
ax.set_ylabel("Point Index")

# Update function for animation
def update(frame):
    scat.set_offsets(np.column_stack((base_probs, probs_aligned[frame])))
    ax.set_title(f"Point Movement by Probabilities - Frame {frame+1}/{num_timesteps}")
    return scat, 

# Create the animation
ani = FuncAnimation(fig, update, frames=num_timesteps, interval=interval, blit=True)

# ani.save('movement_animation.gif', writer='imagemagick', fps=1000/interval)
# Display the animation in Jupyter Notebook
plt.close(fig)
HTML(ani.to_jshtml())



810 5


In [162]:
new_num_timesteps

241

In [23]:
probs.columns[0]

51

In [24]:
len(probs.columns)

5

In [33]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# Parameters
first_frame = probs.columns[0]
point_size = 30
interval = 1000 / 60  # 60 fps, interval in milliseconds
cmap = 'viridis'
frames_per_transition = 60  # Number of intermediate frames between original timesteps

# Original data
correct = np.array(merged['correct'])  # Boolean column
probs_aligned = np.array(probs.loc[:, first_frame:].T)  # Probability columns
base_probs = np.array(merged['true_prob'])
num_points = probs_aligned.shape[1]
num_timesteps = probs_aligned.shape[0]

# Interpolation to create smoother transitions
new_num_timesteps = (num_timesteps) * frames_per_transition
probs_smooth = np.zeros((new_num_timesteps, num_points))

for i in range(num_timesteps - 1):
    start_frame = probs_aligned[i]
    end_frame = probs_aligned[i + 1]
    for j in range(frames_per_transition):
        alpha = j / frames_per_transition  # Linear interpolation factor
        probs_smooth[i * frames_per_transition + j] = (1 - alpha) * start_frame + alpha * end_frame

# Add the final frame to the end, frames_per_transition times
probs_smooth[-frames_per_transition:] = probs_aligned[-1]


# Create figure and scatter plot
fig, ax = plt.subplots()
colors = np.where(correct, 'green', 'red')  # Green for True, Red for False
scat = ax.scatter(base_probs, probs_smooth[0], s=point_size, c=colors, edgecolor='k')

ax.set_xlabel("Base Probability")
ax.set_ylabel("Knockout Probability")

# Update function for animation
def update(frame):
    scat.set_offsets(np.column_stack((base_probs, probs_smooth[frame])))
    if frame % frames_per_transition == 0:
        original_frame = frame // frames_per_transition
        ax.set_title(f"Point Movement by Probabilities - Window {first_frame + original_frame}/{first_frame + num_timesteps}")
    return scat,

# Create the animation
ani = FuncAnimation(fig, update, frames=new_num_timesteps, interval=interval, blit=True)

# Display the animation in Jupyter Notebook
plt.close(fig)
# HTML(ani.to_jshtml())

In [34]:
ani.save('movement_animation.gif', writer='imagemagick', fps=frames_per_transition)

MovieWriter imagemagick unavailable; using Pillow instead.
