# Adversarial classification demo (clean vs attacks)

In [1]:

from utils.dataset import get_dataloader, BatchedImageIterable
from utils.nets import load_network
from utils.adversarial_attacks import fgsm_attack, pgd_attack
from utils.vulnerability_map import get_vulnerability_map, visualize_vulnerability_map, plot_triple_res
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt


In [2]:

# Config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
weights_dir = './weights/weights_AdversarialRobustnessCLIP'
images_inpainted = './data/COCO_inpainted'
images_real = './data/COCO_real'
masks_dir = './data/masks'
batch_size = 1


In [None]:

# Load detector and data loaders

detector = load_network('OJHA_latent_clip', weights_dir).to(device).eval()
is_resnet = type(detector).__name__ == 'ResNet'

inpainted_loader = get_dataloader(
    images_inpainted,
    masks_dir,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    model=detector,
    transform_img=detector.preprocess,
)

real_dataset = BatchedImageIterable(
    images_dir=images_real,
    batch_size=batch_size,
    shuffle=True,
    transform_img=detector.preprocess,
)
real_loader = DataLoader(real_dataset, batch_size=None, num_workers=1)


In [4]:

# Helpers

def classifier_probs(images: torch.Tensor):
    logits = detector(images.to(device))
    probs = torch.softmax(logits, dim=1)
    return logits, probs[:, 1].detach().cpu()


def classify_with_attacks(images: torch.Tensor, label: int, fgsm_eps=4/255, pgd_eps=8/255, pgd_alpha=2/255, pgd_steps=10):
    base = images.to(device)
    labels = torch.full((base.size(0),), label, dtype=torch.long, device=device)
    clamp_range = (float(base.min()), float(base.max()))

    _, clean_prob = classifier_probs(base)
    adv_fgsm = fgsm_attack(detector, base, labels, eps=fgsm_eps, post_clamp=clamp_range)
    _, fgsm_prob = classifier_probs(adv_fgsm)

    adv_pgd = pgd_attack(
        detector,
        base,
        labels,
        eps=pgd_eps,
        alpha=pgd_alpha,
        steps=pgd_steps,
        random_start=True,
        post_clamp=clamp_range,
    )
    _, pgd_prob = classifier_probs(adv_pgd)

    print(f"clean prob(fake): {clean_prob.tolist()}")
    print(f"fgsm  prob(fake): {fgsm_prob.tolist()}")
    print(f"pgd   prob(fake): {pgd_prob.tolist()}")

    return adv_fgsm, adv_pgd


def show_vulnerability(image: torch.Tensor, mask: torch.Tensor, original: torch.Tensor, title: str = "clean"):
    vmap, _ = get_vulnerability_map(image, mask, detector, is_resnet=is_resnet, device=device)
    vis = visualize_vulnerability_map(vmap, original)
    plot_triple_res(original, vis, mask)
    _, prob = classifier_probs(image)
    print(f"[{title}] prob(fake): {prob.tolist()}")


In [None]:

# Inpainted sample: label=1 (fake)
inpaint_image, inpaint_mask, inpaint_orig = next(iter(inpainted_loader))
print(f'inpainted batch -> image: {tuple(inpaint_image.shape)}, mask: {tuple(inpaint_mask.shape)}')
show_vulnerability(inpaint_image, inpaint_mask, inpaint_orig, title='inpainted clean')
_ = classify_with_attacks(inpaint_image, label=1)


In [None]:

# Real sample: label=0 (real)
real_images, real_originals = next(iter(real_loader))
real_mask = torch.zeros((real_images.size(0), 1, real_images.size(2), real_images.size(3)))
print(f'real batch -> image: {tuple(real_images.shape)}')
show_vulnerability(real_images, real_mask, real_originals, title='real clean')
_ = classify_with_attacks(real_images, label=0)
