In [None]:
# ============================================================
# 0) One-off setup (installs)
# ============================================================
%env CUDA_LAUNCH_BLOCKING=1
import subprocess, sys

subprocess.run(
    [sys.executable, "-m", "pip", "install", "-q",
     "--upgrade",
     "datasets>=2.18.0", "fsspec>=2023.6.0",
     "pandas>=2.0.0", "sacrebleu>=2.4.0",
     "evaluate>=0.4.2", "rouge-score>=0.1.2",
     "bert-score>=0.3.13", "tabulate>=0.9.0"],
    check=True
)

# ============================================================
# 1) Imports & Config
# ============================================================
import os
import math
import random
import torch
import torch.nn.functional as F
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, GPT2LMHeadModel
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
import sacrebleu
import evaluate
from tabulate import tabulate

# ---- Experiment knobs ----
SEED               = 123
DEVICE             = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID           = "gpt2"
SPLIT              = "train"
SEQ_LEN            = 1024
BATCH_SIZE         = 16
MAX_STEPS          = 1100
TOP_K              = 3
V_SELECT           = "all"
N_TRIALS_PER_CLASS = 5
MAX_NEW_TOKENS     = 100
MAX_COL_WIDTH      = 100

# Decoding strategy: "greedy", "top-k", or "top-p"
DECODING_STRATEGY = "greedy"
TOP_K_SAMPLING    = 10
TOP_P_SAMPLING    = 0.9
TEMPERATURE       = 1.0

TEST_PROMPTS = [
    "The weather today is",
    "The patient should take",
    "The bank transfer amount is",
    "The recommended dose for a child is",
    "The evacuation order status is",
]

# Metric toggles
ENABLE_BLEU       = True
ENABLE_METEOR     = True
ENABLE_BERTSCORE  = True
ENABLE_ROUGE      = True

# Reproducibility
random.seed(SEED)
torch.manual_seed(SEED)

# ============================================================
# 2) Model & Tokenizer
# ============================================================
tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token = tok.eos_token
model = GPT2LMHeadModel.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()

# ============================================================
# 3) Load & check metrics
# ============================================================
meteor = evaluate.load("meteor")    if ENABLE_METEOR    else None
berts  = evaluate.load("bertscore") if ENABLE_BERTSCORE else None
rouge  = evaluate.load("rouge")     if ENABLE_ROUGE    else None

print(f"METEOR loaded: {meteor}")
print(f"BERTScore loaded: {berts}")
print(f"ROUGE loaded:   {rouge}")

# ============================================================
# 4) Data & gradient scan for top-K sensitive coords
# ============================================================
wiki = load_dataset("wikitext", "wikitext-103-raw-v1", split=SPLIT)

def chunk_generator():
    cache = []
    for doc in wiki:
        cache.extend(tok(doc["text"]).input_ids)
        while len(cache) >= SEQ_LEN + 1:
            win, cache = cache[:SEQ_LEN+1], cache[SEQ_LEN+1:]
            yield win[:-1], win[1:]

def get_batch(gen, bs=BATCH_SIZE):
    buf = []
    for x, _ in gen:
        buf.append(x)
        if len(buf) == bs:
            yield torch.tensor(buf, device=DEVICE)
            buf = []

param_dict  = {n: p for n, p in model.named_parameters() if p.requires_grad}
running_max = {n: torch.zeros_like(p, device="cpu") for n, p in param_dict.items()}

for step, inp in enumerate(get_batch(chunk_generator()), 1):
    model.zero_grad(set_to_none=True)
    model(inp, labels=inp).loss.backward()
    for name, p in param_dict.items():
        running_max[name] = torch.maximum(
            running_max[name],
            p.grad.detach().abs().to("cpu")
        )
    if step >= MAX_STEPS:
        break

candidates = []
for name, rm in running_max.items():
    k_local = min(TOP_K, rm.numel())
    if k_local == 0:
        continue
    vals, idxs = torch.topk(rm.view(-1), k_local)
    for v, flat in zip(vals, idxs):
        coord = torch.unravel_index(flat, rm.shape)
        candidates.append((v.item(), name, coord))

candidates.sort(key=lambda t: t[0], reverse=True)
topk_entries = candidates[:TOP_K]
coords_list  = [(name, coord) for _, name, coord in topk_entries]

print(f"\nGlobal Top-{TOP_K} |∂L/∂θ| scalars:")
for rank, (val, name, coord) in enumerate(topk_entries, 1):
    print(f"  #{rank}: {name}{tuple(map(int,coord))}  |grad|={val:.3e}")

def normalize_v_select(sel, k):
    if sel == "all":
        return list(range(1, k+1))
    if isinstance(sel, int):
        return [sel]
    if isinstance(sel, (list, tuple)):
        return list(sel)
    raise ValueError("V_SELECT must be 'all', int, or list[int]'")

ranks_to_test = normalize_v_select(V_SELECT, TOP_K)
print(f"\nTesting ranks: {ranks_to_test}")

# ============================================================
# 5) Bit-flip helpers
# ============================================================
def flip_bit(val_tensor: torch.Tensor, bit: int):
    iv = val_tensor.view(torch.int32)
    iv ^= (1 << bit)
    return iv.view(torch.float32)

BIT_CLASSES = {
    "sign":     [31],
    "exponent": list(range(23, 31)),
    "mantissa": list(range(0, 23)),
}

# ============================================================
# 6) Scoring function with try/except
# ============================================================
def edit_distance(a: str, b: str):
    n, m = len(a), len(b)
    dp = list(range(m+1))
    for i in range(1, n+1):
        prev, dp[0] = dp[0], i
        for j in range(1, m+1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[j], prev = min(dp[j] + 1, dp[j-1] + 1, prev + cost), dp[j]
    return dp[m]

def score_pair(clean: str, corrupt: str):
    scores = {}
    ed = edit_distance(clean, corrupt)
    scores["EditDist"]      = float(ed)
    scores["EditDist_Norm"] = float(ed / max(1, len(clean)))

    if ENABLE_BLEU:
        try:
            scores["BLEU"] = sacrebleu.corpus_bleu([corrupt], [[clean]]).score
        except Exception as e:
            print("BLEU compute failed:", e)

    if ENABLE_METEOR and meteor is not None:
        try:
            scores["METEOR"] = float(
                meteor.compute(predictions=[corrupt], references=[clean])["meteor"]
            )
        except Exception as e:
            print("METEOR compute failed:", e)

    if ENABLE_BERTSCORE and berts is not None:
        try:
            bs = berts.compute(
                predictions=[corrupt], references=[clean], lang="en"
            )
            scores["BERTScore_F1"] = float(bs["f1"][0])
        except Exception as e:
            print("BERTScore compute failed:", e)

    if ENABLE_ROUGE and rouge is not None:
        try:
            r = rouge.compute(
                predictions=[corrupt], references=[clean], use_stemmer=True
            )
            scores["ROUGE1_F1"] = float(r["rouge1"])
            scores["ROUGE2_F1"] = float(r["rouge2"])
            scores["ROUGEL_F1"] = float(r["rougeL"])
        except Exception as e:
            print("ROUGE compute failed:", e)

    return scores

# ============================================================
# 7) Generation helpers
# ============================================================
class NanInfDetector(LogitsProcessor):
    def __init__(self, flag_dict=None):
        self.flag_dict = flag_dict
    def __call__(self, input_ids, scores):
        if self.flag_dict and (torch.isnan(scores).any() or torch.isinf(scores).any()):
            self.flag_dict["had_nan"] = True
        return scores

class MaxRepeatGuard(LogitsProcessor):
    def __init__(self, max_consecutive=6):
        self.max_consecutive = max_consecutive
    def __call__(self, input_ids, scores):
        if input_ids.size(0) != 1 or input_ids.size(1) == 0:
            return scores
        seq, last = input_ids[0], input_ids[0, -1].item()
        run = 0
        for t in range(seq.size(0)-1, -1, -1):
            if seq[t].item() == last:
                run += 1
            else:
                break
        if run >= self.max_consecutive:
            scores[:, last] = -1e9
        return scores

@torch.no_grad()
def _generate(prompt: str, corrupt: bool = False):
    enc = tok(prompt, return_tensors="pt", return_attention_mask=True)
    ids = enc["input_ids"].to(DEVICE)
    mask = enc["attention_mask"].to(DEVICE)

    do_sample = DECODING_STRATEGY != "greedy"
    gen_kwargs = {"temperature": TEMPERATURE}
    if DECODING_STRATEGY == "top-k":
        gen_kwargs["top_k"] = TOP_K_SAMPLING
    elif DECODING_STRATEGY == "top-p":
        gen_kwargs["top_p"] = TOP_P_SAMPLING

    if not corrupt:
        out = model.generate(
            input_ids=ids,
            attention_mask=mask,
            do_sample=do_sample,
            max_new_tokens=MAX_NEW_TOKENS,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.eos_token_id,
            **gen_kwargs
        )
        return tok.decode(out[0, ids.size(1):], skip_special_tokens=True), False

    diag = {"had_nan": False}
    procs = LogitsProcessorList([
        NanInfDetector(flag_dict=diag),
        MaxRepeatGuard(max_consecutive=6),
    ])
    out = model.generate(
        input_ids=ids,
        attention_mask=mask,
        do_sample=do_sample,
        max_new_tokens=MAX_NEW_TOKENS,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.eos_token_id,
        logits_processor=procs,
        no_repeat_ngram_size=3,
        **gen_kwargs
    )
    return tok.decode(out[0, ids.size(1):], skip_special_tokens=True), diag["had_nan"]

generate_clean   = lambda p: _generate(p, corrupt=False)[0]
generate_corrupt = lambda p: _generate(p, corrupt=True)

# ============================================================
# 8) Run bit-flip trials
# ============================================================
CLEAN_CACHE = {p: generate_clean(p) for p in TEST_PROMPTS}
rows = []
for rank in ranks_to_test:
    name, coord = coords_list[rank-1]
    W, orig = param_dict[name], param_dict[name].data[coord].clone()
    for bit_class, pool in BIT_CLASSES.items():
        for trial in range(1, N_TRIALS_PER_CLASS + 1):
            bit = random.choice(pool)
            W.data[coord] = flip_bit(orig, bit)
            try:
                for prompt in TEST_PROMPTS:
                    clean_out = CLEAN_CACHE[prompt]
                    corrupt_out, had_nan = generate_corrupt(prompt)
                    scores = score_pair(clean_out, corrupt_out)
                    rows.append({
                        "rank": rank,
                        "tensor": name,
                        "coord": tuple(map(int, coord)),
                        "bit_class": bit_class,
                        "bit_index": bit,
                        "trial": trial,
                        "prompt": prompt,
                        "clean": clean_out,
                        "corrupt": corrupt_out,
                        "corrupt_logits_had_nan": had_nan,
                        **scores
                    })
            finally:
                W.data[coord] = orig

df = pd.DataFrame(rows)

# ============================================================
# 9) Debug: show final columns & sample
# ============================================================
print("Final DataFrame columns:", df.columns.tolist())
print(df.head(3))
# ============================================================
# 10) Save CSVs to GitHub repo
# ============================================================
# Save outputs inside the local clone of:
#   https://github.com/kameshr/llm-sensitivity
# This assumes the notebook is run from within that cloned repository.
try:
    repo_root = subprocess.check_output(
        ['git', 'rev-parse', '--show-toplevel'],
        stderr=subprocess.DEVNULL,
    ).decode().strip()
except Exception:
    repo_root = os.getcwd()

OUT_DIR = os.path.join(repo_root, 'bitflip_outputs')
os.makedirs(OUT_DIR, exist_ok=True)

# Per-trial CSV
trial_path = os.path.join(OUT_DIR, 'bitflip_per_trial.csv')
df.to_csv(trial_path, index=False)
print(f"Saved per-trial CSV → {trial_path}")

# Aggregated CSV
metric_cols = [c for c in [
    'EditDist','EditDist_Norm','BLEU','METEOR',
    'BERTScore_F1','ROUGE1_F1','ROUGE2_F1','ROUGEL_F1'
] if c in df.columns]
print('Aggregated metrics present:', metric_cols)
if metric_cols:
    summary = df.groupby(
        ['rank','tensor','coord','bit_class','prompt'],
        as_index=False
    ).agg({m: ['mean','median','std'] for m in metric_cols})
    if isinstance(summary.columns, pd.MultiIndex):
        summary.columns = ['_'.join(filter(None, c)) for c in summary.columns]
    agg_path = os.path.join(OUT_DIR, 'bitflip_aggregated.csv')
    summary.to_csv(agg_path, index=False)
    print(f"Saved aggregated CSV → {agg_path}")
