11import os
2- from typing import Any , Optional , cast
2+ from typing import Any , Dict , List , Optional , Union , cast
33
44# set TOKENIZERS_PARALLELISM to false to avoid warnings
55os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
66import time
77
88import torch
99from sae_vis .data_fetching_fns import get_feature_data
10- from sae_vis .data_storing_fns import FeatureVisParams , to_str_tokens
10+ from sae_vis .data_storing_fns import FeatureVisParams
1111from tqdm import tqdm
1212
1313import numpy as np
1616
1717from matplotlib import colors
1818
19+ OUT_OF_RANGE_TOKEN = "<|outofrange|>"
20+
1921BG_COLOR_MAP = colors .LinearSegmentedColormap .from_list (
2022 "bg_color_map" , ["white" , "darkorange" ]
2123)
@@ -47,8 +49,9 @@ def __init__(
4749 n_features_at_a_time : int = 1024 ,
4850 buffer_tokens_left : int = 8 ,
4951 buffer_tokens_right : int = 8 ,
50- # start_batch
51- start_batch : int = 0 ,
52+ # start and end batch
53+ start_batch_inclusive : int = 0 ,
54+ end_batch_inclusive : Optional [int ] = None ,
5255 ):
5356 self .sae_path = sae_path
5457 if init_session :
@@ -60,7 +63,8 @@ def __init__(
6063 self .buffer_tokens_right = buffer_tokens_right
6164 self .n_batches_to_sample_from = n_batches_to_sample_from
6265 self .n_prompts_to_select = n_prompts_to_select
63- self .start_batch = start_batch
66+ self .start_batch = start_batch_inclusive
67+ self .end_batch = end_batch_inclusive
6468
6569 # Deal with file structure
6670 if not os .path .exists (neuronpedia_parent_folder ):
@@ -107,6 +111,32 @@ def get_tokens(
107111 def round_list (self , to_round : list [float ]):
108112 return list (np .round (to_round , 3 ))
109113
114+ def to_str_tokens_safe (
115+ self , vocab_dict : Dict [int , str ], tokens : Union [int , List [int ], torch .Tensor ]
116+ ):
117+ """
118+ does to_str_tokens, except handles out of range
119+ """
120+ vocab_max_index = self .model .cfg .d_vocab - 1
121+ # Deal with the int case separately
122+ if isinstance (tokens , int ):
123+ if tokens > vocab_max_index :
124+ return OUT_OF_RANGE_TOKEN
125+ return vocab_dict [tokens ]
126+
127+ # If the tokens are a (possibly nested) list, turn them into a tensor
128+ if isinstance (tokens , list ):
129+ tokens = torch .tensor (tokens )
130+
131+ # Get flattened list of tokens
132+ str_tokens = [
133+ (vocab_dict [t ] if t <= vocab_max_index else OUT_OF_RANGE_TOKEN )
134+ for t in tokens .flatten ().tolist ()
135+ ]
136+
137+ # Reshape
138+ return np .reshape (str_tokens , tokens .shape ).tolist ()
139+
110140 def run (self ):
111141 """
112142 Generate the Neuronpedia outputs.
@@ -137,6 +167,16 @@ def run(self):
137167 feature_idx = np .array_split (feature_idx , n_subarrays )
138168 feature_idx = [x .tolist () for x in feature_idx ]
139169
170+ print (f"==== Starting at batch: { self .start_batch } " )
171+ if self .end_batch is not None :
172+ print (f"==== Ending at batch: { self .end_batch } " )
173+
174+ if self .start_batch > len (feature_idx ) + 1 :
175+ print (
176+ f"Start batch { self .start_batch } is greater than number of batches + 1 { len (feature_idx )} , exiting"
177+ )
178+ exit ()
179+
140180 # write dead into file so we can create them as dead in Neuronpedia
141181 skipped_indexes = set (range (self .n_features )) - set (self .target_feature_indexes )
142182 skipped_indexes_json = json .dumps ({"skipped_indexes" : list (skipped_indexes )})
@@ -166,16 +206,20 @@ def run(self):
166206 }
167207 # pad with blank tokens to the actual vocab size
168208 for i in range (len (vocab_dict ), self .model .cfg .d_vocab ):
169- vocab_dict [i ] = " "
209+ vocab_dict [i ] = OUT_OF_RANGE_TOKEN
170210
171211 with torch .no_grad ():
172212 feature_batch_count = 0
173213 for features_to_process in tqdm (feature_idx ):
174214 feature_batch_count = feature_batch_count + 1
175215
176216 if feature_batch_count < self .start_batch :
177- print (f"Skipping batch: { feature_batch_count } " )
217+ # print(f"Skipping batch - it's after start_batch: {feature_batch_count}")
218+ continue
219+ if self .end_batch is not None and feature_batch_count > self .end_batch :
220+ # print(f"Skipping batch - it's after end_batch: {feature_batch_count}")
178221 continue
222+
179223 print (f"Doing batch: { feature_batch_count } " )
180224
181225 feature_vis_params = FeatureVisParams (
@@ -255,11 +299,11 @@ def run(self):
255299 # feature.left_tables_data.correlated_features_pearson
256300 # )
257301
258- feature_output ["neg_str" ] = to_str_tokens (
302+ feature_output ["neg_str" ] = self . to_str_tokens_safe (
259303 vocab_dict , feature .middle_plots_data .bottom10_token_ids
260304 )
261305 feature_output ["neg_values" ] = bottom10_logits
262- feature_output ["pos_str" ] = to_str_tokens (
306+ feature_output ["pos_str" ] = self . to_str_tokens_safe (
263307 vocab_dict , feature .middle_plots_data .top10_token_ids
264308 )
265309 feature_output ["pos_values" ] = top10_logits
@@ -320,11 +364,13 @@ def run(self):
320364 negContribs = []
321365 for i in range (len (sd .token_ids )):
322366 strs .append (
323- to_str_tokens (vocab_dict , sd .token_ids [i ])
367+ self .to_str_tokens_safe (
368+ vocab_dict , sd .token_ids [i ]
369+ )
324370 )
325371 posContrib = {}
326372 posTokens = [
327- to_str_tokens (vocab_dict , j )
373+ self . to_str_tokens_safe (vocab_dict , j )
328374 for j in sd .top5_token_ids [i ]
329375 ]
330376 if len (posTokens ) > 0 :
@@ -335,7 +381,7 @@ def run(self):
335381 posContribs .append (posContrib )
336382 negContrib = {}
337383 negTokens = [
338- to_str_tokens (vocab_dict , j )
384+ self . to_str_tokens_safe (vocab_dict , j )
339385 for j in sd .bottom5_token_ids [i ]
340386 ]
341387 if len (negTokens ) > 0 :
0 commit comments