In [None]:
# ============================================================
# 0) Setup
# ============================================================
import torch, random
from datasets import load_dataset
from transformers import AutoTokenizer, GPT2LMHeadModel

# ---- Config ----
SEED       = 123
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID   = "gpt2"
SPLIT      = "train"
SEQ_LEN    = 1024
BATCH_SIZE = 16
MAX_STEPS  = 1100
TOP_K      = 250  # report top-k most sensitive elements

random.seed(SEED)
torch.manual_seed(SEED)

# ============================================================
# 1) Model & Tokenizer
# ============================================================
tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token = tok.eos_token
model = GPT2LMHeadModel.from_pretrained(MODEL_ID).to(DEVICE)
model.train()  # need gradients

# ============================================================
# 2) Dataset & Chunking
# ============================================================
wiki = load_dataset("wikitext", "wikitext-103-raw-v1", split=SPLIT)

def chunk_generator():
    cache = []
    for doc in wiki:
        cache.extend(tok(doc["text"]).input_ids)
        while len(cache) >= SEQ_LEN + 1:
            win, cache = cache[:SEQ_LEN+1], cache[SEQ_LEN+1:]  # non-overlap
            yield win[:-1], win[1:]

def get_batch(gen, bs=BATCH_SIZE):
    buf = []
    for x, _ in gen:
        buf.append(x)
        if len(buf) == bs:
            yield torch.tensor(buf, device=DEVICE)
            buf = []

# ============================================================
# 3) Gradient Scan
# ============================================================
param_dict  = {n: p for n, p in model.named_parameters() if p.requires_grad}
running_max = {n: torch.zeros_like(p, device="cpu") for n, p in param_dict.items()}

for step, inp in enumerate(get_batch(chunk_generator()), 1):
    model.zero_grad(set_to_none=True)
    loss = model(inp, labels=inp).loss
    loss.backward()
    for name, p in param_dict.items():
        running_max[name] = torch.maximum(
            running_max[name],
            p.grad.detach().abs().to("cpu")
        )
    if step >= MAX_STEPS:
        break

# ============================================================
# 4) Find Top-K Sensitive Coordinates
# ============================================================
candidates = []
for name, rm in running_max.items():
    k_local = min(TOP_K, rm.numel())
    if k_local > 0:
        vals, idxs = torch.topk(rm.view(-1), k_local)
        for v, flat in zip(vals, idxs):
            coord = torch.unravel_index(flat, rm.shape)
            candidates.append((v.item(), name, coord))

candidates.sort(key=lambda t: t[0], reverse=True)
topk_entries = candidates[:TOP_K]

print(f"\nTop-{TOP_K} most sensitive tensor elements:")
for rank, (val, name, coord) in enumerate(topk_entries, 1):
    print(f"  #{rank}: {name}{tuple(map(int,coord))}  |grad|={val:.3e}")



Top-250 most sensitive tensor elements:
  #1: transformer.wte.weight(2488, 496)  |grad|=5.235e+00
  #2: transformer.wte.weight(837, 496)  |grad|=4.481e+00
  #3: transformer.wte.weight(198, 496)  |grad|=3.247e+00
  #4: transformer.wte.weight(11, 496)  |grad|=2.898e+00
  #5: transformer.wte.weight(34315, 496)  |grad|=2.669e+00
  #6: transformer.wte.weight(796, 496)  |grad|=2.517e+00
  #7: transformer.wte.weight(764, 496)  |grad|=2.499e+00
  #8: transformer.wte.weight(13, 496)  |grad|=2.427e+00
  #9: transformer.wte.weight(220, 496)  |grad|=2.242e+00
  #10: transformer.wte.weight(2488, 430)  |grad|=2.207e+00
  #11: transformer.wte.weight(198, 430)  |grad|=2.154e+00
  #12: transformer.wte.weight(31, 496)  |grad|=2.059e+00
  #13: transformer.wte.weight(5187, 496)  |grad|=2.046e+00
  #14: transformer.wte.weight(49063, 496)  |grad|=2.029e+00
  #15: transformer.wte.weight(27583, 496)  |grad|=2.001e+00
  #16: transformer.wte.weight(837, 430)  |grad|=1.993e+00
  #17: transformer.wte.weight(6645

In [None]:
# ============================================================
# 0) Setup & Config
# ============================================================
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---- Experiment knobs ----
SEED       = 123
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_ID   = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

SPLIT      = "train"
SEQ_LEN    = 750      # keep modest to avoid OOM
BATCH_SIZE = 8        # start small; you can try 2 later
MAX_STEPS  = 100        # you said even 1 step is enough for now
TOP_K      = 250      # number of elements to report

# ---- Optional exclusion for param_dict (usually keep OFF) ----
EXCLUDE_ENABLED = False
EXCLUDE_PARAM_NAMES = [
    # e.g. "model.embed_tokens.weight", "lm_head.weight"
]

# ---- Filter for the SECOND Top-K (local vs global) ----
# These names will be EXCLUDED from the "local" (filtered) 250.
FILTER_ENABLED = True
FILTER_PARAM_NAMES = [
    "model.embed_tokens.weight",  # DeepSeek token embeddings
    "lm_head.weight",             # optional: output head
]

# Reproducibility
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ============================================================
# 1) Model & Tokenizer
# ============================================================
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_ID)

if tok.pad_token is None:
    if tok.eos_token is not None:
        tok.pad_token = tok.eos_token
    else:
        tok.pad_token = tok.convert_ids_to_tokens(0)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.float16,           # half precision to reduce memory
    device_map={"": DEVICE},
)

# Optional: gradient checkpointing to save memory
model.gradient_checkpointing_enable()
model.train()  # we need gradients

first_param = next(model.parameters())
print(f"Model loaded on {DEVICE} with dtype {first_param.dtype}")

# ============================================================
# 2) Dataset & Chunking
# ============================================================
print("Loading dataset...")
wiki = load_dataset("wikitext", "wikitext-103-raw-v1", split=SPLIT)

def chunk_generator():
    """
    Stream Wikitext into a rolling buffer of token ids, and yield
    non-overlapping windows of length SEQ_LEN+1, split into (inp, labels).
    """
    cache = []
    for doc in wiki:
        text = doc["text"]
        if not text:
            continue
        ids = tok(text).input_ids
        cache.extend(ids)
        while len(cache) >= SEQ_LEN + 1:
            win, cache = cache[:SEQ_LEN+1], cache[SEQ_LEN+1:]
            inp, labels = win[:-1], win[1:]
            yield inp, labels

def get_batch(gen, bs=BATCH_SIZE):
    """
    Stack windows from chunk_generator into batches of size bs.
    """
    buf = []
    for x, _ in gen:
        buf.append(x)
        if len(buf) == bs:
            batch = torch.tensor(buf, dtype=torch.long, device=DEVICE)
            yield batch
            buf = []

# ============================================================
# 3) Parameter Dictionary & Exclusion (for scan)
# ============================================================
def should_exclude_for_scan(name: str) -> bool:
    """Exact-name exclusion for building param_dict."""
    if not EXCLUDE_ENABLED:
        return False
    return name in EXCLUDE_PARAM_NAMES

param_dict = {
    n: p for n, p in model.named_parameters()
    if p.requires_grad and not should_exclude_for_scan(n)
}

print(f"Total tracked parameter tensors: {len(param_dict)}")
total_elems = sum(p.numel() for p in param_dict.values())
print(f"Total tracked elements: {total_elems:,}")

# Keep running max on CPU in float32
running_max = {
    n: torch.zeros_like(p, device="cpu", dtype=torch.float32)
    for n, p in param_dict.items()
}

# ============================================================
# 4) Gradient Scan
# ============================================================
print("Starting gradient scan...")
stream = get_batch(chunk_generator(), bs=BATCH_SIZE)

for step, inp in enumerate(stream, 1):
    model.zero_grad(set_to_none=True)
    out = model(inp, labels=inp)   # causal LM cross-entropy next-token loss
    loss = out.loss
    loss.backward()

    for name, p in param_dict.items():
        if p.grad is None:
            continue
        grad_abs = p.grad.detach().abs().to("cpu", dtype=torch.float32)
        running_max[name] = torch.maximum(running_max[name], grad_abs)

    print(f"  [step {step}] loss = {loss.item():.4f}")
    if step >= MAX_STEPS:
        print(f"Reached MAX_STEPS = {MAX_STEPS}. Stopping scan.")
        break

print("Gradient scan complete.")

# ============================================================
# 5) Build global ranked list (with global ranks)
# ============================================================
print(f"Selecting global Top-{TOP_K} elements by |grad|...")

candidates = []  # will hold (value, name, coord_tuple)

for name, rm in running_max.items():
    numel = rm.numel()
    if numel == 0:
        continue
    k_local = min(TOP_K, numel)  # top-k within this tensor
    vals, idxs = torch.topk(rm.view(-1), k_local)
    for v, flat_idx in zip(vals, idxs):
        coord = torch.unravel_index(flat_idx, rm.shape)
        candidates.append((v.item(), name, coord))

# Global sort by gradient magnitude
candidates.sort(key=lambda t: t[0], reverse=True)

# Attach global rank and store in a structured list
ranked_global = []
for global_rank, (val, name, coord) in enumerate(candidates, 1):
    ranked_global.append({
        "global_rank": global_rank,
        "value": val,
        "name": name,
        "coord": coord,
    })

# Global Top-K
global_topk = ranked_global[:TOP_K]

# ============================================================
# 6) Build filtered Top-K with local + global ranks
# ============================================================
def passes_filter(name: str) -> bool:
    """Return True if this parameter is allowed in the filtered list."""
    if not FILTER_ENABLED:
        return True
    return name not in FILTER_PARAM_NAMES

filtered = [item for item in ranked_global if passes_filter(item["name"])]
filtered_topk = filtered[:TOP_K]

# Attach local ranks
for local_rank, item in enumerate(filtered_topk, 1):
    item["local_rank"] = local_rank

# ============================================================
# 7) Print Results
# ============================================================
print(f"\n=== Global Top-{TOP_K} most sensitive tensor elements ===")
for item in global_topk:
    gr = item["global_rank"]
    val = item["value"]
    name = item["name"]
    coord = item["coord"]
    coord_str = "(" + ", ".join(str(int(c)) for c in coord) + ")"
    print(f"  global # {gr:4d}: {name}{coord_str}  |grad|={val:.3e}")

print(f"\n=== Filtered Top-{TOP_K} (w/ local & global ranks) ===")
print("Filter excludes:", FILTER_PARAM_NAMES if FILTER_ENABLED else "None")
for item in filtered_topk:
    lr = item["local_rank"]
    gr = item["global_rank"]
    val = item["value"]
    name = item["name"]
    coord = item["coord"]
    coord_str = "(" + ", ".join(str(int(c)) for c in coord) + ")"
    print(
        f"  local # {lr:4d} | global # {gr:4d} : "
        f"{name}{coord_str}  |grad|={val:.3e}"
    )
