In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
from collections import defaultdict
from dataclasses import dataclass
from hmac import new
import json
from pathlib import Path
from typing import Optional, assert_never, Dict, List, Tuple
from torch import Tensor
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

import pandas as pd
import pyrallis
import torch
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
)

from scripts.create_slurm_file import run_slurm
from src.consts import (
    FILTERATIONS,
    MODEL_SIZES_PER_ARCH_TO_MODEL_ID,
    PATHS,
)
from src.datasets.download_dataset import load_dataset, load_splitted_counter_fact
from src.datasets.download_dataset import load_knowns_pd, get_hit_dataset
from src.logit_utils import get_last_token_logits, logits_to_probs
from src.models.model_interface import get_model_interface
from src.types import DATASETS
from src.types import MODEL_ARCH, SPLIT, DatasetArgs, TModelID
from src.utils.setup_models import get_tokenizer_and_model
from src.utils.slurm import submit_job
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.pyplot as plt
import numpy as np

In [35]:

@dataclass
class Args:
    # model_arch: MODEL_ARCH = MODEL_ARCH.MINIMAL_MAMBA2_new
    model_arch: MODEL_ARCH = MODEL_ARCH.MAMBA1
    model_size: str = "2.8B"
    dataset_args: DatasetArgs = pyrallis.field(
        default=DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all"), is_mutable=True
    )
    filteration: str = FILTERATIONS.all_correct
    _batch_size: int = 16  # Adjust based on GPU memory
    output_file: Optional[Path] = None
    with_slurm: bool = False
    temperature = 1
    top_k = 0
    top_p = 1
    window_size = 9
    prompt_indices = [1,2,3,4,5]
    knockout_map = {'last': ['last', 'first', "subject", "relation"], 
                    'subject': ['context', 'subject']}

    output_dir: Optional[Path] = None

    @property
    def batch_size(self) -> int:
        return (
            1
            if (
                self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2
                or self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2_new
            )
            else self._batch_size
        )

    @property
    def model_id(self) -> TModelID:
        return MODEL_SIZES_PER_ARCH_TO_MODEL_ID[self.model_arch][self.model_size]


In [36]:
args = Args()

In [37]:
probs = pd.read_json(PATHS.OUTPUT_DIR / f'{args.model_id}/info_flow_v6/ds={args.dataset_args.dataset_name}/ws={args.window_size}/block_last_target_last/outputs.json')

In [38]:
probs.shape

(810, 56)

In [None]:
def get_top5_knockout():

    # Open and read the JSON file
    data = pd.read_json(PATHS.OUTPUT_DIR / f'{args.model_id}/info_flow_test_top_outputs/ds={args.dataset_args.dataset_name}/ws={args.window_size}/block_last_target_last/outputs.json')

    # Print the data
    # print(data)
    new_df = defaultdict(list)
    for i, row in data.iterrows():
        row = row[0][0]
        # print(row)
        new_df['token'].append(row[1])
        new_df['prob'].append(row[2])
        new_df['token_id'].append(row[0])

    new_df = pd.DataFrame(new_df)
    return new_df
def calc_correct(merged):
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    merged['correct'] = merged.apply(lambda row: row['token_id'] == tokenizer(row['target_true'])['input_ids'][0], axis=1)
    print(f"Correct: {merged['correct'].sum()}, Incorrect: {(~merged['correct']).sum()}")
    # filtered = merged[merged['correct']]
    return merged

In [40]:
# merge on inded
merged = pd.merge(
    left=get_hit_dataset(args.model_id, args.dataset_args).reset_index(drop=True),
    right=get_top5_knockout(),
    left_index=True,
    right_index=True,
).pipe(calc_correct)

Correct: 762, Incorrect: 48


In [60]:
merged

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,original_idx,split,true_prob,max_prob,hit,pred,token,prob,token_id,correct
0,{} is a product of,,{} is a product of,Windows XP Media Center Edition is a product of,P178,Q312,Q2283,Microsoft,Apple,Windows XP Media Center Edition,9209,train1,0.726621,0.726621,True,Microsoft,Microsoft,0.960056,9664,True
1,"In {}, the language spoken is",In,", the language spoken is","In United Kingdom, the language spoken is",P37,Q1412,Q1860,English,Finnish,United Kingdom,18566,train1,0.529818,0.529818,True,English,English,0.999680,4383,True
2,{} is a native speaker of,,{} is a native speaker of,Henry de Montherlant is a native speaker of,P103,Q1860,Q150,French,English,Henry de Montherlant,6916,train1,0.748545,0.748545,True,French,French,0.987940,5112,True
3,The native language of {} is,The native language of,is,The native language of Olga Georges-Picot is,P103,Q7737,Q150,French,Russian,Olga Georges-Picot,140,train1,0.184752,0.184752,True,French,French,0.784016,5112,True
4,{} is owned by,,{} is owned by,Google Marketing Platform is owned by,P127,Q183,Q95,Google,Germany,Google Marketing Platform,2720,train1,0.757935,0.757935,True,Google,Google,0.998711,5559,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
805,{} is created by,,{} is created by,Microsoft InfoPath is created by,P178,Q312,Q2283,Microsoft,Apple,Microsoft InfoPath,15035,test,0.425025,0.425025,True,Microsoft,Microsoft,0.994226,9664,True
806,{} worked in the city of,,{} worked in the city of,Pavlo Skoropadskyi worked in the city of,P937,Q84,Q1899,Kiev,London,Pavlo Skoropadskyi,18585,test,0.134802,0.134802,True,K,K,0.220406,611,True
807,"{}, created by",,"{}, created by","IBM 3790, created by",P176,Q66,Q37156,IBM,Boeing,IBM 3790,19782,test,0.523892,0.523892,True,IBM,IBM,0.983616,21314,True
808,{} belongs to the continent of,,{} belongs to the continent of,Riiser-Larsen Ice Shelf belongs to the contine...,P30,Q48,Q51,Antarctica,Asia,Riiser-Larsen Ice Shelf,9105,test,0.820605,0.820605,True,Antar,Antar,0.994120,31913,True


In [55]:
def create_info_flow_animation(
    model_name:str,
    probs, 
    merged,
    point_size = 30,
    frames_per_transition = 20,  # Number of intermediate frames between original timesteps
    
):
    # Parameters
    first_frame = probs.columns[0]
    interval = 1000 / frames_per_transition

    # Original data
    correct = np.array(merged['correct'])  # Boolean column
    probs_aligned = np.array(probs.loc[:, first_frame:].T)  # Probability columns
    base_probs = np.array(merged['true_prob'])
    num_points = probs_aligned.shape[1]
    num_timesteps = probs_aligned.shape[0]

    # Interpolation to create smoother transitions
    new_num_timesteps = (num_timesteps) * frames_per_transition
    probs_smooth = np.zeros((new_num_timesteps, num_points))

    for i in range(num_timesteps - 1):
        start_frame = probs_aligned[i]
        end_frame = probs_aligned[i + 1]
        for j in range(frames_per_transition):
            alpha = j / frames_per_transition  # Linear interpolation factor
            probs_smooth[i * frames_per_transition + j] = (1 - alpha) * start_frame + alpha * end_frame

    # Add the final frame to the end, frames_per_transition times
    probs_smooth[-frames_per_transition:] = probs_aligned[-1]


    # Create figure and scatter plot
    fig, ax = plt.subplots()
    colors = np.where(correct, 'green', 'red')  # Green for True, Red for False
    scat = ax.scatter(base_probs, probs_smooth[0], s=point_size, c=colors, edgecolor='k')

    ax.set_xlabel("Base Probability")
    ax.set_ylabel("Knockout Probability")

    # Update function for animation
    def update(frame):
        scat.set_offsets(np.column_stack((base_probs, probs_smooth[frame])))
        if frame % frames_per_transition == 0:
            original_frame = frame // frames_per_transition
            ax.set_title(f"{model_name} - Point Movement by Probabilities - Window {first_frame + original_frame}/{first_frame + num_timesteps}")
        return scat,

    # Create the animation
    ani = FuncAnimation(fig, update, frames=new_num_timesteps, interval=interval, blit=True)

    # Display the animation in Jupyter Notebook
    plt.close(fig)

    base_path = PATHS.RESULTS_DIR / 'info_flow_animation'
    base_path.mkdir(parents=True, exist_ok=True)

    ani.save(base_path / f"{model_name}_{args.window_size}.gif", writer='imagemagick', fps=frames_per_transition)
    # HTML(ani.to_jshtml())

In [61]:
def all_combined(args: Args):
    probs = pd.read_json(
        PATHS.OUTPUT_DIR
        / f"{args.model_id}/info_flow_v6/ds={args.dataset_args.dataset_name}/ws={args.window_size}/block_last_target_last/outputs.json"
    )
    merged = (
        get_hit_dataset(args.model_id, args.dataset_args)
        .reset_index(drop=True)
        .pipe(lambda df: df.assign(correct=True))
    )
    # merged = pd.merge(
    #     left=get_hit_dataset(args.model_id, args.dataset_args).reset_index(drop=True),
    #     right=get_top5_knockout(),
    #     left_index=True,
    #     right_index=True,
    # ).pipe(calc_correct)
    create_info_flow_animation(args.model_id.split("/")[1], probs, merged)

In [62]:
window_sizes = [9]
for model_arch, model_size in [
    # (MODEL_ARCH.MAMBA1, "130M"),
    # (MODEL_ARCH.MAMBA1, "1.4B"),
    # (MODEL_ARCH.MAMBA1, "2.8B"),
    (MODEL_ARCH.MINIMAL_MAMBA2_new, "130M"),
    (MODEL_ARCH.MINIMAL_MAMBA2_new, "1.3B"),
    (MODEL_ARCH.MINIMAL_MAMBA2_new, "2.7B"),
]:
    args.model_arch = model_arch
    args.model_size = model_size
    for window_size in window_sizes:
        args.window_size = window_size

        print(f"Creating animation for {args.model_id} with window size {args.window_size}")
        all_combined(args)

MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba2-130M with window size 9
Creating animation for state-spaces/mamba2-1.3b with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.


Creating animation for state-spaces/mamba2-2.7B with window size 9


MovieWriter imagemagick unavailable; using Pillow instead.
