In [26]:
# %pip install transformers datasets peft opacus wandb

import os, math, time, random, csv, gc
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Iterable, Optional, Dict
from contextlib import nullcontext

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
)
import wandb
from peft import LoraConfig, get_peft_model
from opacus.accountants.rdp import RDPAccountant

# ----------------- W&B -----------------
os.environ["WANDB_ENTITY"]  = "bi000050-university-of-minnesota"
os.environ["WANDB_PROJECT"] = "userdp_ft"
os.environ["WANDB_MODE"]    = "online"
wandb.login(key="0d32ee09cbe7bae59ddac577f978a4f31cc4b559")
USE_WANDB = True

def set_seed(seed=42):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# ----------------- Config -----------------
@dataclass
class CFG:
    base_model: str = "gpt2"
    use_8bit: bool = False

    # LoRA
    use_lora: bool = False
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_targets: Tuple[str, ...] = ("c_attn", "c_fc", "c_proj")

    # Data
    dataset_name: str = "yahma/alpaca-cleaned"
    max_train_samples: int = 4000
    max_eval_samples: int = 800
    max_seq_len: int = 256

    # DP level
    privacy_level: str = "user"   # "user" or "record"
    user_lam: float = 5.0         # synthetic users: Poisson(lam) items/user

    # Train
    epochs: int = 1
    micro_batch: int = 4
    lr: float = 1e-4
    weight_decay: float = 0.0
    amp: bool = True              # autocast

    # DP hyperparams
    clip_norm: float = 10
    noise_multiplier: float = 1
    delta: float = 1e-5

    # Methods to compare
    methods: Tuple[str, ...] = ("dpsgd", "dpdisfom")

    # DISFOM
    rho_hat: float = 1
    box_lo: Optional[float] = None
    box_hi: Optional[float] = None

    # Logging
    log_csv: str = "dp_llm_compare_userlvl.csv"
    save_root: str = "dp_llm_runs"

CFG = CFG()
Path(CFG.save_root).mkdir(exist_ok=True)

# ----------------- Data (format → user_id → tokenize) -----------------
def format_alpaca(e):
    instr = e.get("instruction",""); inp = e.get("input",""); out = e.get("output","")
    if inp:
        prompt = f"### Instruction:\n{instr}\n\n### Input:\n{inp}\n\n### Response:\n"
    else:
        prompt = f"### Instruction:\n{instr}\n\n### Response:\n"
    return {"text": prompt + out}

def build_raw_dataset():
    raw = load_dataset(CFG.dataset_name)
    raw = raw.map(format_alpaca, remove_columns=raw["train"].column_names)
    def take_first(ds, n): return ds.select(range(min(n, len(ds))))
    train_raw = take_first(raw["train"], CFG.max_train_samples)
    eval_raw  = take_first(raw["train"].select(range(CFG.max_train_samples, len(raw["train"]))), CFG.max_eval_samples)
    return DatasetDict({"train": train_raw, "validation": eval_raw})

def attach_synthetic_user_ids(ds, lam=5.0, seed=1234):
    import numpy as np
    rng = np.random.default_rng(seed)
    n = len(ds)
    user_ids, uid, i = [], 0, 0
    while i < n:
        k = max(1, int(rng.poisson(lam)))
        for _ in range(k):
            if i >= n: break
            user_ids.append(uid); i += 1
        uid += 1
    return ds.add_column("user_id", user_ids[:n]), uid  # uid = #users

raw_ds = build_raw_dataset()
train_with_uid, U_total = attach_synthetic_user_ids(raw_ds["train"], lam=CFG.user_lam, seed=2025)
val_with_uid, _        = attach_synthetic_user_ids(raw_ds["validation"], lam=CFG.user_lam, seed=2026)

tokenizer = AutoTokenizer.from_pretrained(CFG.base_model, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tok(batch):
    x = tokenizer(batch["text"], padding="max_length", truncation=True, max_length=CFG.max_seq_len)
    x["labels"] = x["input_ids"].copy()
    return x

tok_train = train_with_uid.map(tok, batched=True, remove_columns=["text"])
tok_val   = val_with_uid.map(tok,   batched=True, remove_columns=["text"])
print("Train cols:", tok_train.column_names)
print("Val cols:", tok_val.column_names)

core_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

class KeepUserID:
    def __init__(self, base): self.base = base
    def __call__(self, features):
        user_ids = [f["user_id"] for f in features]
        core = [{k:v for k,v in f.items() if k in ("input_ids","attention_mask","labels")} for f in features]
        batch = self.base(core)
        batch["user_id"] = torch.tensor(user_ids, dtype=torch.long)
        return batch

collator = KeepUserID(core_collator)

train_loader = DataLoader(
    tok_train, batch_size=CFG.micro_batch, shuffle=True, drop_last=True,
    collate_fn=collator, pin_memory=(device=="cuda"), num_workers=2, persistent_workers=True
)
eval_loader  = DataLoader(
    tok_val, batch_size=CFG.micro_batch, shuffle=False, drop_last=False,
    collate_fn=collator, pin_memory=(device=="cuda"), num_workers=2, persistent_workers=True
)
print("Train N:", len(train_loader.dataset), "Eval N:", len(eval_loader.dataset), "Users:", U_total)

# ----------------- Model -----------------
def build_model():
    load_kwargs = {}
    if CFG.use_8bit:
        load_kwargs.update(dict(load_in_8bit=True, device_map="auto"))
    model = AutoModelForCausalLM.from_pretrained(CFG.base_model, **load_kwargs)

    if CFG.use_lora:
        lcfg = LoraConfig(
            r=CFG.lora_r, lora_alpha=CFG.lora_alpha, lora_dropout=CFG.lora_dropout,
            target_modules=list(CFG.lora_targets), bias="none", task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lcfg)
        print("[Model] LoRA enabled.")
    else:
        for p in model.parameters(): p.requires_grad = True
        print("[Model] LoRA disabled. Training full model.")

    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    # ensure checkpointing has a grad-capturing input
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        model.get_input_embeddings().weight.requires_grad_(True)

    model.to(device)
    model.train()
    return model

def params_that_train(model):
    return [p for p in model.parameters() if p.requires_grad]

def pack_params(params: Iterable[torch.Tensor]) -> Tuple[torch.Tensor, List[Tuple[int,...]]]:
    flats, shapes = [], []
    for p in params:
        shapes.append(tuple(p.shape))
        flats.append(p.detach().reshape(-1))
    return torch.cat(flats), shapes

def unpack_to_params(vec: torch.Tensor, params: Iterable[torch.Tensor], shapes: List[Tuple[int,...]]):
    offset = 0
    for p, shp in zip(params, shapes):
        n = int(torch.tensor(shp).prod().item())
        with torch.no_grad():
            p.copy_(vec[offset:offset+n].view(shp))
        offset += n

# ----------------- Loss utils (AMP new API) -----------------
AMP = CFG.amp and (device == "cuda")

def per_example_losses_from_tensors(model, input_ids, labels, attention_mask, amp_enabled: bool):
    amp_ctx = torch.amp.autocast("cuda", enabled=amp_enabled) if (torch.cuda.is_available() and amp_enabled) else nullcontext()
    with amp_ctx:
        out = model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
        logits = out.logits
        shift_logits = logits[:, :-1].contiguous()
        shift_labels = labels[:, :-1].contiguous()
        shift_mask   = attention_mask[:, :-1].contiguous().float()
        logp = torch.nn.functional.log_softmax(shift_logits, dim=-1)
        nll = torch.nn.functional.nll_loss(
            logp.view(-1, logp.size(-1)),
            shift_labels.view(-1),
            reduction="none"
        ).view(shift_labels.shape)
        nll = nll * shift_mask
        denom = shift_mask.sum(dim=1).clamp_min(1.0)
        per_ex = nll.sum(dim=1) / denom      # [b]
    if not per_ex.requires_grad:
        raise RuntimeError("Chunk losses have no grad; ensure model.train() and no no_grad around forward.")
    return per_ex

def evaluate_ppl(model, loader):
    model.eval()
    total, count = 0.0, 0
    amp_ctx = torch.amp.autocast("cuda", enabled=AMP) if device=="cuda" else nullcontext()
    with torch.no_grad(), amp_ctx:
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            labels    = batch["labels"].to(device)
            attention = batch["attention_mask"].to(device)
            out = model(input_ids=input_ids, labels=labels, attention_mask=attention)
            total += float(out.loss.item()) * input_ids.size(0)
            count += input_ids.size(0)
    mean = total / max(1, count)
    ppl = math.exp(mean) if mean < 50 else float("inf")
    model.train()
    return mean, ppl

# ----------------- DP oracles (chunked forward+backward) -----------------
def dp_recordlevel_oracle(model, batch, clip_C: float, noise_sigma: float, chunk_size: int = 2):
    params = params_that_train(model)
    input_ids = batch["input_ids"].to(device)
    labels    = batch["labels"].to(device)
    attn      = batch["attention_mask"].to(device)
    B = int(input_ids.size(0))
    per_grads = []

    for start in range(0, B, chunk_size):
        end = min(start + chunk_size, B)
        ids_sl  = input_ids[start:end]
        labs_sl = labels[start:end]
        attn_sl = attn[start:end]

        # fresh graph per chunk
        losses_sl = per_example_losses_from_tensors(model, ids_sl, labs_sl, attn_sl, AMP)
        s = int(losses_sl.shape[0])

        for j in range(s):
            model.zero_grad(set_to_none=True)
            losses_sl[j].backward(retain_graph=(j < s-1))
            g_j = [(p.grad.detach().clone() if p.grad is not None else torch.zeros_like(p)) for p in params]
            per_grads.append(g_j)

    clipped = []
    for g in per_grads:
        flat, _ = pack_params(g)
        norm = flat.norm(2)
        scale = min(1.0, clip_C / (norm + 1e-12))
        clipped.append([gi * scale for gi in g])

    avg = [sum(gi_list)/B for gi_list in zip(*clipped)]
    std = noise_sigma * clip_C / B
    noisy = [g + torch.randn_like(g) * std for g in avg]
    return noisy, B

def dp_userlevel_oracle(model, batch, clip_C_user: float, noise_sigma: float, chunk_size: int = 2):
    params = params_that_train(model)
    input_ids = batch["input_ids"].to(device)
    labels    = batch["labels"].to(device)
    attn      = batch["attention_mask"].to(device)
    user_ids  = batch["user_id"].to(device)
    B = int(input_ids.size(0))

    per_grads, per_users = [], []

    for start in range(0, B, chunk_size):
        end = min(start + chunk_size, B)
        ids_sl  = input_ids[start:end]
        labs_sl = labels[start:end]
        attn_sl = attn[start:end]
        uid_sl  = user_ids[start:end]

        losses_sl = per_example_losses_from_tensors(model, ids_sl, labs_sl, attn_sl, AMP)
        s = int(losses_sl.shape[0])

        for j in range(s):
            model.zero_grad(set_to_none=True)
            losses_sl[j].backward(retain_graph=(j < s-1))
            g_j = [(p.grad.detach().clone() if p.grad is not None else torch.zeros_like(p)) for p in params]
            per_grads.append(g_j)
            per_users.append(int(uid_sl[j].item()))

    uniq_users = sorted(set(per_users))
    user_grad: Dict[int, List[torch.Tensor]] = {u: [torch.zeros_like(p) for p in params] for u in uniq_users}
    for gi, u in zip(per_grads, per_users):
        for j, gij in enumerate(gi):
            user_grad[u][j] += gij

    clipped_users, clip_hits = [], 0
    for u in uniq_users:
        gu = user_grad[u]
        flat, _ = pack_params(gu)
        norm = flat.norm(2)
        scale = min(1.0, clip_C_user / (norm + 1e-12))
        if scale < 1.0: clip_hits += 1
        clipped_users.append([gij * scale for gij in gu])

    U_B = len(uniq_users)
    avg = [sum(gs[j] for gs in clipped_users) / U_B for j in range(len(params))]
    std = noise_sigma * clip_C_user / U_B
    noisy = [g + torch.randn_like(g) * std for g in avg]
    return noisy, U_B, clip_hits / max(1, U_B)

# ----------------- Optimizer steps -----------------
def apply_weight_decay(params, wd):
    if wd == 0: return
    with torch.no_grad():
        for p in params: p.add_(-wd * p)

def step_dpsgd(model, lr, private_grad):
    with torch.no_grad():
        for p, g in zip(params_that_train(model), private_grad):
            p.add_(-lr * g)

def step_dpnsgd(model, lr, private_grad, eps=1e-12):
    gvec = torch.cat([g.view(-1) for g in private_grad])
    norm = gvec.norm(2).clamp_min(eps)
    scale = 1.0 / norm
    with torch.no_grad():
        for p, g in zip(params_that_train(model), private_grad):
            p.add_(-lr * (g * scale))

# DISFOM prox with φ(z)=(ρ̂/2)||z||_1^2
def tau_by_bisection(u: torch.Tensor, rho_hat: float, max_iter: int = 80, tol: float = 1e-10) -> float:
    lo, hi = 0.0, float(u.abs().max().item())
    def rhs(tau: float) -> float:
        return float(torch.nn.functional.relu(u.abs() - tau).sum().item())
    for _ in range(max_iter):
        mid = 0.5*(lo+hi)
        Rmid = mid - rho_hat * rhs(mid)
        if abs(Rmid) <= tol: return mid
        if Rmid > 0: hi = mid
        else:        lo = mid
    return 0.5*(lo+hi)

def prox_disfom_unconstrained(u: torch.Tensor, rho_hat: float) -> torch.Tensor:
    tau = tau_by_bisection(u, rho_hat)
    return torch.sign(u) * torch.nn.functional.relu(u.abs() - tau)

def step_dpdisfom(model, lr, private_grad, rho_hat, box_lo=None, box_hi=None):
    params = params_that_train(model)
    with torch.no_grad():
        xk_vec, shapes = pack_params([p.data for p in params])
        g_vec, _ = pack_params(private_grad)
        u = (-lr) * g_vec
        if box_lo is None and box_hi is None:
            y = prox_disfom_unconstrained(u, rho_hat)
            x_next = xk_vec + y
        else:
            l = torch.full_like(xk_vec, box_lo) if box_lo is not None else torch.full_like(xk_vec, -float("inf"))
            ubox = torch.full_like(xk_vec, box_hi) if box_hi is not None else torch.full_like(xk_vec,  float("inf"))
            l_disp, u_disp = l - xk_vec, ubox - xk_vec
            def rhs_box(tau: float) -> float:
                z = torch.sign(u) * torch.nn.functional.relu(u.abs() - tau)
                z = torch.max(torch.min(z, u_disp), l_disp)
                return float(z.abs().sum().item())
            lo, hi = 0.0, float(u.abs().max().item())
            for _ in range(80):
                mid = 0.5*(lo+hi)
                Rmid = mid - rho_hat * rhs_box(mid)
                if abs(Rmid) <= 1e-10: 
                    tau = mid; break
                if Rmid > 0: hi = mid
                else:        lo = mid
            else:
                tau = 0.5*(lo+hi)
            z = torch.sign(u) * torch.nn.functional.relu(u.abs() - tau)
            y = torch.max(torch.min(z, u_disp), l_disp)
            x_next = xk_vec + y
        unpack_to_params(x_next, params, shapes)

# ----------------- Train loop -----------------
def run_one_method(method: str, train_loader: DataLoader, eval_loader: DataLoader, U_total_users: int):
    model = build_model()
    assert any(p.requires_grad for p in model.parameters()), "No trainable parameters"
    params = params_that_train(model)
    lr = CFG.lr

    acct = RDPAccountant()

    run = None
    if USE_WANDB:
        run = wandb.init(
            project=os.environ["WANDB_PROJECT"],
            entity=os.environ.get("WANDB_ENTITY"),
            name=f"{CFG.privacy_level}-{method}-{int(time.time())}",
            group=f"{CFG.privacy_level}-optimizer-compare",
            config={
                "privacy_level": CFG.privacy_level,
                "base_model": CFG.base_model, "method": method, "epochs": CFG.epochs,
                "clip": CFG.clip_norm, "noise": CFG.noise_multiplier, "delta": CFG.delta,
                "rho_hat": CFG.rho_hat if method=="dpdisfom" else None,
                "lr": lr, "batch_size": CFG.micro_batch, "max_seq_len": CFG.max_seq_len,
                "train_N": len(train_loader.dataset), "eval_N": len(eval_loader.dataset),
                "U_total": U_total_users
            },
            reinit=True,
        )
        wandb.define_metric("train/iter")
        wandb.define_metric("train/loss_iter", step_metric="train/iter")
        wandb.define_metric("epoch")
        wandb.define_metric("metrics/loss_epoch", step_metric="epoch")
        wandb.define_metric("metrics/epsilon", step_metric="epoch")

    iter_losses, global_step = [], 0

    for ep in range(CFG.epochs):
        pbar = tqdm(train_loader, desc=f"{CFG.privacy_level}:{method} | epoch {ep+1}/{CFG.epochs}")
        for batch in pbar:
            for k in batch: batch[k] = batch[k].to(device)

            if CFG.privacy_level == "record":
                g_priv, B = dp_recordlevel_oracle(model, batch, CFG.clip_norm, CFG.noise_multiplier, chunk_size=2)
                sample_rate = CFG.micro_batch / max(1, len(train_loader.dataset))
                clip_frac_users = None
            else:
                g_priv, U_B, clip_frac_users = dp_userlevel_oracle(model, batch, CFG.clip_norm, CFG.noise_multiplier, chunk_size=2)
                sample_rate = U_B / max(1, U_total_users)

            amp_ctx = torch.amp.autocast("cuda", enabled=AMP) if device=="cuda" else nullcontext()
            with torch.no_grad(), amp_ctx:
                out = model(input_ids=batch["input_ids"], labels=batch["labels"], attention_mask=batch["attention_mask"])
                batch_loss = float(out.loss.detach().cpu().item())

            apply_weight_decay(params, CFG.weight_decay)
            if method == "dpsgd":
                step_dpsgd(model, lr, g_priv)
            elif method == "dpdisfom":
                step_dpdisfom(model, lr, g_priv, CFG.rho_hat, CFG.box_lo, CFG.box_hi)
            else:
                raise ValueError(method)

            acct.step(noise_multiplier=CFG.noise_multiplier, sample_rate=sample_rate)

            for p in params: p.grad = None
            del g_priv, out
            gc.collect()
            if torch.cuda.is_available() and global_step % 50 == 0:
                torch.cuda.empty_cache(); torch.cuda.ipc_collect()

            iter_losses.append(batch_loss)
            global_step += 1
            log_dict = {"train/iter": global_step, "train/loss_iter": batch_loss}
            if CFG.privacy_level == "user" and clip_frac_users is not None:
                log_dict["dp/clip_frac_users"] = clip_frac_users
                log_dict["dp/sample_rate_users"] = sample_rate
            pbar.set_postfix(loss=f"{batch_loss:.4f}")
            if run is not None:
                wandb.log(log_dict, step=global_step)

        eval_loss, eval_ppl = evaluate_ppl(model, eval_loader)
        eps_now = acct.get_epsilon(delta=CFG.delta)
        print(f"[{CFG.privacy_level}:{method}] epoch {ep+1}: eval_loss={eval_loss:.4f} ppl={eval_ppl:.2f}  epsilon={eps_now:.3f}")
        if run is not None:
            wandb.log({"epoch": ep+1,
                       "metrics/loss_epoch": eval_loss,
                       "metrics/ppl": eval_ppl,
                       "metrics/epsilon": eps_now}, step=global_step)

    tag = f"{CFG.privacy_level}_{CFG.base_model.replace('/','_')}_{method}"
    outdir = Path(CFG.save_root) / tag
    outdir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(outdir.as_posix())
    if run is not None: run.finish()
    return iter_losses

# ----------------- Run all methods -----------------
log_path = Path(CFG.log_csv)
new_file = not log_path.exists()
loss_traces = {}

with open(log_path, "a", newline="") as f:
    wr = csv.writer(f)
    if new_file:
        wr.writerow(["timestamp","privacy_level","base_model","method","epochs","clip","noise","delta",
                     "train_N","eval_N","U_total","rho_hat","lr","last_eval_loss","last_eval_ppl","last_epsilon"])

    for m in CFG.methods:
        print(f"\n=== Running {CFG.privacy_level}:{m} ===")
        iter_losses = run_one_method(m, train_loader, eval_loader, U_total_users=U_total)
        loss_traces[m] = iter_losses
        wr.writerow([int(time.time()), CFG.privacy_level, CFG.base_model, m, CFG.epochs,
                     CFG.clip_norm, CFG.noise_multiplier, CFG.delta,
                     len(train_loader.dataset), len(eval_loader.dataset), U_total,
                     CFG.rho_hat if m=="dpdisfom" else "",
                     CFG.lr, None, None, None])

print(f"Wrote {CFG.log_csv}")

import matplotlib.pyplot as plt
plt.figure()
for m, losses in loss_traces.items():
    plt.plot(range(1, len(losses)+1), losses, label=m)
plt.xlabel("Iteration"); plt.ylabel("Training loss (per-batch)")
plt.title(f"Loss vs Iteration ({CFG.privacy_level}-level DP)")
plt.legend(); plt.savefig(f"loss_{CFG.privacy_level}_dp.png"); plt.close()


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /users/2/bi000050/.netrc


Device: cuda
Train cols: ['user_id', 'input_ids', 'attention_mask', 'labels']
Val cols: ['user_id', 'input_ids', 'attention_mask', 'labels']
Train N: 4000 Eval N: 800 Users: 778

=== Running user:dpsgd ===
[Model] LoRA disabled. Training full model.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


user:dpsgd | epoch 1/1:   0%|          | 0/1000 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[user:dpsgd] epoch 1: eval_loss=13.3493 ppl=627405.58  epsilon=1.227


0,1
dp/clip_frac_users,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dp/sample_rate_users,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁
metrics/epsilon,▁
metrics/loss_epoch,▁
metrics/ppl,▁
train/iter,▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▇▇▇█████
train/loss_iter,▁▁▁▁▁▂▂▂▃▃▃▃▄▄▄▅▅▆▆▇▆▇▆▇▇▇▇▇▇▇▇▇▇▇█▇█▇██

0,1
dp/clip_frac_users,1.0
dp/sample_rate_users,0.00514
epoch,1.0
metrics/epsilon,1.22679
metrics/loss_epoch,13.34935
metrics/ppl,627405.58314
train/iter,1000.0
train/loss_iter,11.02115



=== Running user:dpdisfom ===
[Model] LoRA disabled. Training full model.


user:dpdisfom | epoch 1/1:   0%|          | 0/1000 [00:00<?, ?it/s]



[user:dpdisfom] epoch 1: eval_loss=2.9055 ppl=18.27  epsilon=1.227


0,1
dp/clip_frac_users,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dp/sample_rate_users,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁
metrics/epsilon,▁
metrics/loss_epoch,▁
metrics/ppl,▁
train/iter,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██
train/loss_iter,▅▅▂▄▄▂▁▁▄▃▃▂▃▃▃▂▂▂▄▄▄▆▄▂▅▃▃▅▅▂▁▃▂▃▂▁▅▃▄█

0,1
dp/clip_frac_users,1.0
dp/sample_rate_users,0.00514
epoch,1.0
metrics/epsilon,1.22665
metrics/loss_epoch,2.90545
metrics/ppl,18.27348
train/iter,1000.0
train/loss_iter,3.28609


Wrote dp_llm_compare_userlvl.csv


In [3]:
# %% imports & setup
import os, math, time, random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Iterable, Optional, Dict

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model, PeftModel
from opacus.accountants.rdp import RDPAccountant

import evaluate
rouge = evaluate.load("rouge")
bleu  = evaluate.load("sacrebleu")

import wandb

# --- W&B env (set your key via env var for security) ---
os.environ.setdefault("WANDB_ENTITY",  "bi000050-university-of-minnesota")
os.environ.setdefault("WANDB_PROJECT", "userdp_ft")
os.environ.setdefault("WANDB_MODE",    "online")  # or "offline"
wandb_api_key = os.environ.get("WANDB_API_KEY", "0d32ee09cbe7bae59ddac577f978a4f31cc4b559")
wandb.login(key=wandb_api_key)
USE_WANDB = True

# (optional) reduce fragmentation
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


[34m[1mwandb[0m: Currently logged in as: [33mbi000050[0m ([33mbi000050-university-of-minnesota[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /users/2/bi000050/.netrc


Device: cuda


In [4]:
# %% config (OOM-safe defaults; you can scale up later)
@dataclass
class CFG:
    # choose scenario: "matched_epsilon_record" | "matched_epsilon_user" | "noise_stress"
    scenario: str = "matched_epsilon_record"

    # model / LoRA
    base_model: str = "gpt2"   # switch to "gpt2" after validating
    use_8bit: bool = False
    use_lora: bool = False            # LoRA drastically cuts memory
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_targets: Tuple[str, ...] = ("c_attn", "c_fc", "c_proj")

    # data
    dataset_name: str = "yahma/alpaca-cleaned"
    max_train_samples: int = 2000
    max_eval_samples: int = 500
    max_seq_len: int = 256          # reduce if memory tight (e.g., 192 or 128)

    # training
    epochs: int = 10
    micro_batch: int = 4            # small to allow per-example grads
    lr_sgd: float = 1e-4
    lr_disfom: float = 1e-4
    weight_decay: float = 0.0

    # DP privacy knobs
    clip_norm: float = 5
    noise_multiplier: float = 3   # used in noise_stress scenario
    delta: float = 1e-5
    target_epsilon: float = 1     # used in matched_epsilon_* scenarios

    # DISFOM
    rho_hat: float = 1
    box_lo: Optional[float] = None
    box_hi: Optional[float] = None

    # noise_stress sweep
    # stress_sigmas: Tuple[float, ...] = (0.6, 0.8, 1.0, 1.2)

    # IO
    save_root: str = "dp_llm_runs"
    log_csv: str = "dp_llm_compare.csv"

CFG = CFG()
Path(CFG.save_root).mkdir(exist_ok=True)
CFG


CFG(scenario='matched_epsilon_record', base_model='gpt2', use_8bit=False, use_lora=False, lora_r=16, lora_alpha=32, lora_dropout=0.05, lora_targets=('c_attn', 'c_fc', 'c_proj'), dataset_name='yahma/alpaca-cleaned', max_train_samples=2000, max_eval_samples=500, max_seq_len=256, epochs=10, micro_batch=4, lr_sgd=0.0001, lr_disfom=0.0001, weight_decay=0.0, clip_norm=5, noise_multiplier=3, delta=1e-05, target_epsilon=1, rho_hat=1, box_lo=None, box_hi=None, save_root='dp_llm_runs', log_csv='dp_llm_compare.csv')

In [5]:
# %% dataset, tokenization, optional synthetic user IDs (for user-level DP)
def format_alpaca(e):
    instr = e.get("instruction",""); inp = e.get("input",""); out = e.get("output","")
    if inp:
        prompt = f"### Instruction:\n{instr}\n\n### Input:\n{inp}\n\n### Response:\n"
    else:
        prompt = f"### Instruction:\n{instr}\n\n### Response:\n"
    return {"text": prompt + out}

def build_dataset_and_loaders(add_user_ids: bool = False):
    raw = load_dataset(CFG.dataset_name)
    raw = raw.map(format_alpaca, remove_columns=raw["train"].column_names)

    def take_first(ds, n): return ds.select(range(min(n, len(ds))))
    train_raw = take_first(raw["train"], CFG.max_train_samples)
    eval_raw  = take_first(raw["train"].select(range(CFG.max_train_samples, len(raw["train"]))), CFG.max_eval_samples)
    dataset = DatasetDict({"train": train_raw, "validation": eval_raw})

    total_users = None
    if add_user_ids:
        def add_synth_user(ds, lam=5, seed=42):
            rng = np.random.default_rng(seed)
            n = len(ds); user_ids=[]; uid=0; i=0
            while i<n:
                k = max(1, int(rng.poisson(lam)))
                for _ in range(k):
                    if i>=n: break
                    user_ids.append(uid); i+=1
                uid+=1
            return ds.add_column("user_id", user_ids[:n]), uid
        dataset["train"], total_users = add_synth_user(dataset["train"], lam=5, seed=42)
        dataset["validation"], _ = add_synth_user(dataset["validation"], lam=5, seed=43)

    tok = AutoTokenizer.from_pretrained(CFG.base_model, use_fast=True)
    if tok.pad_token is None: tok.pad_token = tok.eos_token

    def tok_map(batch):
        x = tok(batch["text"], padding="max_length", truncation=True, max_length=CFG.max_seq_len)
        x["labels"] = x["input_ids"].copy()
        if "user_id" in batch:
            x["user_id"] = batch["user_id"]
        return x

    tokd = dataset.map(tok_map, batched=True)

    base_collator = DataCollatorForLanguageModeling(tok, mlm=False)

    class KeepUserID:
        def __init__(self, base): self.base = base
        def __call__(self, features):
            keep_uid = "user_id" in features[0]
            uids = [f["user_id"] for f in features] if keep_uid else None
            feats = [{k:v for k,v in f.items() if k not in ("text","user_id")} for f in features]
            batch = self.base(feats)
            if keep_uid:
                batch["user_id"] = torch.tensor(uids, dtype=torch.long)
            return batch

    collator = KeepUserID(base_collator)

    train_loader = DataLoader(tokd["train"], batch_size=CFG.micro_batch,
                              shuffle=True, drop_last=True, collate_fn=collator)
    eval_loader  = DataLoader(tokd["validation"], batch_size=CFG.micro_batch,
                              shuffle=False, drop_last=False, collate_fn=collator)
    return dataset, train_loader, eval_loader, tok, total_users

need_user = (CFG.scenario == "matched_epsilon_user")
dataset, train_loader, eval_loader, tokenizer, TOTAL_USERS = build_dataset_and_loaders(add_user_ids=need_user)
print("Train N:", len(train_loader.dataset), "Eval N:", len(eval_loader.dataset), "Users:", TOTAL_USERS)


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Train N: 2000 Eval N: 500 Users: None


In [6]:
# %% model builder (LoRA on/off + optional gradient checkpointing)
def build_model():
    load_kwargs = {}
    if CFG.use_8bit:
        load_kwargs.update(dict(load_in_8bit=True, device_map="auto"))
    model = AutoModelForCausalLM.from_pretrained(CFG.base_model, **load_kwargs)
    if CFG.use_lora:
        lcfg = LoraConfig(r=CFG.lora_r, lora_alpha=CFG.lora_alpha, lora_dropout=CFG.lora_dropout,
                          target_modules=list(CFG.lora_targets), bias="none", task_type="CAUSAL_LM")
        model = get_peft_model(model, lcfg)
        print("[Model] LoRA enabled.")
    else:
        for p in model.parameters(): p.requires_grad = True
        print("[Model] LoRA disabled (full fine-tune).")

    # (optional) gradient checkpointing for larger models
    try:
        model.config.use_cache = False
        model.gradient_checkpointing_enable()
        print("[Model] Gradient checkpointing enabled.")
    except Exception as e:
        print("[Model] Gradient checkpointing not available:", e)

    model.to(device).train()
    return model

def params_that_train(model):
    return [p for p in model.parameters() if p.requires_grad]


In [7]:
# %% low-memory helpers: no GPU cat, CPU clip/avg/noise
def zero_like_params(params: Iterable[torch.Tensor], device="cpu"):
    return [torch.zeros_like(p, device=device, dtype=p.dtype) for p in params]

def l2norm_list(params_list: List[torch.Tensor]) -> float:
    s = 0.0
    for t in params_list:
        s += float(t.float().pow(2).sum().item())
    return s ** 0.5

def scale_inplace(lst: List[torch.Tensor], scale: float):
    for i in range(len(lst)):
        lst[i].mul_(scale)

def add_inplace(dst: List[torch.Tensor], src: List[torch.Tensor], alpha: float = 1.0):
    for i in range(len(dst)):
        dst[i].add_(src[i], alpha=alpha)

def clone_to_cpu(lst: List[torch.Tensor]) -> List[torch.Tensor]:
    return [t.detach().to("cpu", non_blocking=True).clone() for t in lst]


In [8]:
# %% DP gradient oracles (record-level & user-level) -- MEMORY SAFE
def dp_gradient_oracle_record(model, batch, clip_C: float, noise_sigma: float) -> List[torch.Tensor]:
    """
    Per-example grads via re-forward (no retain_graph), clip/avg/noise on CPU.
    """
    model_device = next(model.parameters()).device
    params = params_that_train(model)

    ids  = batch["input_ids"]
    labs = batch["labels"]
    amsk = batch["attention_mask"]
    B = ids.size(0)

    sum_grad_cpu = zero_like_params(params, device="cpu")

    for i in range(B):
        model.zero_grad(set_to_none=True)
        out = model(
            input_ids=ids[i:i+1].to(model_device, non_blocking=True),
            labels=labs[i:i+1].to(model_device, non_blocking=True),
            attention_mask=amsk[i:i+1].to(model_device, non_blocking=True),
        )
        out.loss.backward()
        g_i_gpu = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
        g_i = clone_to_cpu(g_i_gpu)

        norm = l2norm_list(g_i)
        scale = min(1.0, clip_C / (norm + 1e-12))
        scale_inplace(g_i, scale)

        add_inplace(sum_grad_cpu, g_i, alpha=1.0 / B)
        model.zero_grad(set_to_none=True)

    std = noise_sigma * clip_C / B
    noisy_avg_cpu = [g + torch.randn_like(g, device="cpu") * std for g in sum_grad_cpu]
    private_grad = [g.to(model_device, non_blocking=True) for g in noisy_avg_cpu]
    return private_grad

def dp_gradient_oracle_user(model, batch, clip_C_user: float, noise_sigma: float):
    """
    User-level DP: sum by user on CPU, clip per-user, average users, add noise (CPU).
    Returns (private_grad_on_device, U_B).
    """
    model_device = next(model.parameters()).device
    params = params_that_train(model)

    ids  = batch["input_ids"]
    labs = batch["labels"]
    amsk = batch["attention_mask"]
    uids = batch["user_id"]
    B = ids.size(0)

    uniq = torch.unique(uids).tolist()
    user_sum = {int(u): zero_like_params(params, device="cpu") for u in uniq}

    for i in range(B):
        u = int(uids[i].item())
        model.zero_grad(set_to_none=True)
        out = model(
            input_ids=ids[i:i+1].to(model_device, non_blocking=True),
            labels=labs[i:i+1].to(model_device, non_blocking=True),
            attention_mask=amsk[i:i+1].to(model_device, non_blocking=True),
        )
        out.loss.backward()
        g_i_gpu = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
        g_i = clone_to_cpu(g_i_gpu)
        add_inplace(user_sum[u], g_i, alpha=1.0)
        model.zero_grad(set_to_none=True)

    clipped_users = []
    for u in uniq:
        g_u = user_sum[u]
        nrm = l2norm_list(g_u)
        scale = min(1.0, clip_C_user / (nrm + 1e-12))
        scale_inplace(g_u, scale)
        clipped_users.append(g_u)

    U_B = len(uniq)
    avg_cpu = zero_like_params(params, device="cpu")
    for g_u in clipped_users:
        add_inplace(avg_cpu, g_u, alpha=1.0 / U_B)

    std = noise_sigma * clip_C_user / U_B
    noisy_avg_cpu = [g + torch.randn_like(g, device="cpu") * std for g in avg_cpu]
    private_grad = [g.to(model_device, non_blocking=True) for g in noisy_avg_cpu]
    return private_grad, U_B


In [9]:
# %% updates: weight decay, DP-SGD, DP-DISFOM (prox on displacement, CPU-safe)
def apply_weight_decay(params, wd):
    if wd == 0: return
    with torch.no_grad():
        for p in params: p.add_(-wd * p)

def step_dpsgd(model, lr, private_grad):
    with torch.no_grad():
        for p, g in zip(params_that_train(model), private_grad):
            p.add_(-lr * g)

# DISFOM prox φ(z)=(ρ̂/2)||z||_1^2 applied to displacement u = -lr·ĝ
def tau_by_bisection(u: torch.Tensor, rho_hat: float, max_iter: int = 80, tol: float = 1e-12) -> float:
    lo, hi = 0.0, float(u.abs().max().item())
    def rhs(tau: float) -> float:
        return float(torch.nn.functional.relu(u.abs() - tau).sum().item())
    for _ in range(max_iter):
        mid = 0.5*(lo+hi)
        R = mid - rho_hat * rhs(mid)
        if abs(R) <= tol: return mid
        if R > 0: hi = mid
        else:     lo = mid
    return 0.5*(lo+hi)

def step_dpdisfom(model, lr, private_grad, rho_hat, box_lo=None, box_hi=None):
    """
    CPU-friendly prox: flatten on CPU to avoid peak GPU memory.
    """
    params = params_that_train(model)
    dev = next(model.parameters()).device
    with torch.no_grad():
        # move to CPU
        xk_list = [p.data.detach().cpu().clone() for p in params]
        g_list  = [g.detach().cpu().clone() for g in private_grad]

        xk_vec = torch.cat([t.view(-1) for t in xk_list])
        g_vec  = torch.cat([t.view(-1) for t in g_list])
        u = (-lr) * g_vec

        if box_lo is None and box_hi is None:
            tau = tau_by_bisection(u, rho_hat)
            y = torch.sign(u) * torch.nn.functional.relu(u.abs() - tau)
            x_next = xk_vec + y
        else:
            l = torch.full_like(xk_vec, box_lo) if box_lo is not None else torch.full_like(xk_vec, -float("inf"))
            ubox = torch.full_like(xk_vec, box_hi) if box_hi is not None else torch.full_like(xk_vec,  float("inf"))
            l_disp, u_disp = l - xk_vec, ubox - xk_vec
            def rhs_box(tau: float) -> float:
                z = torch.sign(u) * torch.nn.functional.relu(u.abs() - tau)
                z = torch.max(torch.min(z, u_disp), l_disp)
                return float(z.abs().sum().item())
            lo, hi = 0.0, float(u.abs().max().item())
            for _ in range(80):
                mid = 0.5*(lo+hi)
                R = mid - rho_hat * rhs_box(mid)
                if abs(R) <= 1e-12: tau = mid; break
                if R > 0: hi = mid
                else:     lo = mid
            else:
                tau = 0.5*(lo+hi)
            z = torch.sign(u) * torch.nn.functional.relu(u.abs() - tau)
            y = torch.max(torch.min(z, u_disp), l_disp)
            x_next = xk_vec + y

        # write back to device params
        offset = 0
        for p in params:
            n = p.numel()
            chunk = x_next[offset:offset+n].view(p.shape).to(dev)
            p.copy_(chunk)
            offset += n


In [10]:
# %% eval: PPL + ROUGE/BLEU
def eval_perplexity(model, eval_loader):
    model.eval(); total, count = 0.0, 0
    with torch.no_grad():
        for batch in eval_loader:
            for k in batch: batch[k] = batch[k].to(device)
            out = model(**batch)
            total += float(out.loss.item()) * batch["input_ids"].size(0)
            count += batch["input_ids"].size(0)
    mean = total / max(1, count)
    ppl = math.exp(mean) if mean < 50 else float("inf")
    model.train()
    return mean, ppl

def split_prompt_and_ref(txt: str):
    key = "### Response:\n"
    if key in txt:
        i = txt.index(key)
        return txt[:i+len(key)], txt[i+len(key):].strip()
    return txt, ""

def collect_eval_prompts_and_refs(dataset_validation, max_items=300):
    prompts, refs = [], []
    m = min(len(dataset_validation), max_items)
    for i in range(m):
        txt = dataset_validation[i]["text"]
        p, r = split_prompt_and_ref(txt)
        prompts.append(p); refs.append(r)
    return prompts, refs

@torch.no_grad()
def generate_responses(model, tok, prompts, max_new_tokens=128, temperature=0.0, top_p=1.0):
    outs=[]
    for p in prompts:
        inputs = tok(p, return_tensors="pt").to(device)
        ids = model.generate(**inputs, max_new_tokens=max_new_tokens,
                             do_sample=(temperature>0), temperature=temperature, top_p=top_p,
                             pad_token_id=tok.eos_token_id)
        full = tok.decode(ids[0], skip_special_tokens=True)
        outs.append(full[len(p):].strip() if full.startswith(p) else full.strip())
    return outs

def eval_text_metrics(preds, refs):
    r = rouge.compute(predictions=preds, references=refs)
    b = bleu.compute(predictions=preds, references=[[x] for x in refs])
    return {**r, "bleu": b["score"]}


In [11]:
# %% privacy helper: σ for target ε (record-style search)
def sigma_for_epsilon_record(target_eps, delta, q, T, lo=0.3, hi=5.0):
    for _ in range(25):
        mid = 0.5*(lo+hi)
        acct = RDPAccountant()
        for _ in range(T): acct.step(noise_multiplier=mid, sample_rate=q)
        eps = acct.get_epsilon(delta=delta)
        if eps > target_eps: lo = mid
        else:                hi = mid
    return hi


In [12]:
# %% training loop (scenario-aware)
def train_once(method: str, sigma: float, train_loader, eval_loader,
               privacy_mode: str, total_users: Optional[int] = None) -> Dict[str, float]:
    assert privacy_mode in ("record","user")
    model = build_model()
    params = params_that_train(model)
    lr = CFG.lr_disfom if method == "dpdisfom" else CFG.lr_sgd

    acct = RDPAccountant()

    run = None
    if USE_WANDB:
        run = wandb.init(
            project=os.environ["WANDB_PROJECT"],
            entity=os.environ.get("WANDB_ENTITY"),
            name=f"{method}-{privacy_mode}-sigma{sigma}-{int(time.time())}",
            group=f"{CFG.scenario}",
            config={
                "base_model": CFG.base_model, "method": method,
                "epochs": CFG.epochs, "clip": CFG.clip_norm,
                "sigma": sigma, "delta": CFG.delta,
                "rho_hat": CFG.rho_hat if method=="dpdisfom" else None,
                "lr": lr, "batch": CFG.micro_batch, "seq_len": CFG.max_seq_len,
                "privacy_mode": privacy_mode, "use_lora": CFG.use_lora
            },
            reinit=True,
        )
        wandb.define_metric("train/iter")
        wandb.define_metric("train/loss_iter", step_metric="train/iter")
        wandb.define_metric("epoch")
        wandb.define_metric("metrics/epsilon", step_metric="epoch")
        wandb.define_metric("metrics/loss_epoch", step_metric="epoch")

    global_step = 0
    for ep in range(CFG.epochs):
        pbar = tqdm(train_loader, desc=f"{method} [{privacy_mode}] σ={sigma} | epoch {ep+1}/{CFG.epochs}")
        for batch in pbar:
            for k in batch: batch[k] = batch[k].to(device)

            if privacy_mode == "record":
                g_priv = dp_gradient_oracle_record(model, batch, CFG.clip_norm, sigma)
                sample_rate = CFG.micro_batch / len(train_loader.dataset)
            else:
                g_priv, U_B = dp_gradient_oracle_user(model, batch, CFG.clip_norm, sigma)
                if total_users is None or total_users == 0:
                    raise ValueError("total_users must be provided for user-level DP.")
                sample_rate = U_B / total_users

            # pre-update loss for logging (approx same step)
            with torch.no_grad():
                out = model(input_ids=batch["input_ids"], labels=batch["labels"], attention_mask=batch["attention_mask"])
                batch_loss = float(out.loss.detach().cpu().item())

            apply_weight_decay(params, CFG.weight_decay)
            if method == "dpsgd":
                step_dpsgd(model, lr, g_priv)
            elif method == "dpdisfom":
                step_dpdisfom(model, lr, g_priv, CFG.rho_hat, CFG.box_lo, CFG.box_hi)
            else:
                raise ValueError(method)

            acct.step(noise_multiplier=sigma, sample_rate=sample_rate)

            global_step += 1
            pbar.set_postfix(loss=f"{batch_loss:.4f}")
            if run is not None:
                wandb.log({"train/iter": global_step, "train/loss_iter": batch_loss}, step=global_step)

        eval_loss, eval_ppl = eval_perplexity(model, eval_loader)
        eps_now = acct.get_epsilon(delta=CFG.delta)
        if run is not None:
            wandb.log({"epoch": ep+1,
                       "metrics/loss_epoch": eval_loss,
                       "metrics/ppl": eval_ppl,
                       "metrics/epsilon": eps_now}, step=global_step)

    # save
    tag = f"{CFG.base_model.replace('/','_')}_{method}_{privacy_mode}_sigma{sigma}"
    outdir = Path(CFG.save_root)/tag
    outdir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(outdir.as_posix())

    # final eval
    prompts, refs = collect_eval_prompts_and_refs(dataset["validation"], max_items=300)
    model.eval()
    preds = generate_responses(model, tokenizer, prompts, max_new_tokens=128, temperature=0.0)
    text_metrics = eval_text_metrics(preds, refs)
    ppl_loss, ppl = eval_perplexity(model, eval_loader)

    if run is not None:
        wandb.log({
            "final/rouge1": text_metrics["rouge1"],
            "final/rouge2": text_metrics["rouge2"],
            "final/rougeL": text_metrics["rougeL"],
            "final/bleu":   text_metrics["bleu"],
            "final/ppl_loss": ppl_loss,
            "final/ppl": ppl
        })
        run.finish()

    return {"model_dir": outdir.as_posix(),
            "rougeL": text_metrics["rougeL"],
            "bleu": text_metrics["bleu"],
            "ppl": ppl, "ppl_loss": ppl_loss}


In [13]:
# %% orchestrators
def run_matched_epsilon_record():
    print("== Matched-ε (Record-level DP) ==")
    steps_per_epoch = len(train_loader)
    T = CFG.epochs * steps_per_epoch
    q = CFG.micro_batch / len(train_loader.dataset)
    sigma = sigma_for_epsilon_record(CFG.target_epsilon, CFG.delta, q, T, lo=0.3, hi=5.0)
    print(f"ε_target={CFG.target_epsilon} -> σ≈{sigma:.3f} (q={q:.6f}, T={T})")

    res = {}
    for method in ("dpsgd", "dpdisfom"):
    # for method in ("dpdisfom",):
        res[method] = train_once(method, sigma, train_loader, eval_loader, privacy_mode="record")
        print(method, res[method])
    return res, sigma

def run_matched_epsilon_user():
    assert TOTAL_USERS is not None, "Rebuild dataset with add_user_ids=True."
    print("== Matched-ε (User-level DP) ==")
    steps_per_epoch = len(train_loader)
    T = CFG.epochs * steps_per_epoch
    # expected distinct users per batch ~ micro_batch (Poisson(5) gives many-user batches); conservative:
    q_user = min(1.0, CFG.micro_batch / max(1, TOTAL_USERS))
    sigma = sigma_for_epsilon_record(CFG.target_epsilon, CFG.delta, q_user, T, lo=0.3, hi=5.0)
    print(f"ε_target={CFG.target_epsilon} -> σ≈{sigma:.3f} (q_user≈{q_user:.6f}, T={T}, users={TOTAL_USERS})")

    res = {}
    for method in ("dpsgd", "dpdisfom"):
    # for method in ("dpdisfom",):
        res[method] = train_once(method, sigma, train_loader, eval_loader, privacy_mode="user", total_users=TOTAL_USERS)
        print(method, res[method])
    return res, sigma

def run_noise_stress():
    print("== Noise-Stress (σ sweep) ==")
    table = []
    for sigma in CFG.stress_sigmas:
        for method in ("dpsgd", "dpdisfom"):
            res = train_once(method, sigma, train_loader, eval_loader, privacy_mode="record")
            print(f"sigma={sigma:.2f}", method, res)
            table.append({"sigma": sigma, "method": method, **res})
    return table


In [None]:
# %% execute chosen scenario
if CFG.scenario == "matched_epsilon_record":
    results, sigma_star = run_matched_epsilon_record()
elif CFG.scenario == "matched_epsilon_user":
    results, sigma_star = run_matched_epsilon_user()
elif CFG.scenario == "noise_stress":
    stress_table = run_noise_stress()
else:
    raise ValueError("Unknown CFG.scenario")


== Matched-ε (Record-level DP) ==




ε_target=1 -> σ≈1.005 (q=0.002000, T=5000)
[Model] LoRA disabled (full fine-tune).
[Model] Gradient checkpointing enabled.


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


dpsgd [record] σ=1.00489399433136 | epoch 1/10:   0%|          | 0/500 [00:00<?, ?it/s]

