In [2]:
!nvidia-smi

Mon Jul 21 14:14:13 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   43C    P0              37W / 250W |  13436MiB / 16384MiB |      8%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [3]:
import torch

# Set device (GPU if available)
DEVICE_NUM = 6
ADDITIONAL_GPU = 1

if torch.cuda.is_available():
    torch.cuda.set_device(DEVICE_NUM)
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}:{DEVICE_NUM}")

INFO: Using device - cuda:6


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np
import pandas as pd
from scipy import stats
import time, copy, itertools, os, random

# ===================================================================
# 0. 재현성(Reproducibility)을 위한 시드 설정
# ===================================================================
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)
print(f"Global random seed set to {SEED}")

# ===================================================================
# 1. 실험 환경 설정
# ===================================================================
CONFIG = {
    "run_training": True,
    "model_save_dir": "saved_models",
    "num_runs": 3,
    "epochs": 30,
    "unlearn_epochs": 10,
    "batch_size": 256,
    "lr": 0.1,
    "unlearn_lr": 0.01,
    "unlearn_lr_neggrad": 1e-4,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "forget_set_size": 3000,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "l1_lambda": 1e-5,
    "neggrad_plus_alpha": 0.2,
    "salun_sparsity": 0.5,
    "scrub_alpha": 0.5,
}
print(f"Using device: {CONFIG['device']}")

# ===================================================================
# 2. 모델 및 데이터 헬퍼
# ===================================================================
def get_model():
    return models.resnet18(weights=None, num_classes=10).to(CONFIG["device"])

def train_model(model, train_loader, epochs, lr, is_unlearning=False):
    crit = nn.CrossEntropyLoss()
    opt = optim.SGD(model.parameters(), lr=lr,
                    momentum=CONFIG["momentum"], weight_decay=CONFIG["weight_decay"])
    sched = None if is_unlearning else optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    model.train()
    for ep in range(epochs):
        t0 = time.time()
        for x, y in train_loader:
            x, y = x.to(CONFIG["device"]), y.to(CONFIG["device"])
            opt.zero_grad()
            loss = crit(model(x), y)
            loss.backward()
            opt.step()
        if sched: sched.step()
        print(f"    Epoch {ep+1}/{epochs} completed in {time.time()-t0:.2f}s")

def evaluate_model(model, loader):
    model.eval()
    tot = corr = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(CONFIG["device"]), y.to(CONFIG["device"])
            pred = model(x).argmax(1)
            tot += y.size(0)
            corr += (pred == y).sum().item()
    return 100 * corr / tot

# ===================================================================
# 3. ES(Entanglement Score) 분할
# ===================================================================
def create_es_partitions(original_model, train_dataset, current_seed):
    print("\nCreating ES partitions...")
    t0 = time.time()
    extractor = nn.Sequential(*list(original_model.children())[:-1]).eval()
    emb = []
    g = torch.Generator(); g.manual_seed(current_seed)
    loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=False,
                        worker_init_fn=lambda i: set_seed(current_seed+i), generator=g)
    for i, (x, _) in enumerate(loader):
        if (i+1) % 40 == 0: print(f"      Batch {i+1}/{len(loader)}")
        with torch.no_grad():
            emb.append(extractor(x.to(CONFIG["device"])).squeeze().cpu())
    emb = torch.cat(emb, 0)
    dists = ((emb - emb.mean(0))**2).sum(1)
    idx = dists.argsort(descending=True).numpy()
    fs = CONFIG["forget_set_size"]
    parts = {"Low ES": idx[:fs], "Medium ES": idx[fs:2*fs], "High ES": idx[2*fs:3*fs]}
    print(f"ES partitions created in {time.time()-t0:.2f}s.")
    return parts

# ===================================================================
# 4. 언러닝 알고리즘
# ===================================================================
class RelabelDataset(Dataset):
    def __init__(self, ds, num_classes=10):
        self.ds = ds
        self.n = num_classes
        self.new_labels = [torch.randint(0, self.n, (1,)).item() for _ in range(len(ds))]
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        img, lbl = self.ds[i]
        nl = self.new_labels[i]
        while nl == lbl:
            nl = torch.randint(0, self.n, (1,)).item()
        return img, nl

def unlearn_finetune(orig, retain_loader, cfg):
    m = copy.deepcopy(orig)
    train_model(m, retain_loader, cfg["unlearn_epochs"], cfg["unlearn_lr"], True)
    return m

def unlearn_neggrad(orig, forget_loader, cfg):
    m = copy.deepcopy(orig); crit = nn.CrossEntropyLoss()
    opt = optim.SGD(m.parameters(), lr=cfg["unlearn_lr_neggrad"]); m.train()
    for _ in range(cfg["unlearn_epochs"]):
        for x, y in forget_loader:
            x, y = x.to(cfg["device"]), y.to(cfg["device"])
            opt.zero_grad(); loss = -crit(m(x), y); loss.backward(); opt.step()
    return m

def unlearn_l1_sparse(orig, retain_loader, cfg):
    m = copy.deepcopy(orig); crit = nn.CrossEntropyLoss()
    opt = optim.SGD(m.parameters(), lr=cfg["unlearn_lr"], momentum=CONFIG["momentum"]); m.train()
    for _ in range(cfg["unlearn_epochs"]):
        for x, y in retain_loader:
            x, y = x.to(cfg["device"]), y.to(cfg["device"])
            opt.zero_grad()
            l1 = sum(p.abs().sum() for p in m.parameters())
            loss = crit(m(x), y) + cfg["l1_lambda"] * l1
            loss.backward(); opt.step()
    return m

def unlearn_neggrad_plus(orig, retain_loader, forget_loader, cfg):
    m = copy.deepcopy(orig); crit = nn.CrossEntropyLoss()
    opt = optim.SGD(m.parameters(), lr=cfg["unlearn_lr"]); m.train()
    for _ in range(cfg["unlearn_epochs"]):
        r_iter = iter(itertools.cycle(retain_loader))
        for fx, fy in forget_loader:
            rx, ry = next(r_iter)
            rx, ry = rx.to(cfg["device"]), ry.to(cfg["device"])
            fx, fy = fx.to(cfg["device"]), fy.to(cfg["device"])
            opt.zero_grad()
            loss = crit(m(rx), ry) - cfg["neggrad_plus_alpha"] * crit(m(fx), fy)
            loss.backward(); opt.step()
    return m

def unlearn_scrub(orig, retain_loader, forget_loader, cfg):
    m = copy.deepcopy(orig); t_model = copy.deepcopy(orig).eval()
    crit = nn.CrossEntropyLoss(); kld = nn.KLDivLoss(reduction="batchmean")
    opt = optim.SGD(m.parameters(), lr=cfg["unlearn_lr"]); m.train()
    for _ in range(cfg["unlearn_epochs"]):
        r_iter = iter(itertools.cycle(retain_loader))
        for fx, _ in forget_loader:
            rx, ry = next(r_iter)
            rx, ry, fx = rx.to(cfg["device"]), ry.to(cfg["device"]), fx.to(cfg["device"])
            opt.zero_grad()
            loss = (1-cfg["scrub_alpha"])*crit(m(rx), ry) \
                   - cfg["scrub_alpha"]*kld(F.log_softmax(m(fx),1),
                                             F.softmax(t_model(fx),1))
            loss.backward(); opt.step()
    return m

def unlearn_random_label(orig, forget_set, cfg):
    m = copy.deepcopy(orig)
    loader = DataLoader(RelabelDataset(forget_set), batch_size=cfg["batch_size"], shuffle=True)
    train_model(m, loader, cfg["unlearn_epochs"], cfg["unlearn_lr"], True)
    return m

def unlearn_salun(orig, forget_set, cfg):
    m = copy.deepcopy(orig); sal = [torch.zeros_like(p) for p in m.parameters()]
    crit = nn.CrossEntropyLoss()
    f_loader = DataLoader(forget_set, batch_size=cfg["batch_size"])
    for x, y in f_loader:
        x, y = x.to(cfg["device"]), y.to(cfg["device"])
        m.zero_grad(); loss = crit(m(x), y); loss.backward()
        for i, p in enumerate(m.parameters()):
            if p.grad is not None: sal[i] += p.grad.abs()
    flat = torch.cat([s.flatten() for s in sal]); k = int(len(flat)*cfg["salun_sparsity"])
    th, _ = torch.kthvalue(flat, k); masks = [(s>th).float() for s in sal]

    loader = DataLoader(RelabelDataset(forget_set), batch_size=cfg["batch_size"], shuffle=True)
    opt = optim.SGD(m.parameters(), lr=cfg["unlearn_lr"], momentum=CONFIG["momentum"]); m.train()
    for _ in range(cfg["unlearn_epochs"]):
        for x, y in loader:
            x, y = x.to(cfg["device"]), y.to(cfg["device"])
            opt.zero_grad(); loss = crit(m(x), y); loss.backward()
            for i, p in enumerate(m.parameters()):
                if p.grad is not None: p.grad *= masks[i]
            opt.step()
    return m

# ===================================================================
# 5. MIA: black-box benchmarks
# ===================================================================
class black_box_benchmarks:
    def __init__(self, s_tr, s_te, t_tr, t_te, num_classes):
        self.k = num_classes
        self.s_tr_out, self.s_tr_lab = s_tr
        self.s_te_out, self.s_te_lab = s_te
        self.t_tr_out, self.t_tr_lab = t_tr
        self.t_te_out, self.t_te_lab = t_te

        self.s_tr_corr = (self.s_tr_out.argmax(1)==self.s_tr_lab).astype(int)
        self.s_te_corr = (self.s_te_out.argmax(1)==self.s_te_lab).astype(int)
        self.t_tr_corr = (self.t_tr_out.argmax(1)==self.t_tr_lab).astype(int)
        self.t_te_corr = (self.t_te_out.argmax(1)==self.t_te_lab).astype(int)

        self.s_tr_conf = self.s_tr_out[np.arange(len(self.s_tr_lab)), self.s_tr_lab]
        self.s_te_conf = self.s_te_out[np.arange(len(self.s_te_lab)), self.s_te_lab]
        self.t_tr_conf = self.t_tr_out[np.arange(len(self.t_tr_lab)), self.t_tr_lab]
        self.t_te_conf = self.t_te_out[np.arange(len(self.t_te_lab)), self.t_te_lab]

        self.s_tr_entr = self._entr(self.s_tr_out)
        self.s_te_entr = self._entr(self.s_te_out)
        self.t_tr_entr = self._entr(self.t_tr_out)
        self.t_te_entr = self._entr(self.t_te_out)

        self.s_tr_m_entr = self._m_entr(self.s_tr_out, self.s_tr_lab)
        self.s_te_m_entr = self._m_entr(self.s_te_out, self.s_te_lab)
        self.t_tr_m_entr = self._m_entr(self.t_tr_out, self.t_tr_lab)
        self.t_te_m_entr = self._m_entr(self.t_te_out, self.t_te_lab)

    def _log(self,p,eps=1e-30): return -np.log(np.maximum(p,eps))
    def _entr(self,p): return (p*self._log(p)).sum(1)
    def _m_entr(self,p,l):
        lp = self._log(p); rp = 1-p; lrp = self._log(rp)
        mp = p.copy(); mp[np.arange(l.size),l]=rp[np.arange(l.size),l]
        mlp=lrp.copy(); mlp[np.arange(l.size),l]=lp[np.arange(l.size),l]
        return (mp*mlp).sum(1)

    def _thre(self, tr, te):
        vals = np.concatenate((tr,te)); best_acc=0; best_t=0
        for v in vals:
            acc = 0.5*( (tr>=v).mean() + (te<v).mean() )
            if acc>best_acc: best_acc, best_t = acc, v
        return best_t

    def _via_corr(self):
        acc = 0.5*(self.t_tr_corr.mean() + (1-self.t_te_corr).mean())
        return acc

    def _via_feat(self, tr, te, Ttr, Tte):
        t_mem = t_non = 0
        for c in range(self.k):
            thr = self._thre(tr[self.s_tr_lab==c], te[self.s_te_lab==c])
            t_mem  += (Ttr[self.t_tr_lab==c] >= thr).sum()
            t_non  += (Tte[self.t_te_lab==c] <  thr).sum()
        return 0.5*(t_mem/len(Ttr) + t_non/len(Tte))

    def run(self):
        return {
            "correctness": self._via_corr(),
            "confidence" : self._via_feat(self.s_tr_conf, self.s_te_conf,
                                          self.t_tr_conf, self.t_te_conf),
            "entropy"    : self._via_feat(-self.s_tr_entr, -self.s_te_entr,
                                          -self.t_tr_entr, -self.t_te_entr),
            "m_entropy"  : self._via_feat(-self.s_tr_m_entr, -self.s_te_m_entr,
                                          -self.t_tr_m_entr, -self.t_te_m_entr)
        }

def collect_performance(loader, model, device):
    outs, labs = [], []
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            outs.append(F.softmax(model(x),1).cpu())
            labs.append(y.cpu())
    return torch.cat(outs).numpy(), torch.cat(labs).numpy()

def calculate_mia_score(model, retain_loader_train, retain_loader_test,
                        forget_loader, test_loader):
    s_tr = collect_performance(retain_loader_train, model, CONFIG["device"])
    s_te = collect_performance(test_loader,         model, CONFIG["device"])
    t_tr = collect_performance(retain_loader_test,  model, CONFIG["device"])
    t_te = collect_performance(forget_loader,       model, CONFIG["device"])

    mia = black_box_benchmarks(s_tr, s_te, t_tr, t_te, 10).run()
    return mia["confidence"]  # scalar 사용

# ===================================================================
# 6. 메인 실험 루프
# ===================================================================
def main():
    sd = CONFIG["model_save_dir"]; os.makedirs(sd, exist_ok=True)
    tf_train = transforms.Compose([
        transforms.RandomCrop(32,4), transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
    ])
    tf_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
    ])
    DATASET = "../../data"
    tr_ds = datasets.CIFAR10(root=DATASET,
                         train=True,
                         download=False,
                         transform=tf_train)

    te_ds = datasets.CIFAR10(root=DATASET,
                            train=False,
                            download=False,
                            transform=tf_test)

    g = torch.Generator(); g.manual_seed(SEED+9999)
    te_loader = DataLoader(te_ds, batch_size=CONFIG["batch_size"], shuffle=False,
                           worker_init_fn=lambda i:set_seed(SEED+9999+i), generator=g)

    methods = ["Retrain","Original","Fine-tune","L1-sparse","NegGrad","NegGrad+",
               "SCRUB","SalUn","Random-label"]
    res = {m:{es:{"F":[],"R":[],"T":[],"M":[]} for es in ["Low ES","Medium ES","High ES"]} for m in methods}

    for run in range(CONFIG["num_runs"]):
        print(f"\n{'='*20} Starting Run {run+1}/{CONFIG['num_runs']} {'='*20}")
        orig = get_model()
        orig_pth = f"{sd}/run_{run}_original_model.pth"
        part_pth = f"{sd}/run_{run}_es_partitions.pth"

        if os.path.exists(orig_pth):
            print("\n[LOADING] original model"); orig.load_state_dict(torch.load(orig_pth,map_location=CONFIG["device"]))
            parts = torch.load(part_pth)
        else:
            print("\n[TRAINING] original model")
            gtr = torch.Generator(); gtr.manual_seed(SEED+run)
            tr_loader = DataLoader(tr_ds, batch_size=CONFIG["batch_size"], shuffle=True,
                                   worker_init_fn=lambda i:set_seed(SEED+run+i), generator=gtr)
            train_model(orig, tr_loader, CONFIG["epochs"], CONFIG["lr"])
            torch.save(orig.state_dict(), orig_pth)
            parts = create_es_partitions(orig, tr_ds, SEED+run); torch.save(parts, part_pth)

        for es, forget_idx in parts.items():
            print(f"\n--- Processing ES Level: {es} ---")
            all_idx = np.arange(len(tr_ds))
            retain_idx = np.setdiff1d(all_idx, forget_idx, assume_unique=True)
            r_set, f_set = Subset(tr_ds, retain_idx), Subset(tr_ds, forget_idx)

            # loaders
            g_r = torch.Generator(); g_r.manual_seed(SEED+run+ord(es[0]))
            retain_loader = DataLoader(r_set, batch_size=CONFIG["batch_size"], shuffle=True,
                                       worker_init_fn=lambda i:set_seed(SEED+run+ord(es[0])+i), generator=g_r)
            g_re = torch.Generator(); g_re.manual_seed(SEED+run+ord(es[0])+1000)
            retain_eval = DataLoader(r_set, batch_size=CONFIG["batch_size"], shuffle=False,
                                     worker_init_fn=lambda i:set_seed(SEED+run+ord(es[0])+1000+i), generator=g_re)
            g_f = torch.Generator(); g_f.manual_seed(SEED+run+ord(es[0])+2000)
            forget_loader = DataLoader(f_set, batch_size=CONFIG["batch_size"], shuffle=False,
                                       worker_init_fn=lambda i:set_seed(SEED+run+ord(es[0])+2000+i), generator=g_f)

            # retrain
            rt_pth = f"{sd}/run_{run}_{es.replace(' ','')}_retrained.pth"
            retr = get_model()
            if os.path.exists(rt_pth):
                print(f"\n[LOADING] retrained model for {es}"); retr.load_state_dict(torch.load(rt_pth,map_location=CONFIG["device"]))
            else:
                print(f"\n[TRAINING] retrained model for {es}")
                train_model(retr, retain_loader, CONFIG["epochs"], CONFIG["lr"])
                torch.save(retr.state_dict(), rt_pth)

            print("\nEvaluating retrained model...")
            r_acc = evaluate_model(retr, forget_loader)
            r_ret = evaluate_model(retr, retain_eval)
            r_test = evaluate_model(retr, te_loader)
            r_mia = calculate_mia_score(retr, retain_loader, retain_eval, forget_loader, te_loader)
            print(f"  Retrain Accs -> F:{r_acc:.2f}% R:{r_ret:.2f}% T:{r_test:.2f}%  MIA:{r_mia:.3f}")
            res["Retrain"][es]["F"].append(r_acc); res["Retrain"][es]["R"].append(r_ret)
            res["Retrain"][es]["T"].append(r_test); res["Retrain"][es]["M"].append(r_mia)

            unlearn = {
                "Original"    : lambda: copy.deepcopy(orig),
                "Fine-tune"   : lambda: unlearn_finetune(orig, retain_loader, CONFIG),
                "L1-sparse"   : lambda: unlearn_l1_sparse(orig, retain_loader, CONFIG),
                "NegGrad"     : lambda: unlearn_neggrad(orig, forget_loader, CONFIG),
                "NegGrad+"    : lambda: unlearn_neggrad_plus(orig, retain_loader, forget_loader, CONFIG),
                "SCRUB"       : lambda: unlearn_scrub(orig, retain_loader, forget_loader, CONFIG),
                "SalUn"       : lambda: unlearn_salun(orig, f_set, CONFIG),
                "Random-label": lambda: unlearn_random_label(orig, f_set, CONFIG),
            }

            print("\nApplying and evaluating unlearning methods...")
            for m_name, fn in unlearn.items():
                upth = f"{sd}/run_{run}_{es.replace(' ','')}_{m_name}_unlearned.pth"
                if os.path.exists(upth):
                    print(f"    > [LOADING] {m_name}"); u_model = get_model(); u_model.load_state_dict(torch.load(upth,map_location=CONFIG["device"]))
                else:
                    print(f"    > [TRAINING] {m_name}"); u_model = fn(); torch.save(u_model.state_dict(), upth)

                u_f = evaluate_model(u_model, forget_loader)
                u_r = evaluate_model(u_model, retain_eval)
                u_t = evaluate_model(u_model, te_loader)
                u_m = calculate_mia_score(u_model, retain_loader, retain_eval, forget_loader, te_loader)
                print(f"      - {m_name}  F:{u_f:.2f}% R:{u_r:.2f}% T:{u_t:.2f}%  MIA:{u_m:.3f}")

                res[m_name][es]["F"].append(u_f); res[m_name][es]["R"].append(u_r)
                res[m_name][es]["T"].append(u_t); res[m_name][es]["M"].append(u_m)

    # ===================================================================
    # 7. 결과 정리
    # ===================================================================
    print(f"\n{'='*20} Final Results {'='*20}")
    def fmt(xs):
        xs=np.array(xs); mu=xs.mean()
        return f"{mu:.3f}" if len(xs)<2 else f"{mu:.3f} ± {(stats.sem(xs)*stats.t.ppf(0.975,len(xs)-1)):.3f}"
    for es in ["Low ES","Medium ES","High ES"]:
        print(f"\n--- Results for {es} ---")
        rows=[]
        for m in methods:
            row={"Method":m,
                 "Forget Acc":fmt(res[m][es]["F"]),
                 "Retain Acc":fmt(res[m][es]["R"]),
                 "Test Acc"  :fmt(res[m][es]["T"]),
                 "MIA"       :fmt(res[m][es]["M"])}
            rows.append(row)
        print(pd.DataFrame(rows).to_string(index=False))

if __name__ == "__main__":
    main()


Global random seed set to 42
Using device: cuda


[TRAINING] original model
    Epoch 1/30 completed in 32.62s
    Epoch 2/30 completed in 31.31s
    Epoch 3/30 completed in 31.98s
    Epoch 4/30 completed in 31.47s
    Epoch 5/30 completed in 31.81s
    Epoch 6/30 completed in 31.55s
    Epoch 7/30 completed in 31.53s
    Epoch 8/30 completed in 32.11s
    Epoch 9/30 completed in 31.29s
    Epoch 10/30 completed in 31.61s
    Epoch 11/30 completed in 31.38s
    Epoch 12/30 completed in 31.79s
    Epoch 13/30 completed in 31.55s
    Epoch 14/30 completed in 31.29s
    Epoch 15/30 completed in 32.22s
    Epoch 16/30 completed in 31.52s
    Epoch 17/30 completed in 31.18s
    Epoch 18/30 completed in 31.62s
    Epoch 19/30 completed in 31.87s
    Epoch 20/30 completed in 31.46s
    Epoch 21/30 completed in 31.25s
    Epoch 22/30 completed in 30.78s
    Epoch 23/30 completed in 31.21s
    Epoch 24/30 completed in 31.48s
    Epoch 25/30 completed in 31.39s
    Epoch 26/30 completed in 31.3