In [1]:
!git clone https://github.com/kuangliu/pytorch-cifar.git
%cd pytorch-cifar

Cloning into 'pytorch-cifar'...
remote: Enumerating objects: 382, done.[K
remote: Total 382 (delta 0), reused 0 (delta 0), pack-reused 382 (from 1)[K
Receiving objects: 100% (382/382), 94.67 KiB | 9.47 MiB/s, done.
Resolving deltas: 100% (182/182), done.
/content/pytorch-cifar


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *

In [3]:
try:
    from models.resnet import ResNet18
except Exception as e:
    # 로컬에 models가 없으면 간단 안내
    raise RuntimeError("models/ 디렉터리가 필요합니다. 깃 레포를 클론하거나 models를 복사하세요.") from e

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [5]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..


100%|██████████| 170M/170M [00:05<00:00, 29.9MB/s]


In [6]:
# Model
print('==> Building model..')
net = ResNet18()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

==> Building model..


In [7]:
import sys, time

# 진행바용 전역 변수 초기화
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time

def format_time(seconds):
    """초 단위 시간을 h m s 포맷으로 바꾸는 함수"""
    L = []
    for count, unit in [(3600,'h'), (60,'m'), (1,'s')]:
        if seconds >= count:
            val = int(seconds // count)
            seconds -= val * count
            L.append(f"{val}{unit}")
    return ' '.join(L) if L else '0s'

def progress_bar(current, total, msg=None):
    """
    progress_bar(i, len(loader), '메시지') 형태로 호출
    예시:
        for i, (x, y) in enumerate(train_loader):
            progress_bar(i, len(train_loader), f"loss {loss.item():.3f}")
    """
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()

    cur_len = int(TOTAL_BAR_LENGTH * current / total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
    bar = '[' + '=' * cur_len + '>' + '.' * rest_len + ']'

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    text = f"\r{bar} Step: {format_time(step_time)} | Tot: {format_time(tot_time)}"
    if msg:
        text += f" | {msg}"

    # Colab은 stdout flush를 꼭 해줘야 바로 출력됨
    sys.stdout.write(text)
    sys.stdout.flush()
    if current == total - 1:
        sys.stdout.write('\n')

In [8]:
# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ResNet18.pth')
        best_acc = acc

In [9]:
for epoch in range(200):
    train(epoch)
    test(epoch)
    scheduler.step()


Epoch: 0
Saving..

Epoch: 1
Saving..

Epoch: 2
Saving..

Epoch: 3
Saving..

Epoch: 4
Saving..

Epoch: 5
Saving..

Epoch: 6

Epoch: 7
Saving..

Epoch: 8

Epoch: 9

Epoch: 10
Saving..

Epoch: 11

Epoch: 12

Epoch: 13
Saving..

Epoch: 14

Epoch: 15
Saving..

Epoch: 16

Epoch: 17

Epoch: 18

Epoch: 19
Saving..

Epoch: 20

Epoch: 21
Saving..

Epoch: 22

Epoch: 23

Epoch: 24
Saving..

Epoch: 25

Epoch: 26

Epoch: 27
Saving..

Epoch: 28
Saving..

Epoch: 29

Epoch: 30

Epoch: 31
Saving..

Epoch: 32

Epoch: 33

Epoch: 34

Epoch: 35

Epoch: 36
Saving..

Epoch: 37
Saving..

Epoch: 38

Epoch: 39

Epoch: 40

Epoch: 41

Epoch: 42

Epoch: 43

Epoch: 44

Epoch: 45
Saving..

Epoch: 46
Saving..

Epoch: 47

Epoch: 48

Epoch: 49

Epoch: 50

Epoch: 51

Epoch: 52

Epoch: 53

Epoch: 54

Epoch: 55

Epoch: 56
Saving..

Epoch: 57

Epoch: 58

Epoch: 59

Epoch: 60

Epoch: 61

Epoch: 62

Epoch: 63
Saving..

Epoch: 64

Epoch: 65

Epoch: 66
Saving..

Epoch: 67

Epoch: 68

Epoch: 69

Epoch: 70

Epoch: 71

Epoch: 72


In [10]:
import torch, time, os
from models import resnet  # kuangliu models/ 사용 가정

# 1) 견고한 state_dict 로더 (net/model 키, module. 접두 처리)
def load_state_dict_safely(model, ckpt):
    # 1) state_dict 추출
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        state = ckpt["state_dict"]
    elif isinstance(ckpt, dict) and "net" in ckpt:
        state = ckpt["net"]
    elif isinstance(ckpt, dict) and "model" in ckpt:
        state = ckpt["model"]
    else:
        # 마지막 수단: ckpt 자체가 state_dict라고 가정
        state = ckpt

    # 2) DataParallel 호환 (module. 접두사 제거)
    from collections import OrderedDict
    new_state = OrderedDict()
    for k, v in state.items():
        nk = k.replace("module.", "", 1) if k.startswith("module.") else k
        new_state[nk] = v

    # 3) 로드
    model.load_state_dict(new_state, strict=True)
    return model

# 2) 파라미터/희소도/크기/지연 측정 유틸
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    nnz   = sum((p != 0).sum().item() for p in model.parameters() if p.dtype.is_floating_point)
    return total, nnz

@torch.no_grad()
def measure_latency(model, device="cuda", img_size=32, bs=256, iters=50):
    model.eval().to(device)
    x = torch.randn(bs, 3, img_size, img_size, device=device)
    # warmup
    for _ in range(10): _ = model(x)
    if device == "cuda": torch.cuda.synchronize()
    t = 0.0
    for _ in range(iters):
        t0 = time.perf_counter()
        _ = model(x)
        if device == "cuda": torch.cuda.synchronize()
        t += (time.perf_counter() - t0) / bs
    return (t/iters) * 1000.0  # ms/img

# 3) 실행: ResNet-18 모델 만들고 ckpt 로드 → 표 값 출력
ckpt = torch.load("checkpoint/ResNet18.pth", map_location="cpu")   # ← 네 파일 경로
model = resnet.ResNet18()                               # CIFAR-10용 ResNet-18

# 로드
model = load_state_dict_safely(model, ckpt)

# 비교표 값 계산
total, nnz = count_params(model)
sparsity = 1 - nnz/total
size_MB  = total * 4 / (1024**2)  # FP32 가정(4B/weight)
lat_ms   = measure_latency(model, device="cuda" if torch.cuda.is_available() else "cpu")
acc      = ckpt.get("acc", None)

# 출력 (보고서 표 형식)
print(f"{'model':10s} {'sparsity':>10s} {'acc(%)':>10s} {'params(M)':>12s} {'size(MB)':>10s} {'lat(ms/img)':>12s}")
print("-"*70)
print(f"{'ResNet18':10s} {sparsity:10.3f} { (acc if acc is not None else 0):10.2f} "
      f"{total/1e6:12.2f} {size_MB:10.2f} {lat_ms:12.3f}")

model        sparsity     acc(%)    params(M)   size(MB)  lat(ms/img)
----------------------------------------------------------------------
ResNet18        0.000      95.58        11.17      42.63        0.026


In [14]:
from torch.nn.utils import prune

def magnitude_prune_global(model, sparsity=0.9):
    params_to_prune = []
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            params_to_prune.append((m, "weight"))
    prune.global_unstructured(
        params_to_prune, pruning_method=prune.L1Unstructured, amount=sparsity
    )
    for m, _ in params_to_prune:
        prune.remove(m, "weight")
    return model

In [15]:
def run_magnitude_once(ckpt_path, save_path, sparsity=0.9, seed=1, finetune_epochs=5, lr=1e-2):
    """
    kuangliu/pytorch-cifar 형태의 체크포인트(`'net'`, `'acc'`, `'epoch'`)를 지원하는
    Magnitude-based pruning 실행 함수.
    """
    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader

    print(f"\n[MAG] seed={seed}, sparsity={sparsity}")

    # 랜덤 시드 고정
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 데이터셋
    mean, std = (0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010)
    tf_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ])
    tf_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ])
    tr = datasets.CIFAR10("./data", train=True, download=True, transform=tf_train)
    te = datasets.CIFAR10("./data", train=False, download=True, transform=tf_test)
    train_loader = DataLoader(tr, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(te, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)

    # 모델 생성 및 ckpt 로드
    model = resnet.ResNet18().to(device)
    ckpt = torch.load(ckpt_path, map_location="cpu")

    # ✅ 'net' 또는 'model' 키 자동 감지
    state = ckpt.get("net", ckpt.get("model", ckpt))
    from collections import OrderedDict
    new_state = OrderedDict()
    for k, v in state.items():
        nk = k.replace("module.", "", 1) if k.startswith("module.") else k
        new_state[nk] = v
    model.load_state_dict(new_state, strict=True)

    # 프루닝 전 정확도
    base_acc, _ = evaluate(model, test_loader, device)
    print(f"[BASELINE] acc={base_acc:.2f}%")

    # ① Magnitude Pruning
    model = magnitude_prune_global(model, sparsity=sparsity)
    total, nnz = count_params(model)
    print(f"[PRUNED] sparsity={1 - nnz/total:.4f}, params={total/1e6:.2f}M")

    # ② 파인튜닝 (짧게)
    ce = nn.CrossEntropyLoss()
    opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
    for ep in range(finetune_epochs):
        model.train()
        loss_sum = correct = total_ = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = ce(logits, y)
            loss.backward()
            opt.step()
            loss_sum += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            total_ += y.numel()
        tr_acc = 100 * correct / total_
        te_acc, te_loss = evaluate(model, test_loader, device)
        print(f"[FT {ep+1}/{finetune_epochs}] train_acc={tr_acc:.2f}%  test_acc={te_acc:.2f}%")

    # ③ 최종 평가 및 저장
    acc, _ = evaluate(model, test_loader, device)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save({
        "net": model.state_dict(),        # ✅ kuangliu 포맷 유지
        "acc": float(acc),
        "base_acc": float(base_acc),
        "sparsity": float(sparsity),
        "seed": int(seed),
        "params_total": int(total),
        "params_nnz": int(nnz),
    }, save_path)
    print(f"[DONE] acc={acc:.2f}%  saved: {save_path}")

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_loss += ce(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += y.numel()
    return 100 * correct / total, total_loss / total

In [21]:
sparsities = [0.00, 0.20, 0.40, 0.60, 0.80, 0.90, 0.95, 0.98]
seeds = [1, 2, 3]
ckpt_path = "checkpoint/ResNet18.pth"

os.makedirs("results/mag", exist_ok=True)

for sp in sparsities:
    for s in seeds:
        run_magnitude_once(
            ckpt_path=ckpt_path,
            save_path=f"results/mag/resnet18_mag_s{s}_sp{sp}.pth",
            sparsity=sp,
            seed=s,
            finetune_epochs=5,   # 빠른 실험용
            lr=1e-2
        )


[MAG] seed=1, sparsity=0.0
[BASELINE] acc=95.58%
[PRUNED] sparsity=0.0000, params=11.17M
[FT 1/5] train_acc=100.00%  test_acc=95.48%
[FT 2/5] train_acc=100.00%  test_acc=95.53%
[FT 3/5] train_acc=100.00%  test_acc=95.55%
[FT 4/5] train_acc=100.00%  test_acc=95.48%
[FT 5/5] train_acc=100.00%  test_acc=95.44%
[DONE] acc=95.44%  saved: results/mag/resnet18_mag_s1_sp0.0.pth

[MAG] seed=2, sparsity=0.0
[BASELINE] acc=95.58%
[PRUNED] sparsity=0.0000, params=11.17M
[FT 1/5] train_acc=100.00%  test_acc=95.53%
[FT 2/5] train_acc=100.00%  test_acc=95.62%
[FT 3/5] train_acc=100.00%  test_acc=95.61%
[FT 4/5] train_acc=100.00%  test_acc=95.43%
[FT 5/5] train_acc=100.00%  test_acc=95.50%
[DONE] acc=95.50%  saved: results/mag/resnet18_mag_s2_sp0.0.pth

[MAG] seed=3, sparsity=0.0
[BASELINE] acc=95.58%
[PRUNED] sparsity=0.0000, params=11.17M
[FT 1/5] train_acc=100.00%  test_acc=95.52%
[FT 2/5] train_acc=100.00%  test_acc=95.57%
[FT 3/5] train_acc=100.00%  test_acc=95.58%
[FT 4/5] train_acc=99.99%  tes

In [22]:
# --------------------------
# OBD 핵심: H_ii 대각 근사(E[g^2]) 계산
# --------------------------
@torch.no_grad()
def _is_prunable(m):
    return isinstance(m, (nn.Conv2d, nn.Linear))

def estimate_hessian_diag_eg2(model, loader, device, max_batches=100):
    """
    E[g^2]로 H의 대각 근사 추정 (OBD에서 saliency = 0.5 * H_ii * w_i^2)
    Conv/Linear의 weight만 대상으로 함.
    """
    model.train()  # BN 통계 안정
    ce = nn.CrossEntropyLoss()

    # 대상 파라미터 수집 (Conv/Linear weight만)
    modules, params = [], []
    for m in model.modules():
        if _is_prunable(m) and m.weight.requires_grad:
            modules.append(m)
            params.append(m.weight)

    # 누적 버퍼
    hdiag = [torch.zeros_like(p, device=device) for p in params]
    batches = 0

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        model.zero_grad(set_to_none=True)
        logits = model(x)
        loss = ce(logits, y)

        # g = dL/dw  (대상 파라미터만 미분)
        grads = torch.autograd.grad(loss, params, retain_graph=False, create_graph=False, allow_unused=False)
        for hd, g in zip(hdiag, grads):
            hd += (g.detach() ** 2)
        batches += 1
        if batches >= max_batches:
            break

    for i in range(len(hdiag)):
        hdiag[i] /= max(1, batches)
    return hdiag, modules  # modules[i].weight와 hdiag[i]가 1:1

In [23]:
# --------------------------
# OBD 프루닝 (글로벌)
# --------------------------
def obd_prune_global(model, loader, device, sparsity=0.9, max_batches=100):
    """
    saliency_i = 0.5 * H_ii * w_i^2 를 전 레이어에서 모아
    작은 것부터 amount(=sparsity) 비율만큼 0으로 설정.
    """
    hdiag_list, modules = estimate_hessian_diag_eg2(model, loader, device, max_batches=max_batches)

    # 레이어별 saliency 계산 & 전역 임계값 산출
    saliency_flat = []
    for m, Hd in zip(modules, hdiag_list):
        W = m.weight.data
        sal = 0.5 * Hd * (W ** 2)
        saliency_flat.append(sal.flatten())
    saliency_flat = torch.cat(saliency_flat)

    k = int(saliency_flat.numel() * sparsity)
    if k > 0:
        thresh = torch.topk(saliency_flat, k, largest=False).values.max()
    else:
        thresh = saliency_flat.min() - 1  # 아무것도 자르지 않음

    # 임계값 이하를 0으로
    for m, Hd in zip(modules, hdiag_list):
        W = m.weight.data
        sal = 0.5 * Hd * (W ** 2)
        mask = (sal <= thresh)
        W[mask] = 0.0

    return model

# --------------------------
# 통합 실행 함수: ckpt 로드 → OBD → (짧게) 파인튜닝 → 저장
# --------------------------
def run_obd_once(ckpt_path, save_path, sparsity=0.9, seed=1, finetune_epochs=5, lr=1e-2, max_batches=100):
    """
    kuangliu/pytorch-cifar 형태('net' 키)의 ckpt 지원.
    """
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 데이터
    train_loader, test_loader = get_loaders(batch_train=256, batch_test=512, workers=2)

    # 모델 & ckpt 로드
    model = resnet.ResNet18().to(device)
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt.get("net", ckpt.get("model", ckpt))
    new_state = OrderedDict((k.replace("module.","",1) if k.startswith("module.") else k, v) for k,v in state.items())
    model.load_state_dict(new_state, strict=True)

    # 프루닝 전 성능
    base_acc, _ = evaluate(model, test_loader, device)
    print(f"[BASELINE] acc={base_acc:.2f}%")

    # OBD 프루닝
    model = obd_prune_global(model, train_loader, device, sparsity=sparsity, max_batches=max_batches)
    total, nnz = count_params(model)
    print(f"[OBD-PRUNED] sparsity={1 - nnz/total:.4f}, params={total/1e6:.2f}M")

    # 짧은 파인튜닝
    if finetune_epochs > 0:
        ce = nn.CrossEntropyLoss()
        opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
        for ep in range(finetune_epochs):
            model.train()
            loss_sum = cor = tot = 0
            for x,y in train_loader:
                x,y = x.to(device), y.to(device)
                opt.zero_grad(set_to_none=True)
                logits = model(x); loss = ce(logits,y)
                loss.backward(); opt.step()
                loss_sum += loss.item()*x.size(0)
                cor += (logits.argmax(1)==y).sum().item()
                tot += y.numel()
            tr_acc = 100*cor/tot
            te_acc, te_loss = evaluate(model, test_loader, device)
            print(f"[FT {ep+1}/{finetune_epochs}] train_acc={tr_acc:.2f}%  test_acc={te_acc:.2f}%")

    # 저장(kuangliu 포맷)
    acc, _ = evaluate(model, test_loader, device)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save({
        "net": model.state_dict(),
        "acc": float(acc),
        "base_acc": float(base_acc),
        "sparsity": float(sparsity),
        "seed": int(seed),
        "params_total": int(total),
        "params_nnz": int(nnz),
        "method": "obd",
    }, save_path)
    print(f"[DONE][OBD] sp={sparsity} acc={acc:.2f}% → {save_path}")

In [24]:
def get_loaders(batch_train=256, batch_test=512, workers=2):
    mean,std=(0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)
    tf_tr = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean,std),
    ])
    tf_te = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean,std)])
    tr = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=tf_tr)
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tf_te)
    tr_loader = torch.utils.data.DataLoader(tr, batch_size=batch_train, shuffle=True, num_workers=workers, pin_memory=True, drop_last=True)
    te_loader = torch.utils.data.DataLoader(te, batch_size=batch_test, shuffle=False, num_workers=workers, pin_memory=True)
    return tr_loader, te_loader

In [25]:
from collections import OrderedDict

ckpt_path = "checkpoint/ResNet18.pth"  # 너의 베이스라인 ckpt 경로
os.makedirs("results/obd", exist_ok=True)

sparsities = [0.00, 0.20, 0.40, 0.60, 0.80, 0.90, 0.95, 0.98]
seeds = [1, 2, 3]

for sp in sparsities:
    for s in seeds:
        run_obd_once(
            ckpt_path=ckpt_path,
            save_path=f"results/obd/resnet18_obd_s{s}_sp{sp}.pth",
            sparsity=sp, seed=s,
            finetune_epochs=5, lr=1e-2,
            max_batches=200,
        )

[BASELINE] acc=95.58%
[OBD-PRUNED] sparsity=0.0000, params=11.17M
[FT 1/5] train_acc=99.99%  test_acc=95.62%
[FT 2/5] train_acc=99.99%  test_acc=95.50%
[FT 3/5] train_acc=100.00%  test_acc=95.52%
[FT 4/5] train_acc=99.99%  test_acc=95.50%
[FT 5/5] train_acc=100.00%  test_acc=95.34%
[DONE][OBD] sp=0.0 acc=95.34% → results/obd/resnet18_obd_s1_sp0.0.pth
[BASELINE] acc=95.58%
[OBD-PRUNED] sparsity=0.0000, params=11.17M
[FT 1/5] train_acc=100.00%  test_acc=95.54%
[FT 2/5] train_acc=100.00%  test_acc=95.45%
[FT 3/5] train_acc=99.99%  test_acc=95.36%
[FT 4/5] train_acc=100.00%  test_acc=95.53%
[FT 5/5] train_acc=100.00%  test_acc=95.64%
[DONE][OBD] sp=0.0 acc=95.64% → results/obd/resnet18_obd_s2_sp0.0.pth
[BASELINE] acc=95.58%
[OBD-PRUNED] sparsity=0.0000, params=11.17M
[FT 1/5] train_acc=100.00%  test_acc=95.49%
[FT 2/5] train_acc=100.00%  test_acc=95.51%
[FT 3/5] train_acc=100.00%  test_acc=95.46%
[FT 4/5] train_acc=100.00%  test_acc=95.46%
[FT 5/5] train_acc=100.00%  test_acc=95.47%
[DONE]

In [26]:
# --------------------------
# H_diag ≈ E[g^2] 추정 (OBD와 동일)
# --------------------------
def estimate_hessian_diag_eg2(model, loader, device, max_batches=100):
    """
    H_ii ≈ E[g_i^2]  (Conv/Linear의 weight만 대상)
    """
    model.train()  # BN 통계 안정화
    ce = nn.CrossEntropyLoss()

    # 프루닝 대상 weight만 모으기
    modules, params = [], []
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)) and m.weight.requires_grad:
            modules.append(m)
            params.append(m.weight)

    hdiag = [torch.zeros_like(p, device=device) for p in params]
    steps = 0

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        model.zero_grad(set_to_none=True)
        logits = model(x)
        loss = ce(logits, y)
        grads = torch.autograd.grad(loss, params, retain_graph=False, create_graph=False)
        for hd, g in zip(hdiag, grads):
            hd += (g.detach() ** 2)
        steps += 1
        if steps >= max_batches:
            break

    for i in range(len(hdiag)):
        hdiag[i] /= max(1, steps)
    return hdiag, modules  # modules[i].weight ↔ hdiag[i]

# --------------------------
# OBS-fast 프루닝 (글로벌)
# --------------------------
def obs_fast_prune_global(model, loader, device, sparsity=0.9, lambda_damp=1e-3, max_batches=100, hdiag_precomputed=None):
    """
    점수: S_i ≈ 0.5 * w_i^2 * (H_ii + λ)
    가장 작은 점수부터 'sparsity' 비율만큼 0으로.
    hdiag_precomputed를 주면 재추정 생략 가능(OBD에서 계산한 것을 재사용).
    """
    if hdiag_precomputed is None:
        hdiag_list, modules = estimate_hessian_diag_eg2(model, loader, device, max_batches=max_batches)
    else:
        hdiag_list, modules = hdiag_precomputed

    # 전 레이어 점수 모으기
    scores_flat = []
    for m, Hd in zip(modules, hdiag_list):
        if Hd is None:   # ✅ None 값 건너뛰기
            continue
        W = m.weight.data
        score = 0.5 * (W ** 2) * (Hd + lambda_damp)
        scores_flat.append(score.flatten())
    scores_flat = torch.cat(scores_flat)

    # 전역 임계값
    k = int(scores_flat.numel() * sparsity)
    if k > 0:
        thresh = torch.topk(scores_flat, k, largest=False).values.max()
    else:
        thresh = scores_flat.min() - 1  # 아무 것도 제거하지 않음

    # 마스킹 적용(0으로 설정)
    for m, Hd in zip(modules, hdiag_list):
        W = m.weight.data
        s_local = 0.5 * (W**2) * (Hd + lambda_damp)
        W[s_local <= thresh] = 0.0

    return model, (hdiag_list, modules)

# --------------------------
# 통합 실행: ckpt 로드 → OBS-fast → (짧게) 파인튜닝 → 저장
# --------------------------
def run_obs_once(ckpt_path, save_path, sparsity=0.9, seed=1, finetune_epochs=5, lr=1e-2,
                 max_batches=100, lambda_damp=1e-3, reuse_hdiag=False, hdiag_cache=None):
    """
    kuangliu/pytorch-cifar 포맷('net') 지원.
    reuse_hdiag=True 이고 hdiag_cache가 있으면 E[g^2] 재계산 생략.
    """
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 데이터
    train_loader, test_loader = get_loaders(batch_train=256, batch_test=512, workers=2)

    # 모델 & ckpt 로드
    model = resnet.ResNet18().to(device)
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt.get("net", ckpt.get("model", ckpt))
    new_state = OrderedDict((k.replace("module.","",1) if k.startswith("module.") else k, v) for k,v in state.items())
    model.load_state_dict(new_state, strict=True)

    # 프루닝 전 성능
    base_acc, _ = evaluate(model, test_loader, device)
    print(f"[BASELINE] acc={base_acc:.2f}%")

    # OBS-fast 프루닝 (E[g^2] 신규/재사용)
    hdiag_pre = hdiag_cache if reuse_hdiag else None
    model, hdiag_pack = obs_fast_prune_global(
        model, train_loader, device,
        sparsity=sparsity, lambda_damp=lambda_damp, max_batches=max_batches,
        hdiag_precomputed=hdiag_pre
    )

    total, nnz = count_params(model)
    print(f"[OBS-PRUNED] sparsity={1 - nnz/total:.4f}, params={total/1e6:.2f}M, lambda={lambda_damp}")

    # 짧은 파인튜닝
    if finetune_epochs > 0:
        ce = nn.CrossEntropyLoss()
        opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
        for ep in range(finetune_epochs):
            model.train()
            loss_sum = cor = tot = 0
            for x,y in train_loader:
                x,y = x.to(device), y.to(device)
                opt.zero_grad(set_to_none=True)
                logits = model(x); loss = ce(logits,y)
                loss.backward(); opt.step()
                loss_sum += loss.item()*x.size(0)
                cor += (logits.argmax(1)==y).sum().item()
                tot += y.numel()
            tr_acc = 100*cor/tot
            te_acc, _ = evaluate(model, test_loader, device)
            print(f"[FT {ep+1}/{finetune_epochs}] train_acc={tr_acc:.2f}%  test_acc={te_acc:.2f}%")

    # 저장 (kuangliu 포맷)
    acc, _ = evaluate(model, test_loader, device)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save({
        "net": model.state_dict(),
        "acc": float(acc),
        "base_acc": float(base_acc),
        "sparsity": float(sparsity),
        "seed": int(seed),
        "params_total": int(total),
        "params_nnz": int(nnz),
        "method": "obs",
        "lambda_damp": float(lambda_damp),
        "hdiag_batches": int(max_batches),
    }, save_path)
    print(f"[DONE][OBS] sp={sparsity} acc={acc:.2f}% → {save_path}")
    return hdiag_pack  # (hdiag_list, modules)

In [27]:
ckpt_path = "checkpoint/ResNet18.pth"   # 베이스라인 ckpt
os.makedirs("results/obs", exist_ok=True)

sparsities = [0.00, 0.20, 0.40, 0.60, 0.80, 0.90, 0.95, 0.98]
seeds = [1, 2, 3]

# λ는 1e-3 ~ 1e-2 구간이 무난. 시간 없으면 1e-3로 통일.
lambda_damp = 1e-3
max_batches = 100

for sp in sparsities:
    for s in seeds:
        run_obs_once(
            ckpt_path=ckpt_path,
            save_path=f"results/obs/resnet18_obs_s{s}_sp{sp}.pth",
            sparsity=sp, seed=s,
            finetune_epochs=5, lr=1e-2,
            max_batches=max_batches, lambda_damp=lambda_damp
        )

[BASELINE] acc=95.58%
[OBS-PRUNED] sparsity=0.0000, params=11.17M, lambda=0.001
[FT 1/5] train_acc=99.99%  test_acc=95.61%
[FT 2/5] train_acc=99.99%  test_acc=95.47%
[FT 3/5] train_acc=100.00%  test_acc=95.52%
[FT 4/5] train_acc=99.99%  test_acc=95.48%
[FT 5/5] train_acc=100.00%  test_acc=95.36%
[DONE][OBS] sp=0.0 acc=95.36% → results/obs/resnet18_obs_s1_sp0.0.pth
[BASELINE] acc=95.58%
[OBS-PRUNED] sparsity=0.0000, params=11.17M, lambda=0.001
[FT 1/5] train_acc=100.00%  test_acc=95.55%
[FT 2/5] train_acc=100.00%  test_acc=95.45%
[FT 3/5] train_acc=99.99%  test_acc=95.36%
[FT 4/5] train_acc=100.00%  test_acc=95.58%
[FT 5/5] train_acc=100.00%  test_acc=95.62%
[DONE][OBS] sp=0.0 acc=95.62% → results/obs/resnet18_obs_s2_sp0.0.pth
[BASELINE] acc=95.58%
[OBS-PRUNED] sparsity=0.0000, params=11.17M, lambda=0.001
[FT 1/5] train_acc=100.00%  test_acc=95.49%
[FT 2/5] train_acc=100.00%  test_acc=95.49%
[FT 3/5] train_acc=100.00%  test_acc=95.47%
[FT 4/5] train_acc=100.00%  test_acc=95.46%
[FT 5/5]