In [2]:
import random
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
import foolbox
from foolbox import PyTorchModel
from foolbox.attacks import (
    LinfBasicIterativeAttack,
    FGSM,
    PGD,
    L2DeepFoolAttack,
    L2CarliniWagnerAttack,
)
from torchvision import transforms
from PIL import Image
import torch.nn as nn


def _load_checkpoint_into_model(model, checkpoint_path, map_location="cpu", strict=True, verbose=False):
    ckpt = torch.load(checkpoint_path, map_location=map_location)

    if isinstance(ckpt, dict):
        for candidate in ("model_state", "model_state_dict", "state_dict", "state"):
            if candidate in ckpt:
                state = ckpt[candidate]
                break
        else:
            if all(isinstance(v, (torch.Tensor, type(None))) or hasattr(v, "shape") for v in ckpt.values()):
                state = ckpt
            else:
                nested = None
                for v in ckpt.values():
                    if isinstance(v, dict):
                        if nested is None or len(v) > len(nested):
                            nested = v
                state = nested if nested is not None else ckpt
    else:
        state = ckpt

    if not isinstance(state, dict):
        raise ValueError(f"Checkpoint {checkpoint_path} does not contain a state-dict (found type: {type(state)})")

    keys = list(state.keys())
    prefix = None
    for p in ("module.", "model."):
        cnt = sum(1 for k in keys if k.startswith(p))
        if cnt >= max(1, len(keys) // 2):
            prefix = p
            break

    if prefix:
        new_state = OrderedDict((k[len(prefix):], v) for k, v in state.items())
    else:
        new_state = state

    try:
        model.load_state_dict(new_state, strict=strict)
        if verbose:
            print(f"Loaded checkpoint {checkpoint_path} (strict={strict}).")
        return model
    except Exception as e:
        if verbose:
            print(f"Strict load failed: {e}. Trying non-strict load ...")
        res = model.load_state_dict(new_state, strict=False)
        if verbose:
            print("Loaded with strict=False. Missing keys:", getattr(res, "missing_keys", None))
            print("Unexpected keys:", getattr(res, "unexpected_keys", None))
        return model


def run_attack(
    model,
    checkpoint_path,
    preprocessing = dict(
        mean=[0.4812775254249573, 0.4674863815307617, 0.4093940854072571],
        std=[0.19709135591983795, 0.1933959424495697, 0.19051066040992737],
        axis=-3
    ),
    attack="fgsm",                  
    csv_path="./data/clean_data.csv",
    imdir="./images",
    outdir="./data",
    use_cuda=True,
    model_device=None,
    load_map_location=None,
    batch_size=32,
    verbose=True,
):
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    if model_device is None:
        device = torch.device("cuda:0" if (torch.cuda.is_available() and use_cuda) else "cpu")
    else:
        device = torch.device(model_device)

    if load_map_location is None:
        ckpt_map = "cpu" if device.type == "cpu" else device
    else:
        ckpt_map = load_map_location

    if verbose:
        print("Loading checkpoint:", checkpoint_path, "map_location:", ckpt_map)
    model = _load_checkpoint_into_model(model, checkpoint_path, map_location=ckpt_map, strict=False, verbose=verbose)

    model = model.to(device).eval()
    fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)

    to_tensor = transforms.ToTensor()

    def load_image_tensor(rel_path, base_dir):
        p = Path(base_dir) / str(rel_path).lstrip("/")
        img = Image.open(p).convert("RGB")
        return to_tensor(img)

    if attack == "fgsm":
        atk = FGSM()
        epsilons = [0.03]
    elif attack == "bim":
        atk = LinfBasicIterativeAttack()
        epsilons = [0.03]
    elif attack == "pgd":
        atk = PGD()
        epsilons = [0.03]
    elif attack == "df":
        atk = L2DeepFoolAttack()
        epsilons = None
    elif attack == "cw":
        atk = L2CarliniWagnerAttack(steps=1000)
        epsilons = None
    else:
        raise ValueError("Unknown attack: " + str(attack))

    df = pd.read_csv(csv_path)

    df_correct = df[df["true_idx"] == df["pred_idx"]].copy()
    df_correct['original_index'] = df_correct.index
    total_candidates = len(df_correct)

    if verbose:
        print(f"CSV: {csv_path} rows={len(df)}; correctly-classified candidates={total_candidates}; epsilons={epsilons}")

    out_base = Path(outdir)
    out_base.mkdir(parents=True, exist_ok=True)

    attempted_count = 0
    success_count = 0
    saved_count = 0
    
    candidate_results = {}

    for i in tqdm(range(0, total_candidates, batch_size), desc="Processing Batches"):
        batch_df = df_correct.iloc[i:i+batch_size]
        
        images, labels, batch_info = [], [], []
        for _, row in batch_df.iterrows():
            try:
                img_tensor = load_image_tensor(row["rel_path"], imdir)
                images.append(img_tensor)
                labels.append(int(row["true_idx"]))
                batch_info.append(row)
            except Exception as e:
                if verbose:
                    print(f"Skipping image {row['rel_path']} due to loading error: {e}")
                candidate_results[row['original_index']] = {"success": False}
        
        if not images:
            continue
            
        images_t = torch.stack(images).to(device)
        labels_t = torch.tensor(labels, device=device)
        
        try:
            _, advs_t, success_t = atk(fmodel, images_t, criterion=foolbox.criteria.Misclassification(labels_t), epsilons=epsilons)
            attempted_count += len(images)
        except Exception as e:
            if verbose:
                print(f"Attack failed for a batch: {e}")
            for row in batch_info:
                candidate_results[row['original_index']] = {"success": False}
            continue

        for j, row in enumerate(batch_info):
            original_idx = row['original_index']
            
            if epsilons is not None:
                is_successful = bool(success_t[0][j].item())
                adv_tensor_for_item = advs_t[0][j]
            else:
                is_successful = bool(success_t[j].item())
                adv_tensor_for_item = advs_t[j]

            candidate_results[original_idx] = {"success": is_successful}
            
            if is_successful:
                success_count += 1
                adv_cpu = adv_tensor_for_item.cpu()
                
                try:
                    rel_clean = str(row["rel_path"]).lstrip("/")
                    out_rel = Path(rel_clean).with_suffix(".pt")
                    out_path = out_base / attack / out_rel
                    out_path.parent.mkdir(parents=True, exist_ok=True)
                    torch.save(adv_cpu, out_path)
                    saved_count += 1
                except Exception as e:
                    if verbose:
                        print(f"Failed to save adv tensor for {row['rel_path']}: {e}")

    df['success'] = df.index.map(lambda idx: candidate_results.get(idx, {}).get('success', False))
    
    meta_df = df[["rel_path", "true_idx", "pred_idx", "true_class", "pred_class", "success"]]
    meta_csv_path = out_base / f"metadata_{attack}.csv"
    meta_df.to_csv(meta_csv_path, index=False)
    if verbose:
        print(f"Wrote metadata CSV to: {meta_csv_path}")

    # compute rates
    success_rate_over_candidates = success_count / total_candidates if total_candidates > 0 else 0.0
    success_rate_over_attempts = success_count / attempted_count if attempted_count > 0 else 0.0

    if verbose:
        print(f"Attempted attacks: {attempted_count}/{total_candidates} candidates.")
        print(f"Successes: {success_count}.")
        print(f"Saved adv files: {saved_count}.")
        print(f"Success rate (over candidates): {success_rate_over_candidates:.4f} ({success_count}/{total_candidates})")
        print(f"Success rate (over attempted):  {success_rate_over_attempts:.4f} ({success_count}/{attempted_count if attempted_count > 0 else 0})")

    return success_rate_over_candidates

In [3]:
from torchvision.models import VGG16_BN_Weights
from torchvision import models, transforms

model = models.vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)
model.avgpool = nn.AdaptiveAvgPool2d((7,7))
model.classifier[6] = nn.Linear(4096, 7)

img_size = 128
checkpoint_path = f'../01-CleanModel/Models/{img_size}x{img_size}/best_min_acc_vgg16_{img_size}x{img_size}_Model-2.pth'
csv_path = '../01-CleanModel/Evaluate/correct_in_all_models.csv'
imdir = f'../01-CleanModel/Dataset/{img_size}x{img_size}'
outdir = f'./generated_images-test/{img_size}x{img_size}'

In [3]:
run_attack(
    model=model,
    checkpoint_path=checkpoint_path,                  
    attack="fgsm",                    
    csv_path=csv_path,
    imdir=imdir,
    outdir=outdir,
    batch_size=4,
    verbose = True)

Loading checkpoint: ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth map_location: cuda:0
Loaded checkpoint ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth (strict=False).
CSV: ../01-CleanModel/Evaluate/correct_in_all_models.csv rows=2225; correctly-classified candidates=2225; epsilons=[0.03]


Processing Batches: 100%|████████████████████████████████████████████████████████████| 557/557 [01:42<00:00,  5.43it/s]

Wrote metadata CSV to: generated_images-test\128x128\metadata_fgsm.csv
Attempted attacks: 2225/2225 candidates.
Successes: 1324.
Saved adv files: 1324.
Success rate (over candidates): 0.5951 (1324/2225)
Success rate (over attempted):  0.5951 (1324/2225)





0.5950561797752809

In [4]:
run_attack(
    model=model,
    checkpoint_path=checkpoint_path,                  
    attack="bim",                    
    csv_path=csv_path,
    imdir=imdir,
    outdir=outdir,
    batch_size=4,
    verbose = True)

Loading checkpoint: ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth map_location: cuda:0
Loaded checkpoint ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth (strict=False).
CSV: ../01-CleanModel/Evaluate/correct_in_all_models.csv rows=2225; correctly-classified candidates=2225; epsilons=[0.03]


Processing Batches: 100%|████████████████████████████████████████████████████████████| 557/557 [03:36<00:00,  2.58it/s]

Wrote metadata CSV to: generated_images-test\128x128\metadata_bim.csv
Attempted attacks: 2225/2225 candidates.
Successes: 1723.
Saved adv files: 1723.
Success rate (over candidates): 0.7744 (1723/2225)
Success rate (over attempted):  0.7744 (1723/2225)





0.7743820224719101

In [5]:
run_attack(
    model=model,
    checkpoint_path=checkpoint_path,                  
    attack="pgd",                    
    csv_path=csv_path,
    imdir=imdir,
    outdir=outdir,
    batch_size=4,
    verbose = True)

Loading checkpoint: ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth map_location: cuda:0
Loaded checkpoint ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth (strict=False).
CSV: ../01-CleanModel/Evaluate/correct_in_all_models.csv rows=2225; correctly-classified candidates=2225; epsilons=[0.03]


Processing Batches: 100%|████████████████████████████████████████████████████████████| 557/557 [11:22<00:00,  1.23s/it]

Wrote metadata CSV to: generated_images-test\128x128\metadata_pgd.csv
Attempted attacks: 2225/2225 candidates.
Successes: 1766.
Saved adv files: 1766.
Success rate (over candidates): 0.7937 (1766/2225)
Success rate (over attempted):  0.7937 (1766/2225)





0.7937078651685393

In [4]:
run_attack(
    model=model,
    checkpoint_path=checkpoint_path,                  
    attack="df",                    
    csv_path=csv_path,
    imdir=imdir,
    outdir=outdir,
    batch_size=2,
    verbose = True)

Loading checkpoint: ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth map_location: cuda:0
Loaded checkpoint ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth (strict=False).
CSV: ../01-CleanModel/Evaluate/correct_in_all_models.csv rows=2225; correctly-classified candidates=2225; epsilons=None


Processing Batches: 100%|██████████████████████████████████████████████████████████| 1113/1113 [12:54<00:00,  1.44it/s]

Wrote metadata CSV to: generated_images-test\128x128\metadata_df.csv
Attempted attacks: 2225/2225 candidates.
Successes: 2225.
Saved adv files: 2225.
Success rate (over candidates): 1.0000 (2225/2225)
Success rate (over attempted):  1.0000 (2225/2225)





1.0

In [5]:
run_attack(
    model=model,
    checkpoint_path=checkpoint_path,                  
    attack="cw",                    
    csv_path=csv_path,
    imdir=imdir,
    outdir=outdir,
    batch_size=4,
    verbose = True)

Loading checkpoint: ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth map_location: cuda:0
Loaded checkpoint ../01-CleanModel/Models/128x128/best_min_acc_vgg16_128x128_Model-2.pth (strict=False).
CSV: ../01-CleanModel/Evaluate/correct_in_all_models.csv rows=2225; correctly-classified candidates=2225; epsilons=None


Processing Batches: 100%|█████████████████████████████████████████████████████████| 557/557 [11:04:01<00:00, 71.53s/it]

Wrote metadata CSV to: generated_images-test\128x128\metadata_cw.csv
Attempted attacks: 2225/2225 candidates.
Successes: 2225.
Saved adv files: 2225.
Success rate (over candidates): 1.0000 (2225/2225)
Success rate (over attempted):  1.0000 (2225/2225)





1.0