In [7]:
import sys
import os
sys.path.append(os.path.abspath("../"))  # or "../../" depending on location

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
from utils.load_models import get_CNN, get_CNN_small, get_resnet

In [None]:

# --- 設定 ---
epsilon = 0.03
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_path = "../cifar10_fgsm_attack.pt"

# --- データ前処理（標準化なし） ---
transform_plain = transforms.ToTensor()


# ダウンロード済みのデータを指定して読み込み
root = "../data/cifar-10-batches-py"
test_dataset = datasets.CIFAR10(
    root=root, 
    train=False, 
    download=False, 
    transform=transform_plain
)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# --- FGSM攻撃関数 ---
def fgsm_attack(model, images, labels, epsilon):
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    images.requires_grad = True
    outputs = model(images)
    loss = F.cross_entropy(outputs, labels)
    model.zero_grad()
    loss.backward()
    grad_sign = images.grad.data.sign()
    adv_images = images + epsilon * grad_sign
    adv_images = torch.clamp(adv_images, 0, 1)
    return adv_images.detach()

# --- 敵対画像生成と保存 ---
def generate_and_save_adversarial_dataset(model,save_path,device = "cpu",):
    model.eval()
    model.to(device)

    orig_list = []
    adv_list = []
    label_list = []

    for images, labels in test_loader:
        adv_images = fgsm_attack(model, images, labels, epsilon)
        orig_list.append(images.cpu())
        adv_list.append(adv_images.cpu())
        label_list.append(labels.cpu())

    orig_tensor = torch.cat(orig_list)
    adv_tensor = torch.cat(adv_list)
    label_tensor = torch.cat(label_list)

    torch.save({
        'original': orig_tensor,
        'adversarial': adv_tensor,
        'labels': label_tensor
    }, save_path)

    print(f"Saved adversarial dataset to: {save_path}")


Files already downloaded and verified


In [None]:

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model = get_resnet
generate_and_save_adversarial_dataset(model,save_path="../cifar10_fgsm_attack_byResNet.pt",device=device)


  model.load_state_dict(torch.load("resnet18_cifar10.pth"))


KeyboardInterrupt: 