In [1]:
!pip install -q torch torchvision timm datasets

In [2]:
# =========================================================
# 0) 설치 (필요 시)
# =========================================================
# !pip install -q timm scikit-image

# =========================================================
# 1) 데이터: Imagenette 다운로드 → ImageFolder 로드
# =========================================================
import os, tarfile, urllib.request, pathlib, random
import torch, numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

root = pathlib.Path("./data")
root.mkdir(exist_ok=True)
tgz = root / "imagenette2-320.tgz"
dst = root / "imagenette2-320"

if not dst.exists():
    if not tgz.exists():
        url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
        print("Downloading:", url)
        urllib.request.urlretrieve(url, tgz.as_posix())
    print("Extracting...")
    with tarfile.open(tgz, "r:gz") as tf:
        tf.extractall(root)

train_dir = dst / "train"
val_dir   = dst / "val"

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
val_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

train_set = datasets.ImageFolder(train_dir.as_posix(), transform=train_tf)
val_set   = datasets.ImageFolder(val_dir.as_posix(),   transform=val_tf)
BATCH_SIZE = 32
NW = max(2, os.cpu_count() // 2) if os.cpu_count() else 2
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NW, pin_memory=True, persistent_workers=True)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NW, pin_memory=True, persistent_workers=True)

print("Train/Val sizes:", len(train_set), len(val_set))
print("Classes:", train_set.classes)


Device: cuda
Downloading: https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
Extracting...


  tf.extractall(root)


Train/Val sizes: 9469 3925
Classes: ['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916', 'n03417042', 'n03425413', 'n03445777', 'n03888257']


In [3]:
# =========================================================
# 2) 모델: timm ResNet-50 (로짓 반환)
# =========================================================
import timm
model = timm.create_model("resnet50", pretrained=True, num_classes=1000).eval().to(device)

# 배치 하나 미리 가져오기
xb_cpu, yb_cpu = next(iter(val_loader))
xb = xb_cpu.to(device, non_blocking=True)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# =========================================================
# 3) 유틸: 디노멀라이즈 / 텐서→이미지
# =========================================================
import matplotlib.pyplot as plt

def denorm(img_t):
    mean = torch.tensor(IMAGENET_MEAN, device=img_t.device).view(3,1,1)
    std  = torch.tensor(IMAGENET_STD,  device=img_t.device).view(3,1,1)
    return img_t * std + mean

def to_numpy_img(img_t):
    """(C,H,W) normalized tensor -> (H,W,3) numpy [0,1]"""
    img = denorm(img_t).clamp(0,1).permute(1,2,0).detach().cpu().numpy()
    return img


In [5]:
# =========================================================
# 4) IG 미니 구현 (+Smooth-IG 옵션)
# =========================================================
from torch import nn

@torch.no_grad()
def _alphas(steps:int, device):
    # 중점 규칙 (midpoint)
    return torch.linspace(0, 1, steps+1, device=device)[:-1] + 0.5/steps

def integrated_gradients(
    model: nn.Module,
    x: torch.Tensor,               # (B,C,H,W)
    target_idx: torch.Tensor,      # (B,)
    baseline: torch.Tensor = None, # (B,C,H,W)
    steps: int = 64,
    chunk: int = 64,
):
    model.eval()
    device = x.device
    B = x.size(0)
    if baseline is None:
        baseline = torch.zeros_like(x)

    diff = x - baseline
    al = _alphas(steps, device)                         # (steps,)
    total = torch.zeros_like(x, dtype=torch.float32)    # (B,C,H,W)

    for s in range(0, steps, chunk):
        a = al[s: s+chunk].view(-1, *([1]*x.dim()))     # (S,1,1,1)
        x_path = (baseline.unsqueeze(0) + a * diff.unsqueeze(0)).clone()
        x_path.requires_grad_(True)                     # (S,B,C,H,W)
        SB = x_path.shape[0] * x_path.shape[1]

        y = model(x_path.view(SB, *x.shape[1:]))       # (S*B, num_classes)
        y = y.view(x_path.shape[0], x_path.shape[1], -1)   # (S,B,C_cls)
        tgt = target_idx.view(1,B).expand(y.size(0),B)      # (S,B)
        f = y.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)     # (S,B)

        grads = torch.autograd.grad(outputs=f.sum(), inputs=x_path,
                                    create_graph=False, retain_graph=False, allow_unused=False)[0]  # (S,B,C,H,W)
        avg_grads = grads.mean(dim=0)                              # (B,C,H,W)
        total += avg_grads
        del x_path, y, tgt, f, grads, avg_grads

    ig = diff * (total / (steps / chunk))
    return ig

def smooth_ig(
    model, x, target_idx, baseline=None, steps=64, n_samples=8, sigma=0.08, chunk=64
):
    if baseline is None: baseline = torch.zeros_like(x)
    acc = torch.zeros_like(x)
    for _ in range(n_samples):
        eps = torch.randn_like(x) * sigma
        acc += integrated_gradients(model, x+eps, target_idx, baseline+eps, steps=steps, chunk=chunk)
    return acc / n_samples

def completeness_check(model, x, baseline, target_idx):
    with torch.no_grad():
        fx = model(x).gather(-1, target_idx.view(-1,1)).squeeze(-1)
        f0 = model(baseline).gather(-1, target_idx.view(-1,1)).squeeze(-1)
    return fx - f0


In [6]:
# =========================================================
# 5) IG → 단일맵 변환 / 정규화 / 마스크 / 컨투어
# =========================================================
import numpy as np
from skimage import measure, segmentation

def ig_reduce(ig_sample: torch.Tensor, how="sum_abs"):
    # (C,H,W) -> (H,W)
    if how == "sum_abs":
        m = ig_sample.abs().sum(dim=0)
    elif how == "l2":
        m = torch.linalg.vector_norm(ig_sample, dim=0)
    elif how == "sum":
        m = ig_sample.sum(dim=0)
    else:
        raise ValueError(how)
    return m

def normalize_01(m: torch.Tensor, clip_p=0.99):
    v = m.flatten()
    hi = torch.quantile(v, clip_p).clamp(min=1e-8)
    m = m.clamp(0, hi) / hi
    return m

def threshold_percentile(m01: torch.Tensor, keep=0.15):
    thr = torch.quantile(m01.flatten(), 1 - keep)
    mask = (m01 >= thr).float()
    return mask, float(thr)

def draw_contours_on(ax, m01: torch.Tensor, level=0.6, color="lime", lw=1.8):
    arr = m01.cpu().numpy()
    cs = measure.find_contours(arr, level=level)
    for c in cs:
        ax.plot(c[:,1], c[:,0], color=color, linewidth=lw)


In [7]:
# =========================================================
# 6) 한 배치에 대해: 타깃 선택 → IG 계산(Smooth-IG 권장)
# =========================================================
with torch.no_grad():
    logits = model(xb)                           # (B,1000)
top1_idx = logits.softmax(-1).argmax(-1)         # (B,)
baseline = torch.zeros_like(xb)

# 표준 IG 또는 Smooth-IG 중 택1
# ig = integrated_gradients(model, xb, target_idx=top1_idx, baseline=baseline, steps=128, chunk=64)
ig = smooth_ig(model, xb, target_idx=top1_idx, baseline=baseline, steps=128, n_samples=8, sigma=0.06, chunk=64)

lhs = ig.view(ig.size(0), -1).sum(dim=1)
rhs = completeness_check(model, xb, baseline, top1_idx)
print("Completeness gap (mean±std):", float((lhs - rhs).abs().mean()), float((lhs - rhs).abs().std()))



OutOfMemoryError: CUDA out of memory. Tried to allocate 3.06 GiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 15.18 GiB is allocated by PyTorch, and 81.96 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# =========================================================
# 7) 시각화 5종: 원본 / heat-only / 원본+overlay / 원본+Top-k mask / 원본+contour
# =========================================================
def show_panel(x_i, ig_i, keep=0.15, clip_p=0.99, overlay_alpha=0.45, cmap="magma", contour_level=0.6):
    """
    x_i: (C,H,W) normalized image
    ig_i: (C,H,W) IG
    """
    img = to_numpy_img(x_i)
    m   = ig_reduce(ig_i, "sum_abs")
    m01 = normalize_01(m, clip_p=clip_p)
    mask,_ = threshold_percentile(m01, keep=keep)

    fig, axs = plt.subplots(1,5, figsize=(15,3))
    # (1) 원본
    axs[0].imshow(img); axs[0].set_title("Original"); axs[0].axis("off")
    # (2) heat-only
    axs[1].imshow(m01.cpu().numpy(), cmap=cmap, vmin=0, vmax=1); axs[1].set_title("Heat-only"); axs[1].axis("off")
    # (3) 원본 + overlay
    axs[2].imshow(img); axs[2].imshow(m01.cpu().numpy(), cmap=cmap, alpha=overlay_alpha, vmin=0, vmax=1)
    axs[2].set_title("Overlay"); axs[2].axis("off")
    # (4) 원본 + Top-k mask(하얀색)
    axs[3].imshow(img); axs[3].imshow(mask.cpu().numpy(), cmap="gray", alpha=0.45, vmin=0, vmax=1)
    axs[3].set_title(f"Top-{int(keep*100)}% mask"); axs[3].axis("off")
    # (5) 원본 + 컨투어
    axs[4].imshow(img); draw_contours_on(axs[4], m01, level=contour_level, color="lime", lw=2.0)
    axs[4].set_title(f"Contour @{contour_level:.2f}"); axs[4].axis("off")
    plt.tight_layout(); plt.show()

# 배치의 앞 4장만 시각화
for i in range(min(4, xb.size(0))):
    show_panel(xb[i], ig[i], keep=0.15, clip_p=0.99, overlay_alpha=0.45, cmap="magma", contour_level=0.6)
