Skip to content

Commit 901b888

Browse files
Johnny LinJohnny Lin
authored andcommitted
Safe to_str_tokens, fix memory issues
1 parent 85d8f57 commit 901b888

File tree

5 files changed

+299
-350
lines changed

5 files changed

+299
-350
lines changed

sae_analysis/neuronpedia_runner.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import os
2-
from typing import Any, Optional, cast
2+
from typing import Any, Dict, List, Optional, Union, cast
33

44
# set TOKENIZERS_PARALLELISM to false to avoid warnings
55
os.environ["TOKENIZERS_PARALLELISM"] = "false"
66
import time
77

88
import torch
99
from sae_vis.data_fetching_fns import get_feature_data
10-
from sae_vis.data_storing_fns import FeatureVisParams, to_str_tokens
10+
from sae_vis.data_storing_fns import FeatureVisParams
1111
from tqdm import tqdm
1212

1313
import numpy as np
@@ -16,6 +16,8 @@
1616

1717
from matplotlib import colors
1818

19+
OUT_OF_RANGE_TOKEN = "<|outofrange|>"
20+
1921
BG_COLOR_MAP = colors.LinearSegmentedColormap.from_list(
2022
"bg_color_map", ["white", "darkorange"]
2123
)
@@ -47,8 +49,9 @@ def __init__(
4749
n_features_at_a_time: int = 1024,
4850
buffer_tokens_left: int = 8,
4951
buffer_tokens_right: int = 8,
50-
# start_batch
51-
start_batch: int = 0,
52+
# start and end batch
53+
start_batch_inclusive: int = 0,
54+
end_batch_inclusive: Optional[int] = None,
5255
):
5356
self.sae_path = sae_path
5457
if init_session:
@@ -60,7 +63,8 @@ def __init__(
6063
self.buffer_tokens_right = buffer_tokens_right
6164
self.n_batches_to_sample_from = n_batches_to_sample_from
6265
self.n_prompts_to_select = n_prompts_to_select
63-
self.start_batch = start_batch
66+
self.start_batch = start_batch_inclusive
67+
self.end_batch = end_batch_inclusive
6468

6569
# Deal with file structure
6670
if not os.path.exists(neuronpedia_parent_folder):
@@ -107,6 +111,32 @@ def get_tokens(
107111
def round_list(self, to_round: list[float]):
108112
return list(np.round(to_round, 3))
109113

114+
def to_str_tokens_safe(
115+
self, vocab_dict: Dict[int, str], tokens: Union[int, List[int], torch.Tensor]
116+
):
117+
"""
118+
does to_str_tokens, except handles out of range
119+
"""
120+
vocab_max_index = self.model.cfg.d_vocab - 1
121+
# Deal with the int case separately
122+
if isinstance(tokens, int):
123+
if tokens > vocab_max_index:
124+
return OUT_OF_RANGE_TOKEN
125+
return vocab_dict[tokens]
126+
127+
# If the tokens are a (possibly nested) list, turn them into a tensor
128+
if isinstance(tokens, list):
129+
tokens = torch.tensor(tokens)
130+
131+
# Get flattened list of tokens
132+
str_tokens = [
133+
(vocab_dict[t] if t <= vocab_max_index else OUT_OF_RANGE_TOKEN)
134+
for t in tokens.flatten().tolist()
135+
]
136+
137+
# Reshape
138+
return np.reshape(str_tokens, tokens.shape).tolist()
139+
110140
def run(self):
111141
"""
112142
Generate the Neuronpedia outputs.
@@ -137,6 +167,16 @@ def run(self):
137167
feature_idx = np.array_split(feature_idx, n_subarrays)
138168
feature_idx = [x.tolist() for x in feature_idx]
139169

170+
print(f"==== Starting at batch: {self.start_batch}")
171+
if self.end_batch is not None:
172+
print(f"==== Ending at batch: {self.end_batch}")
173+
174+
if self.start_batch > len(feature_idx) + 1:
175+
print(
176+
f"Start batch {self.start_batch} is greater than number of batches + 1 {len(feature_idx)}, exiting"
177+
)
178+
exit()
179+
140180
# write dead into file so we can create them as dead in Neuronpedia
141181
skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes)
142182
skipped_indexes_json = json.dumps({"skipped_indexes": list(skipped_indexes)})
@@ -166,16 +206,20 @@ def run(self):
166206
}
167207
# pad with blank tokens to the actual vocab size
168208
for i in range(len(vocab_dict), self.model.cfg.d_vocab):
169-
vocab_dict[i] = " "
209+
vocab_dict[i] = OUT_OF_RANGE_TOKEN
170210

171211
with torch.no_grad():
172212
feature_batch_count = 0
173213
for features_to_process in tqdm(feature_idx):
174214
feature_batch_count = feature_batch_count + 1
175215

176216
if feature_batch_count < self.start_batch:
177-
print(f"Skipping batch: {feature_batch_count}")
217+
# print(f"Skipping batch - it's after start_batch: {feature_batch_count}")
218+
continue
219+
if self.end_batch is not None and feature_batch_count > self.end_batch:
220+
# print(f"Skipping batch - it's after end_batch: {feature_batch_count}")
178221
continue
222+
179223
print(f"Doing batch: {feature_batch_count}")
180224

181225
feature_vis_params = FeatureVisParams(
@@ -255,11 +299,11 @@ def run(self):
255299
# feature.left_tables_data.correlated_features_pearson
256300
# )
257301

258-
feature_output["neg_str"] = to_str_tokens(
302+
feature_output["neg_str"] = self.to_str_tokens_safe(
259303
vocab_dict, feature.middle_plots_data.bottom10_token_ids
260304
)
261305
feature_output["neg_values"] = bottom10_logits
262-
feature_output["pos_str"] = to_str_tokens(
306+
feature_output["pos_str"] = self.to_str_tokens_safe(
263307
vocab_dict, feature.middle_plots_data.top10_token_ids
264308
)
265309
feature_output["pos_values"] = top10_logits
@@ -320,11 +364,13 @@ def run(self):
320364
negContribs = []
321365
for i in range(len(sd.token_ids)):
322366
strs.append(
323-
to_str_tokens(vocab_dict, sd.token_ids[i])
367+
self.to_str_tokens_safe(
368+
vocab_dict, sd.token_ids[i]
369+
)
324370
)
325371
posContrib = {}
326372
posTokens = [
327-
to_str_tokens(vocab_dict, j)
373+
self.to_str_tokens_safe(vocab_dict, j)
328374
for j in sd.top5_token_ids[i]
329375
]
330376
if len(posTokens) > 0:
@@ -335,7 +381,7 @@ def run(self):
335381
posContribs.append(posContrib)
336382
negContrib = {}
337383
negTokens = [
338-
to_str_tokens(vocab_dict, j)
384+
self.to_str_tokens_safe(vocab_dict, j)
339385
for j in sd.bottom5_token_ids[i]
340386
]
341387
if len(negTokens) > 0:

0 commit comments

Comments
 (0)