## Install Dependencies

In [None]:
!pip install torch torchvision torchinfo open_clip_torch Pillow grad-cam

## CLIP ViT-B/32 PGD Cross-Entropy Untargeted

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets
from PIL import Image
from torchinfo import summary
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import open_clip
import numpy as np
import torch.nn as nn
from pytorch_grad_cam import GradCAMPlusPlus

Загрузим предобученную CLIP модель от OpenAI с ViT-B-32 энкодером изображений.

In [None]:
device = "cuda:0"
model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="openai"
)
model = model.eval().to(device)
tokenizer = open_clip.get_tokenizer("ViT-B-32")
summary(model)

Отделим нормализацию от остальных аугментаций, что позволит нам получать визуально похожие изображения, операция нормализации является дифференцируемой, поэтому можем ее использовать внутри PGD атаки.

In [None]:
normalize_transform = preprocess.transforms[-1]
preprocess.transforms = preprocess.transforms[:-1]
preprocess, normalize_transform

Загрузим тест часть датасета CIFAR-10, датасет состоит из изображений объектов и их лейблов, всего 10 классов, изображения низкого разрешения 32x32.

In [None]:
N = 64
batch_size = 64

cifar10 = datasets.CIFAR10("./data", train=False, download=True)

images = [preprocess(Image.fromarray(cifar10.data[i]).convert("RGB")) for i in range(N)]
images = torch.stack(images)
labels = torch.tensor([cifar10.targets[i] for i in range(len(images))])

dataset = TensorDataset(images, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

print(images.shape, labels.shape)

In [None]:
cifar10.idx_to_class = {v: k for k, v in cifar10.class_to_idx.items()}
cifar10.idx_to_class

In [None]:
nrows, ncols, figsize = 4, 8, (15, 10)
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
for i, ax in enumerate(axes.flat):
    img = images[i].cpu().permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(f"{cifar10.idx_to_class[labels[i].item()]}")
    ax.axis("off")
plt.tight_layout()
plt.show()

Для CLIP в zero-shot классификации будем использовать текстовые названия классов, text_features - 10 текстовых эмбеддингов размера 512, соответствующих названиям классов.

In [None]:
text_inputs = torch.cat([tokenizer(l) for l in cifar10.idx_to_class.values()]).to(
    device
)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = F.normalize(text_features, dim=-1)
text_features.shape

PGD - white-box атака, которая итеративно с шагом в $\alpha$ делает шаги по направлению градиента к входным данным, изменение значений модифицируемых примеров ограничено сферой $S$ размера $\epsilon$.
$$x^{t+1} = \prod_{x+S}(x^t + \alpha \cdot sign(\nabla_xL(\theta, x, y)))$$

В качестве функции потерь будем оптимизировать классификационную функцию потерь (в нашем случае Cross-Entropy), причем в ненаправленном варианте мы стараемся повысить значение функции потерь между предсказанием и реальными метками.

In [1]:
def pgd_attack(dataloader, text_features, eps=0.1, alpha=1.0 / 255.0, steps=100):
    all_adv = []

    for batch_idx, (batch_images, batch_labels) in enumerate(dataloader):
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
        delta = torch.zeros_like(batch_images, requires_grad=True)

        pbar = tqdm(range(steps))
        for _ in pbar:
            adv_images = normalize_transform(torch.clamp(batch_images + delta, 0, 1))
            image_features = F.normalize(model.encode_image(adv_images), dim=-1)

            logits = image_features @ text_features.T

            loss = F.cross_entropy(logits, batch_labels)
            loss.backward()

            delta.data = delta.data + alpha * delta.grad.sign()
            delta.data = torch.clamp(delta.data, -eps, eps)
            delta.grad.zero_()

            pbar.set_description(
                f"Batch = {batch_idx+1} / {len(dataloader)}, loss = {loss.item():.3f}"
            )

        all_adv.append(torch.clamp(batch_images + delta, 0, 1).detach().cpu())

    return torch.cat(all_adv)

Получим тензор значений для adversarial примеров после PGD атаки с $\alpha = \frac{1}{255}$, $\epsilon = 0.1$ и $t = 100$.

In [None]:
adversarial_images = pgd_attack(dataloader, text_features, alpha=1.0 / 255.0)
adversarial_images.shape

Функция для расчета классических классификационных метрик - accuracy, precision, recall, f1_score.

In [None]:
def calc_metrics(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )
    return acc, prec, rec, f1


def print_metrics(real_labels, orig_labels, adv_labels):
    y_true = real_labels.cpu().numpy()

    orig_pred = orig_labels.cpu().numpy()
    adv_pred = adv_labels.cpu().numpy()

    orig_metrics = calc_metrics(y_true, orig_pred)
    adv_metrics = calc_metrics(y_true, adv_pred)

    print(f"{'Metric':<12} {'Original':>10} {'Adversarial':>12}")
    print("-" * 36)
    for name, orig, adv in zip(
        ["Accuracy", "Precision", "Recall", "F1-Score"], orig_metrics, adv_metrics
    ):
        print(f"{name:<12} {orig:>10.2%} {adv:>12.2%}")

In [None]:
orig_dataset = TensorDataset(images)
adv_dataset = TensorDataset(adversarial_images)

orig_dataloader = DataLoader(orig_dataset, batch_size=batch_size, shuffle=False)
adv_dataloader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=False)

Функция для получения предсказаний, ищем наиболее близкий эмбеддинг класса по косинусной схожести к эмбеддингу изображения.

In [None]:
def get_labels(orig_dataloader, adv_dataloader, text_features):
    with torch.no_grad():
        orig_labels, adv_labels = [], []

        for orig_images_batch, adv_images_batch in tqdm(
            zip(orig_dataloader, adv_dataloader), total=len(orig_dataloader)
        ):
            orig_images_batch, adv_images_batch = orig_images_batch[0].to(
                device
            ), adv_images_batch[0].to(device)

            orig_features = F.normalize(model.encode_image(orig_images_batch), dim=-1)
            adv_features = F.normalize(model.encode_image(adv_images_batch), dim=-1)

            orig_labels.append((orig_features @ text_features.T).argmax(dim=-1))
            adv_labels.append((adv_features @ text_features.T).argmax(dim=-1))

        orig_labels = torch.concatenate(orig_labels, dim=0)
        adv_labels = torch.concatenate(adv_labels, dim=0)

    return orig_labels, adv_labels

In [None]:
orig_labels, adv_labels = get_labels(orig_dataloader, adv_dataloader, text_features)
print_metrics(labels, orig_labels, adv_labels)

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
for i, ax in enumerate(axes.flat):
    img = adversarial_images[i].cpu().permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(f"{cifar10.idx_to_class[adv_labels[i].item()]}")
    ax.axis("off")
plt.suptitle("Adversarial")
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
for i, ax in enumerate(axes.flat):
    img = images[i].cpu().permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(f"{cifar10.idx_to_class[orig_labels[i].item()]}")
    ax.axis("off")
plt.suptitle("Clean")
plt.tight_layout()
plt.show()

Вспомогательные функции для GradCAM+

In [None]:
def make_vit_reshape_transform(model):
    img_size = getattr(model.visual, "image_size", 224)
    if isinstance(img_size, (tuple, list)):
        img_h = img_size[0]
        img_w = img_size[1] if len(img_size) > 1 else img_size[0]
    else:
        img_h = img_w = int(img_size)

    patch = getattr(model.visual, "patch_size", 32)
    if isinstance(patch, (tuple, list)):
        ph = patch[0]
        pw = patch[1] if len(patch) > 1 else patch[0]
    else:
        ph = pw = int(patch)

    h = img_h // ph
    w = img_w // pw

    def _reshape_transform(tensor):
        b, _, c = tensor.shape
        patches = tensor[:, 1:, :].reshape(b, h, w, c)
        return patches.permute(0, 3, 1, 2).contiguous()

    return _reshape_transform


class CLIPImageWrapper(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip = clip_model

    def forward(self, x):
        return self.clip.encode_image(x)


class CLIPTextTarget:
    def __init__(self, text_features):
        self.text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    def __call__(self, image_features):
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features @ self.text_features.T


def run_clip_gradcam(gradcam, image_tensor, text_features):
    model.eval()
    if hasattr(model, "visual"):
        model.visual.float()

    targets = [CLIPTextTarget(text_features)]

    grayscale_cam = gradcam(
        input_tensor=image_tensor, targets=targets, eigen_smooth=True
    )[0]
    return grayscale_cam


def visualize_gradcam(gradcam, images, adv_images, labels, n):
    _, ax = plt.subplots(nrows=n, ncols=4, figsize=(12, 3 * n))

    for i in range(n):
        class_idx = labels[i].item()

        mask_clean_i = run_clip_gradcam(
            gradcam, images[i : i + 1], text_features[class_idx : class_idx + 1]
        )
        mask_adv_i = run_clip_gradcam(
            gradcam, adv_images[i : i + 1], text_features[class_idx : class_idx + 1]
        )

        img = images[i].cpu().permute(1, 2, 0).numpy()
        adv_img = adv_images[i].cpu().permute(1, 2, 0).numpy()

        ax[i][0].imshow(img)
        ax[i][0].set_xlabel(
            f"Clean image with label = {cifar10.idx_to_class[class_idx]}"
        )

        ax[i][1].imshow(mask_clean_i, cmap="jet")
        ax[i][1].set_xlabel("Clean GradCAM+")

        ax[i][2].imshow(adv_img)
        ax[i][2].set_xlabel("Adversarial image")

        ax[i][3].imshow(mask_adv_i, cmap="jet")
        ax[i][3].set_xlabel("Adversarial GradCAM+")

    plt.tight_layout()
    plt.show()

Отрисуем GradCAM+ heatmap для слоя LayerNorm с последнего блока ViT.

GradCAM вычисляет важность каждой активации в заданном слое для предсказания определенного класса, используя градиенты функции потерь по этим активациям в качестве весов.

In [None]:
target_layer = model.visual.transformer.resblocks[-1].ln_1
reshape_transform = make_vit_reshape_transform(model)

gradcam = GradCAMPlusPlus(
    model=CLIPImageWrapper(model),
    target_layers=[target_layer],
    reshape_transform=reshape_transform,
)

In [None]:
visualize_gradcam(gradcam, images, adversarial_images, labels, 10)

## CLIP ViT-B/32 PGD Cross-Entropy Targeted

Добавим новый 11 класс, который хотим использовать как целевой для атаки, назовем его так, чтобы модель на чистых примерах не притягивала его в большинстве случаев.

In [None]:
cifar10.class_to_idx["<UNKNOWN_TARGET>"] = 10
cifar10.idx_to_class = {v: k for k, v in cifar10.class_to_idx.items()}
cifar10.idx_to_class

Теперь скажем, что целевые метки - это наш новый класс.

In [None]:
target_labels = torch.tensor([10 for _ in range(len(images))])

dataset = TensorDataset(images, target_labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

images.shape, target_labels.shape

In [None]:
text_inputs = torch.cat([tokenizer(l) for l in cifar10.idx_to_class.values()])
text_inputs = text_inputs.to(device)

with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = F.normalize(text_features, dim=-1)

text_features.shape

Модифицируем функцию потерь для PGD, теперь мы хотим на самом деле минимизировать функцию потерь, т.к. целевой класс у нас неверный (наш новый целевой класс), просто добавим знак "-" перед Cross-Entropy для изменения направления оптимизации.

In [None]:
def pgd_attack(dataloader, text_features, eps=0.1, alpha=1.0 / 255.0, steps=100):
    all_adv = []

    for batch_idx, (batch_images, batch_labels) in enumerate(dataloader):
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
        delta = torch.zeros_like(batch_images, requires_grad=True)

        pbar = tqdm(range(steps))
        for _ in pbar:
            adv_images = normalize_transform(torch.clamp(batch_images + delta, 0, 1))
            image_features = F.normalize(model.encode_image(adv_images), dim=-1)

            logits = image_features @ text_features.T

            loss = -1.0 * F.cross_entropy(logits, batch_labels)
            loss.backward()

            delta.data = delta.data + alpha * delta.grad.sign()
            delta.data = torch.clamp(delta.data, -eps, eps)
            delta.grad.zero_()

            pbar.set_description(
                f"Batch = {batch_idx+1} / {len(dataloader)}, loss = {loss.item():.3f}"
            )

        all_adv.append(torch.clamp(batch_images + delta, 0, 1).detach().cpu())

    return torch.cat(all_adv)

In [None]:
adversarial_images = pgd_attack(dataloader, text_features, alpha=1.0 / 255.0)
adversarial_images.shape

In [None]:
adv_dataset = TensorDataset(adversarial_images)
adv_dataloader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=False)

In [None]:
orig_labels, adv_labels = get_labels(orig_dataloader, adv_dataloader, text_features)
print_metrics(labels, orig_labels, adv_labels)

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
for i, ax in enumerate(axes.flat):
    img = adversarial_images[i].cpu().permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(f"{cifar10.idx_to_class[adv_labels[i].item()]}")
    ax.axis("off")
plt.suptitle("Adversarial")
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
for i, ax in enumerate(axes.flat):
    img = images[i].cpu().permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(f"{cifar10.idx_to_class[orig_labels[i].item()]}")
    ax.axis("off")
plt.suptitle("Clean")
plt.tight_layout()
plt.show()

In [None]:
visualize_gradcam(gradcam, images, adversarial_images, labels, 10)