In [1]:
import sys, os, torch, numpy as np
from modules.ig_utils import ig_analysis
from modules.model_utils import model_loss_optimizer_resnet

device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

model, _, _ = model_loss_optimizer_resnet(
    device=device, weight_pos=1.0, lr=1e-3, show_model_summary=False, resnet_depth=152
)
ckpt_path = "/zdisk/users/ext_user_03/01_yschoi/project_01_FVH_detection/02_results/fhv_resnet152_with_cam/best_model.pt"  # ← 너의 체크포인트
ckpt = torch.load(
    ckpt_path,
    map_location=device if isinstance(device, str) else None,
    weights_only=False 
    ) if os.path.exists(ckpt_path) else None

if ckpt and 'model' in ckpt:
    model.load_state_dict(ckpt['model'])
model.eval();

In [2]:
path_test_data = '/zdisk/users/ext_user_03/01_yschoi/project_01_FVH_detection/02_results/fhv_resnet152_with_cam/test_data.npy'
test_data = np.load(path_test_data)

path_test_label = '/zdisk/users/ext_user_03/01_yschoi/project_01_FVH_detection/02_results/fhv_resnet152_with_cam/test_y_true.npy'
test_label = np.load(path_test_label)

from modules.data_utils import SliceDataset
from torch.utils.data import Dataset, DataLoader

batch_size = 1
num_workers = 0
ds_test  = SliceDataset(test_data, test_label)
dl_test  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False,
                      num_workers=num_workers, pin_memory=False,
                      persistent_workers=False if num_workers == 0 else True)

from modules.ig_utils import ig_on_batch

model.eval()  # 반드시 eval

idx = 0
for images, labels in dl_test:           # images: (B,1,672,672)
    # 배치 한 번에 IG 계산 + PNG 저장
    res_list = ig_on_batch(
        model=model,
        batch_images=images,            # (B,1,H,W) 그대로 전달
        device=device,
        out_dir="./02_results/fhv_resnet152_with_cam/ig_test", # PNG는 runs/ig_batch_demo/ig_batch/*.png 로 저장
        fname_index=idx,
        ig_steps=16,
        ig_target="pos",                # 1-logit이면 sign(+), multi-logit이면 pos_class_idx 사용
        pos_class_idx=1,                # multi-logit에서 양성 클래스 인덱스가 1이라면
        baseline_mode="zeros",          # 'zeros'|'constant'|'blur'
        use_noise_tunnel=False,         # 필요시 True로
        viz_sign="positive",            # 'positive'|'negative'|'both'|'absolute'
        display_mode="percentile",      # 'percentile'|'zscore'
        save_raw_attr=False
    )
    idx += 1

    # 선택: res_list[i]에는 각 샘플의 attr/overlay/score/delta 등이 들어있음
    # 여기서는 이미지 파일 저장만 쓰는 경우 res_list를 굳이 보관할 필요 없음.