In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

# Load Model

In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    # "pythia-2.8b",
    # "pythia-70m-deduped",
    # "tiny-stories-2L-33M",
    # "attn-only-2l",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
    fold_ln=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)


# Load SAE

In [None]:


path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/1100001280_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt"
# path = "./artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/final_sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152.pt"
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)

print(sparse_autoencoder.cfg)


# sanity check
text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
model(text, return_type="loss")

In [None]:
from sae_training.utils import LMSparseAutoencoderSessionloader
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

## Feature Dashboard generator util

In [None]:
import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    
    return

# for feature in [100,300,400]:
#     render_feature_dashboard(feature)

# Fun Examples

In [None]:
prompt1 = "The war caused not only destruction and death but also generations of hatred between the two communities."
prompt2 = "The car not only is economical but also feels good to drive."
prompt3 = "This investigation is not only one that is continuing and worldwide,"  # but also one that we expect to continue for quite some time."
prompt = prompt3
answer = "but"
model.reset_hooks()
utils.test_prompt(prompt, answer, model)


In [None]:
import joseph
reload(joseph.analysis)
from joseph.analysis import *

prompt3 = "This investigation is not only one that is continuing and worldwide, but also one that we expect to continue for quite some time."
token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt3], model, sparse_autoencoder, head_idx_override=7)
print(token_df.columns)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
               "top_k_features"]
token_df[filter_cols].style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm")

In [None]:
POS_INTEREST = 12# index from 0.
print(token_df.shape)
print(feature_acts.shape)
print(token_df["unique_token"][POS_INTEREST]) 
feature_acts_of_interest = feature_acts[POS_INTEREST]
plot_line_with_top_10_labels(feature_acts_of_interest, "", 25)
vals, inds = torch.topk(feature_acts_of_interest,39)
print(vals.nonzero().shape)
print(inds)

## Visualize activations over Sentence

In [None]:
top_k_feature_inds = inds
print(feature_acts.shape)
features_acts_by_token_df = pd.DataFrame(
    feature_acts[:,top_k_feature_inds[:]].detach().cpu().T,
    index = [f"feature_{i}" for i in top_k_feature_inds.flatten().tolist()],
    columns = token_df["unique_token"])

features_acts_by_token_df.sort_values(by=",/12", ascending=False).style.background_gradient(
    cmap="coolwarm", axis=0)

In [None]:
px.line(features_acts_by_token_df.sort_values(",/12", ascending=False).T, title="Top k features by activation")

## Metric Development

- For Loop for a bunch of inference + getting SAE activations
- Metrics:
    - Max Contiguous fires 
    - Average Length of Contiguous Fires (given it fired, how many tokens do we expect it to keep firing on)
    - Number of contiguous blocks in any given prompt
    - Std on activation within a set of contigous fires
    - Profile of firing indexed to first fire in contigous section of fires
    - Proporition of total activation in prompt of feature / total number of times it fired
    - Token based
        - histogram of tokens on which it fires first
        - histogram of tokens on which it stops firing
    - Other SAE Properties
        - Sparsity
        - b_enc
        - W_dec @ b_dec
        - W_enc @ W_dec


Data:
    - Sequence Data (tokens in a prompt)
    - Per Feature Data 
    - Token (Feature in Data)

In [None]:
# let's measure these for one prompt

token_df, original_cache, cache_reconstructed_query, feature_acts_example = eval_prompt([prompt3], model, sparse_autoencoder, head_idx_override=7)
# print(token_df.columns)
# filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
#                "top_k_features"]
# token_df[filter_cols].style.background_gradient(
#     subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
#     cmap="coolwarm")

print(prompt)
feature_acts.shape

In [None]:
features_of_interest = (feature_acts_example[12] > 0).nonzero().flatten().tolist()
print(features_of_interest)

In [None]:
feature_acts_binary = (feature_acts_example > 0).float()[:,features_of_interest]
# number times active
num_fires_on_prompt = feature_acts_binary.sum(dim=0)

### Use Multithreading (broken in jupyter)

In [None]:
# import torch
# import multiprocessing
# from itertools import repeat

# def analyze_row(row):
#     in_event = False
#     event_start = 0
#     num_events = 0
#     max_values = []
#     avg_values = []
#     durations = []

#     for i, value in enumerate(row):
#         if value > 0:
#             if not in_event:
#                 in_event = True
#                 event_start = i
#                 num_events += 1
#                 max_value = value
#                 total_value = value
#             else:
#                 max_value = max(max_value, value)
#                 total_value += value
#         else:
#             if in_event:
#                 in_event = False
#                 durations.append(i - event_start)
#                 max_values.append(max_value)
#                 avg_values.append(total_value / (i - event_start))

#     if in_event:
#         durations.append(len(row) - event_start)
#         max_values.append(max_value)
#         avg_values.append(total_value / (len(row) - event_start))

#     return {
#         'num_events': num_events,
#         'max_values': max_values,
#         'avg_values': avg_values,
#         'durations': durations
#     }

# def analyze_events_parallel(tensor, num_processes=None):
#     if num_processes is None:
#         num_processes = multiprocessing.cpu_count()

#     with multiprocessing.Pool(num_processes) as pool:
#         results = pool.map(analyze_row, tensor)

#     return results

import time

import events_experiment_multithreading
reload(events_experiment_multithreading)

# Example usage
tensor = feature_acts_example[:, features_of_interest].T.cpu()

start_time = time.time()
print(tensor.shape)
results = events_experiment_multithreading.analyze_events_parallel(tensor)
end_time = time.time()
print(f"Time taken: {end_time - start_time}")
print(results)
feature_prompt_df  = pd.DataFrame(results, index=features_of_interest)
feature_prompt_df.head()
# print(feature_prompt_df.shape)
# feature_prompt_df["feature"] = feature_prompt_df.index
# feature_prompt_df.explode('events').sort_values("num_events", ascending=False)

In [None]:
tensor.shape

In [None]:
len(features_of_interest)

### Don't use Multithreading

In [None]:
import time 

def analyze_events(tensor):
    
    assert len(tensor.shape) == 2, "tensor must be 2D"
    results = []

    for row in tensor:
        in_event = False
        event_start = 0
        num_events = 0
        max_values = []
        avg_values = []
        durations = []
        start_position = np.NAN
        final_position = np.NAN
        
        for i, value in enumerate(row.tolist()):
            if value > 0:
                if not in_event:
                    in_event = True
                    event_start = i
                    num_events += 1
                    max_value = value
                    total_value = value
                    start_position = i
                else:
                    max_value = max(max_value, value)
                    total_value += value
            else:
                if in_event:
                    in_event = False
                    durations.append(i - event_start)
                    max_values.append(max_value)
                    avg_values.append(total_value / (i - event_start))
                    final_position = i
        
        if in_event:
            durations.append(len(row) - event_start)
            max_values.append(max_value)
            avg_values.append(total_value / (len(row) - event_start))

        
        # get the average event duration
        avg_duration = (sum(durations) / len(durations)) if len(durations) > 0 else np.NaN
        
        # max duration 
        max_duration = max(durations) if len(durations) > 0 else np.NaN
        
        # get the average max value
        avg_max_value = (sum(max_values) / (len(max_values)) if len(max_values) > 0 else np.NaN)
        num_firings = sum(durations)
        
        # `zip` avg_valuea, max_values, durations and add it as a subrecord which we could unfurl later
        event_stats = zip(avg_values, max_values, durations)
        event_stats = [
            {
                'avg_value': avg_value,
                'max_value': max_value,
                'duration': duration,
                'start_position': start_position, 
                'final_position': final_position,
            }
            for avg_value, max_value, duration in event_stats
        ]

        results.append({
            'num_events': num_events,
            'num_firings': num_firings,
            'avg_values': avg_values,
            'max_values': max_values,
            'durations': durations,
            'avg_duration': avg_duration,
            'max_duration': max_duration,
            'avg_max_value': avg_max_value,
            'events': event_stats,
        })

    return results

# Example usage


tensor =feature_acts_example[:, features_of_interest].T

start_time = time.time()
results = analyze_events(tensor)
end_time = time.time()
print(f"Time taken: {end_time - start_time}")

feature_prompt_df  = pd.DataFrame(results, index=features_of_interest)
feature_prompt_df["feature"] = feature_prompt_df.index
feature_prompt_df.explode('events').sort_values("num_events", ascending=False)
display(feature_prompt_df.head(10))

In [None]:
# convert events to a dataframe
tmp = feature_prompt_df.explode('events').apply(lambda x: pd.Series(x['events']), axis=1).reset_index().rename(columns={"index": "feature"})
tmp["feature"] = tmp["feature"].astype(str)
px.scatter_matrix(tmp, 
                  title="Event stats for each feature", color="feature", dimensions=["avg_value", "max_value", "duration"],
                  width=1000, height=1000)

In [None]:
px.scatter(feature_prompt_df, x="num_firings", y="num_events", hover_name=feature_prompt_df.index)

## Write Loop

In [None]:
all_tokens_list = []
pbar = tqdm(range(128*6))
for i in pbar:
    all_tokens_list.append(activation_store.get_batch_tokens())
all_tokens = torch.cat(all_tokens_list, dim=0)
print(all_tokens.shape)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
tokens = all_tokens[:4096*6]
del all_tokens
torch.mps.empty_cache()

In [None]:

n_prompts = 1000
# features_of_interest = features_of_interest
features_of_interest = torch.randperm(feature_acts.shape[-1])[:3000].tolist()
token_dfs = []
event_dfs = []
feature_acts_all = []

for prompt_index in tqdm(range(n_prompts)):
    prompt_tokens = tokens[prompt_index].unsqueeze(0)
    # make token df 
    token_df = make_token_df(model, prompt_tokens, len_suffix=5, len_prefix=10)
    token_df["prompt_index"] = prompt_index
    
    (original_logits, original_loss), original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)
    token_df['loss'] = original_loss.flatten().tolist() + [np.nan]
    
    original_act = original_cache[sparse_autoencoder.cfg.hook_point]
    sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)

    feature_acts_of_interest = feature_acts[0, :, features_of_interest].T
    results = analyze_events(feature_acts_of_interest)
    events_df  = pd.DataFrame(results, index=features_of_interest)
    events_df["feature"] = events_df.index.astype(str)
    events_df["prompt_index"] = prompt_index
    events_df = events_df[events_df["num_events"] > 0]
    
        
    token_dfs.append(token_df.reset_index(drop=True))
    event_dfs.append(events_df.reset_index(drop=True))
    feature_acts_all.append(feature_acts_of_interest)
    
feature_acts_all = torch.stack(feature_acts_all, dim=0)

In [None]:
token_df = pd.concat(token_dfs).reset_index(drop=True)
prompt_event_df = pd.concat(event_dfs).reset_index(drop=True)
events_df = prompt_event_df.explode('events').apply(lambda x: pd.Series(x['events']), axis=1)
events_df["feature"] = events_df.index.map(lambda x: prompt_event_df.feature[x]).astype(str)
events_df["prompt_index"] = events_df.index.map(lambda x: prompt_event_df.prompt_index[x])
#
# tmp["feature"] = tmp.index.map(lambda x: event_df["feature"][x]).astype(str)
    

In [None]:
prompt_event_df.head()

In [None]:
px.scatter_matrix(prompt_event_df, 
                  title="Event stats for each feature", color="feature", dimensions=["num_events", "num_firings", "avg_duration", "avg_max_value"],
                  width=1000, height=1000)

In [None]:
prompt_event_agg_df = prompt_event_df.groupby(["feature", "prompt_index"]).agg({"num_events": "sum", "num_firings": "sum", "avg_duration": "mean"}).sort_values("num_events", ascending=False).reset_index()
prompt_event_agg_df["firings_per_event"] = prompt_event_agg_df["num_firings"] / prompt_event_agg_df["num_events"]
px.strip(prompt_event_agg_df, x = "feature", y = "firings_per_event", color="feature", title="Firings per event",
         hover_data= ["num_events", "num_firings", "avg_duration", "prompt_index"],
         ).show()




In [None]:
prompt_event_agg_df.feature.unique().shape

In [None]:
mean_firings_per_event = prompt_event_agg_df.groupby("feature").firings_per_event.mean().sort_values(ascending=False)
std_firings_per_event = prompt_event_agg_df.groupby("feature").firings_per_event.std().sort_values(ascending=False)
px.scatter(x=mean_firings_per_event.values, 
           y = std_firings_per_event.values,
           hover_name=mean_firings_per_event.index,
           marginal_x="histogram",
              marginal_y="histogram",
           labels = {"x": "Mean firings per event", "y": "Std firings per event"},
           title="Mean vs Std firings per event").show()

In [None]:
for feature in mean_firings_per_event[mean_firings_per_event<1.3].index[10:30]:
    render_feature_dashboard(feature)

In [None]:
## Given some token, let's get the distribution of tokens it began firing on
events_df["token_df_id"] = events_df.apply(lambda x: token_df_id_from_prompt_and_pos(x["prompt_index"], x["start_position"]), axis=1)

In [None]:
events_df

In [None]:
events_df.join(token_df, ="token_df_id").head()

In [None]:
prompt_event_df.head()

In [None]:
# we want to get the token distribution from events. 
feature_idx = features_of_interest.index(22768)
token_df["feature_22768"] = feature_acts_all[:, feature_idx].flatten().tolist() 
# token_df["feature_22768_quantile"] = pd.qcut(token_df["feature_22768"], 10, labels=False, duplicates="drop")
idxes = token_df.sort_values("feature_22768", ascending=False).head(30).index
idxes_minus_1 = idxes - 1


In [None]:
token_df.groupby()

In [None]:
token_df_id_from_prompt_and_pos = lambda prompt_index, pos: token_df[(token_df["prompt_index"] == prompt_index) & (token_df["pos"] == pos)].index[0]
str_token_from_prompt_and_pos = lambda prompt_index, pos: token_df[(token_df["prompt_index"] == prompt_index) & (token_df["pos"] == pos)].str_tokens.values[0]

token_df_id_from_prompt_and_pos(12,3)
# str_token_from_prompt_and_pos(12,3)

In [None]:
events_df.groupby("feature").agg({"duration": "std"}).sort_values("duration", ascending=False)

In [None]:
# start id word_cloud

feature_of_interest = 22768

# step 1. Get the start and end points for the text we care about
events_df[events_df.duration == 4]#[events_df.feature == str(feature_of_interest)]
# px.strip(tmp, x = "duration", y = "avg_value",title="Firings per event")\
    

# step 2. for each of these, get prompt
token_df_ids = [token_df_id_from_prompt_and_pos(i,j) for i,j in zip(tmp.prompt_index, tmp.start_position)]
minus_one_token_ids = [token_df_id_from_prompt_and_pos(i,j) for i,j in zip(tmp.prompt_index, tmp.start_position - 1)]
final_token_ids = [token_df_id_from_prompt_and_pos(i,j) for i,j in zip(tmp.prompt_index, tmp.final_position.fillna(128) -1)]
minus_one_token_fire = token_df.iloc[minus_one_token_ids].str_tokens.reset_index(drop=True)
first_token_fire = token_df.iloc[token_df_ids].str_tokens.reset_index(drop=True)
final_token_fire = token_df.iloc[final_token_ids].str_tokens.reset_index(drop=True)

tmp = pd.concat([first_token_fire, minus_one_token_fire, final_token_fire], axis=1)

tmp.columns = ["first_token", "minus_one_token", "final_token"]
tmp

# Feature Dashboards

## Generate new features if needed

In [None]:
all_tokens_list = []
pbar = tqdm(range(128*6))
for i in pbar:
    all_tokens_list.append(activation_store.get_batch_tokens())
all_tokens = torch.cat(all_tokens_list, dim=0)
print(all_tokens.shape)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
tokens = all_tokens[:4096*6]
del all_tokens
torch.mps.empty_cache()

In [None]:
from sae_analysis.visualizer import data_fns, model_fns, html_fns
import importlib

importlib.reload(data_fns)
importlib.reload(html_fns)
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096*6
feature_idx = [i for i in range(sparse_autoencoder.cfg.d_sae)]
feature_idx = torch.tensor(feature_idx).reshape(512, -1)
feature_idx = [feature_idx[i].tolist() for i in range(512)]
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))


# shuffle
interesting_features = mean_firings_per_event.index.astype(int).to_list()

feature_data = get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,
    hook_point_head_index=None,
    tokens=tokens,
    feature_idx=interesting_features,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)



In [None]:

for test_idx in feature_data.keys():
    html_str = feature_data[test_idx].get_all_html()
    with open(f"../week_8_jan/gpt2_small_features/data_{test_idx:04}.html", "w") as f:
        f.write(html_str)

# Further Analysis