In [None]:
!pip install sae-lens transformer-lens plotly pandas numpy scipy tqdm
!pip install sae-lens
!pip install transformer-lens
!pip install plotly
!pip install -U bitsandbytes accelerate

In [None]:
import os, gc, re
import torch
import numpy as np
import requests
from functools import partial
from tqdm.auto import tqdm

# ========================================
# 0) 환경 설정
# ========================================
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TQDM_DISABLE"] = "1"

import transformers
transformers.logging.set_verbosity_error()

from huggingface_hub import login
from sae_lens import SAE, HookedSAETransformer

HF_TOKEN = os.environ.get("HF_TOKEN", None)
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("✅ Logged in to Hugging Face")
else:
    print("⚠️ HF_TOKEN not set, skipping login")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

gc.collect()

# ========================================
# 1) 모델 & SAE 로드
# ========================================
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
print(f"\nLoading model: {MODEL_NAME}")

model = HookedSAETransformer.from_pretrained(
    MODEL_NAME,
    device=device,
    dtype=torch.float16 if device == "cuda" else torch.float32,
)
print("✅ Model loaded!")

release = "llama-3-8b-it-res-jh"
sae_id  = "blocks.25.hook_resid_post"

print(f"\nLoading SAE: release={release}, sae_id={sae_id}")
sae, cfg_dict, _ = SAE.from_pretrained(release, sae_id, device=device)
print("✅ SAE loaded!")

# SAE 설정값 파싱
HOOK_NAME   = sae_id                                                        # "blocks.25.hook_resid_post"
LAYER_IDX   = int(re.search(r"blocks\.(\d+)\.", sae_id).group(1))          # 25
PREPEND_BOS = getattr(sae.cfg, "prepend_bos", True)

print(f"   Hook:       {HOOK_NAME}")
print(f"   Layer:      {LAYER_IDX}")
print(f"   d_in:       {sae.cfg.d_in}")
print(f"   d_sae:      {sae.cfg.d_sae}")
print(f"   prepend_bos:{PREPEND_BOS}")