In [None]:
# ============================================================
# 0) Setup
# ============================================================
import torch, random
from datasets import load_dataset
from transformers import AutoTokenizer, GPT2LMHeadModel

# ---- Config ----
SEED       = 123
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID   = "gpt2"
SPLIT      = "train"
SEQ_LEN    = 1024
BATCH_SIZE = 16
MAX_STEPS  = 1100
TOP_K      = 250  # report top-k most sensitive elements

random.seed(SEED)
torch.manual_seed(SEED)

# ============================================================
# 1) Model & Tokenizer
# ============================================================
tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token = tok.eos_token
model = GPT2LMHeadModel.from_pretrained(MODEL_ID).to(DEVICE)
model.train()  # need gradients

# ============================================================
# 2) Dataset & Chunking
# ============================================================
wiki = load_dataset("wikitext", "wikitext-103-raw-v1", split=SPLIT)

def chunk_generator():
    cache = []
    for doc in wiki:
        cache.extend(tok(doc["text"]).input_ids)
        while len(cache) >= SEQ_LEN + 1:
            win, cache = cache[:SEQ_LEN+1], cache[SEQ_LEN+1:]  # non-overlap
            yield win[:-1], win[1:]

def get_batch(gen, bs=BATCH_SIZE):
    buf = []
    for x, _ in gen:
        buf.append(x)
        if len(buf) == bs:
            yield torch.tensor(buf, device=DEVICE)
            buf = []

# ============================================================
# 3) Gradient Scan
# ============================================================
param_dict  = {n: p for n, p in model.named_parameters() if p.requires_grad}
running_max = {n: torch.zeros_like(p, device="cpu") for n, p in param_dict.items()}

for step, inp in enumerate(get_batch(chunk_generator()), 1):
    model.zero_grad(set_to_none=True)
    loss = model(inp, labels=inp).loss
    loss.backward()
    for name, p in param_dict.items():
        running_max[name] = torch.maximum(
            running_max[name],
            p.grad.detach().abs().to("cpu")
        )
    if step >= MAX_STEPS:
        break

# ============================================================
# 4) Find Top-K Sensitive Coordinates
# ============================================================
candidates = []
for name, rm in running_max.items():
    k_local = min(TOP_K, rm.numel())
    if k_local > 0:
        vals, idxs = torch.topk(rm.view(-1), k_local)
        for v, flat in zip(vals, idxs):
            coord = torch.unravel_index(flat, rm.shape)
            candidates.append((v.item(), name, coord))

candidates.sort(key=lambda t: t[0], reverse=True)
topk_entries = candidates[:TOP_K]

print(f"\nTop-{TOP_K} most sensitive tensor elements:")
for rank, (val, name, coord) in enumerate(topk_entries, 1):
    print(f"  #{rank}: {name}{tuple(map(int,coord))}  |grad|={val:.3e}")



Top-250 most sensitive tensor elements:
  #1: transformer.wte.weight(2488, 496)  |grad|=5.235e+00
  #2: transformer.wte.weight(837, 496)  |grad|=4.481e+00
  #3: transformer.wte.weight(198, 496)  |grad|=3.247e+00
  #4: transformer.wte.weight(11, 496)  |grad|=2.898e+00
  #5: transformer.wte.weight(34315, 496)  |grad|=2.669e+00
  #6: transformer.wte.weight(796, 496)  |grad|=2.517e+00
  #7: transformer.wte.weight(764, 496)  |grad|=2.499e+00
  #8: transformer.wte.weight(13, 496)  |grad|=2.427e+00
  #9: transformer.wte.weight(220, 496)  |grad|=2.242e+00
  #10: transformer.wte.weight(2488, 430)  |grad|=2.207e+00
  #11: transformer.wte.weight(198, 430)  |grad|=2.154e+00
  #12: transformer.wte.weight(31, 496)  |grad|=2.059e+00
  #13: transformer.wte.weight(5187, 496)  |grad|=2.046e+00
  #14: transformer.wte.weight(49063, 496)  |grad|=2.029e+00
  #15: transformer.wte.weight(27583, 496)  |grad|=2.001e+00
  #16: transformer.wte.weight(837, 430)  |grad|=1.993e+00
  #17: transformer.wte.weight(6645

In [9]:
# ============================================================
# 0) Setup & Config
# ============================================================
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---- Experiment knobs ----
SEED       = 123
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_ID   = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"   # or "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

SPLIT      = "train"
SEQ_LEN    = 1024
BATCH_SIZE = 16
MAX_STEPS  = 100
TOP_K      = 250

# ---- Filter list for second top-K (local ranks) ----
FILTER_ENABLED = True
NUM_LAYERS = 28   # Qwen1.5B / Qwen2.5-1.5B architecture
FILTER_PARAM_NAMES = [
    # Embeddings
    "model.embed_tokens.weight",

    # Layer norms: input LN
    *[f"model.layers.{i}.input_layernorm.weight" for i in range(NUM_LAYERS)],
    *[f"model.layers.{i}.input_layernorm.bias"   for i in range(NUM_LAYERS)],

    # Layer norms: post-attention LN
    *[f"model.layers.{i}.post_attention_layernorm.weight" for i in range(NUM_LAYERS)],
    *[f"model.layers.{i}.post_attention_layernorm.bias"   for i in range(NUM_LAYERS)],

    # Final RMSNorm
    "model.norm.weight"
]


# Reproducibility
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ============================================================
# 1) Model & Tokenizer
# ============================================================
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_ID)

if tok.pad_token is None:
    tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.float16,
    device_map={"": DEVICE},
)

model.gradient_checkpointing_enable()
model.train()  # gradients required for scan

print(f"Model loaded on {DEVICE} with dtype {next(model.parameters()).dtype}")

# ============================================================
# 2) Dataset & Chunking
# ============================================================
print("Loading dataset...")
wiki = load_dataset("wikitext", "wikitext-103-raw-v1", split=SPLIT)

def chunk_generator():
    """Yield (inp, labels) windows of size SEQ_LEN."""
    cache = []
    for doc in wiki:
        ids = tok(doc["text"]).input_ids
        cache.extend(ids)
        while len(cache) >= SEQ_LEN + 1:
            window = cache[:SEQ_LEN+1]
            cache = cache[SEQ_LEN+1:]
            yield window[:-1], window[1:]

def get_batch(gen, bs=BATCH_SIZE):
    """Stack windows into batches."""
    buf = []
    for inp, _ in gen:
        buf.append(inp)
        if len(buf) == bs:
            yield torch.tensor(buf, dtype=torch.long, device=DEVICE)
            buf = []

# ============================================================
# 3) Parameter Dictionary
# ============================================================
param_dict = {name: p for name, p in model.named_parameters() if p.requires_grad}

print(f"Total tracked tensors: {len(param_dict)}")
print(f"Total elements: {sum(p.numel() for p in param_dict.values()):,}")

running_max = {
    name: torch.zeros_like(p, device="cpu", dtype=torch.float32)
    for name, p in param_dict.items()
}

# ============================================================
# 4) Gradient Scan
# ============================================================
print("Starting gradient scan...")
stream = get_batch(chunk_generator(), bs=BATCH_SIZE)

for step, inp in enumerate(stream, 1):
    model.zero_grad(set_to_none=True)
    out = model(inp, labels=inp)
    loss = out.loss
    loss.backward()

    for name, p in param_dict.items():
        if p.grad is not None:
            grad_abs = p.grad.detach().abs().to("cpu", torch.float32)
            running_max[name] = torch.maximum(running_max[name], grad_abs)

    print(f"[step {step}] loss={loss.item():.4f}")

    if step >= MAX_STEPS:
        print("Reached MAX_STEPS — stopping scan.")
        break

print("Gradient scan complete.\n")

# ============================================================
# 5) Build Global TOP-K
# ============================================================
print(f"Selecting global Top-{TOP_K} by |grad|...")

candidates = []
for name, rm in running_max.items():
    rm_flat = rm.view(-1)
    k_local = min(TOP_K, rm_flat.numel())
    vals, idxs = torch.topk(rm_flat, k_local)
    for v, flat_idx in zip(vals, idxs):
        coord = torch.unravel_index(flat_idx, rm.shape)
        candidates.append((v.item(), name, coord))

# Sort globally
candidates.sort(key=lambda x: x[0], reverse=True)

ranked_global = [
    {
        "global_rank": i + 1,
        "value": val,
        "name": name,
        "coord": coord,
    }
    for i, (val, name, coord) in enumerate(candidates)
]

global_topk = ranked_global[:TOP_K]

# ============================================================
# 6) Filtered TOP-K (local + global ranks)
# ============================================================
def allowed(name):
    return not FILTER_ENABLED or name not in FILTER_PARAM_NAMES

filtered_list = [item for item in ranked_global if allowed(item["name"])]
filtered_topk = filtered_list[:TOP_K]

for i, item in enumerate(filtered_topk, 1):
    item["local_rank"] = i

# ============================================================
# 7) Print Results
# ============================================================
print(f"\n=== Global Top-{TOP_K} most sensitive parameters ===")
for item in global_topk:
    coord = "(" + ", ".join(str(int(x)) for x in item["coord"]) + ")"
    print(f"  global #{item['global_rank']:3d}: {item['name']}{coord}  |grad|={item['value']:.3e}")

print(f"\n=== Filtered Top-{TOP_K} (local + global ranks) ===")
print("Filter:", FILTER_PARAM_NAMES if FILTER_ENABLED else "None")
for item in filtered_topk:
    coord = "(" + ", ".join(str(int(x)) for x in item["coord"]) + ")"
    print(
        f"  local #{item['local_rank']:3d} | global #{item['global_rank']:3d}:  "
        f"{item['name']}{coord}  |grad|={item['value']:.3e}"
    )


Loading tokenizer...
Loading model...
Model loaded on cuda with dtype torch.float16
Loading dataset...
Total tracked tensors: 339
Total elements: 1,777,088,000
Starting gradient scan...
[step 1] loss=3.7354
[step 2] loss=3.6691
[step 3] loss=3.7742
[step 4] loss=3.7177
[step 5] loss=3.6300
[step 6] loss=3.6891
[step 7] loss=3.8463
[step 8] loss=3.3649
[step 9] loss=3.7301
[step 10] loss=3.6542
[step 11] loss=3.7277
[step 12] loss=3.7875
[step 13] loss=3.5501
[step 14] loss=3.7428
[step 15] loss=3.7047
[step 16] loss=3.2469
[step 17] loss=3.9071
[step 18] loss=3.4903
[step 19] loss=3.3947
[step 20] loss=3.5306
[step 21] loss=3.7135
[step 22] loss=3.6956
[step 23] loss=4.0239
[step 24] loss=3.3613
[step 25] loss=3.6670
[step 26] loss=3.4049
[step 27] loss=3.8173
[step 28] loss=4.0070
[step 29] loss=3.6343
[step 30] loss=3.9703
[step 31] loss=3.7230
[step 32] loss=3.7623
[step 33] loss=3.8226
[step 34] loss=3.6987
[step 35] loss=3.5843
[step 36] loss=3.7525
[step 37] loss=3.8072
[step 38]