In [1]:
import torch
from tqdm.autonotebook import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from lib.datasets import (
    get_cifar10, get_imagenet, get_mnist,
    get_cifar10_networks, get_imagenet_networks, get_mnist_networks,
    Task,
)
from lib.utils import *
from lib.attack import Attack, FgsmAttack

%load_ext autoreload
%autoreload 2

  from tqdm.autonotebook import tqdm


[2024-01-06 22:34:47,050] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Images folder: /scratch/diego/semester-proj/images


# Attacking classifiers

In [2]:
tasks = [
    Task(*get_mnist(), *get_mnist_networks(), name="MNIST"),
    Task(*get_cifar10(), *get_cifar10_networks(), name="CIFAR10"),
    Task(*get_imagenet(), *get_imagenet_networks(), name="ImageNet"),
]



### Inspecting the datasets

In [None]:
for task in tasks:
    confidence = task.classifier(task.sample_images).softmax(-1).gather(1, task.sample_labels[:, None]).squeeze(1)
    # Shorten labels of more than 13 characters
    limit = 16
    labels = [f"{label[:limit-3]}..." if len(label) > limit else label for label in task.sample_labels_str]
    labels = [f"{label} ({conf:.1%})" for label, conf in zip(labels, confidence)]
    show(task.sample_images, labels,
         width=400, height=400,
         coloraxis_showscale=False,
         font=dict(size=15),
         save=f"sample_{task.name}")

### Computing the test accuracy

In [None]:
for task in tasks:
    correct = 0
    total = 0
    with torch.inference_mode():
        for batch_dict in tqdm(task.dataset.iter(batch_size=100)):
            images = torch.stack([task.preprocess(image) for image in batch_dict['image']]).cuda()
            labels = torch.tensor(batch_dict['label']).cuda()
            predictions = task.classifier(images).argmax(-1)
            correct += (predictions == labels).sum()
            total += len(labels)

    print(f"{task.name}: {correct / total:.1%}")

## Fast Gradient Sign Method (FGSM)

### Attacking 1000 images

In [None]:
epsilon = 10 / 255

original_confidences = []
adversarial_confidences = []
all_attacks = []
all_labels = []
for task in tasks:
    print(task.name)
    images, labels = next(task.iter(1000))
    all_labels.append(labels)

    attacks = []
    for image, label in tqdm(zip(images, labels), total=len(images)):
        attack = FgsmAttack(image, None, task.classifier)
        attack.train(
            away_from_label=label,
            lr=epsilon,
        )

        attacks.append(attack)
    all_attacks.append(attacks)

    with torch.inference_mode():
        original_confidences.append(
            task.classifier(images)
                .softmax(-1)
                .gather(1, labels[:, None])
                .squeeze(1)
        )
        adversarial_images = torch.stack([attack.adversarial_image for attack in attacks])
        adversarial_confidences.append(
            task.classifier(adversarial_images)
                .softmax(-1)
                .gather(1, labels[:, None])
                .squeeze(1)
        )

### Plot confidence drop

In [None]:
fig = make_subplots(rows=1, cols=len(tasks),
                    subplot_titles=[task.name for task in tasks],
                    shared_xaxes=True, shared_yaxes=True,)
for i, (task, orig, adv) in enumerate(zip(tasks, original_confidences, adversarial_confidences)):
    fig.add_trace(go.Scatter(
        x=orig.cpu().detach().float(),
        y=adv.cpu().detach().float(),
        mode='markers',
        name=task.name,
        marker=dict(
            size=5,
            showscale=False,
            opacity=0.4,
        ),
    ), row=1, col=i+1)

fig.update_layout(
    font=dict(size=15),
    width=900,
    height=300,
    font_size=15,
    margin=dict(l=10, r=10, b=10, t=30),
)
# Add x and y axis titles
fig.update_xaxes(title_text="Clean accuracy", row=1, col=2)
fig.update_yaxes(title_text="Accuracy on adversarial images", row=1, col=1)

for col in range(len(tasks)):
    add_line(fig, "x=y", line=dict(color="black", dash="dash"),
             row=1, col=col+1)

# Hide the legend
fig.update_layout(showlegend=False)

fig.show()
fig.write_image(IMAGES_FOLDER / "fgsm_strength.png", scale=3)

### Show individual examples

In [None]:
for i in (0, 1, 2):
    task = tasks[i]
    attack = all_attacks[i][0]

    fig = make_subplots(rows=2, cols=3,
                        subplot_titles=[
                            "Original image",
                            "Adversarial image (eps=4/255)",
                            "Adversarial image (eps=10/255)",
                            "Top-5 categories for original image",
                            "Top-5 categories for eps=4/255",
                            "Top-5 categories for eps=10/255",
                        ],
                        vertical_spacing=0.1,
                        )

    # First column
    images = [
        attack.start,
        attack.start + attack.perturbation.sign() * 4 / 255,
        attack.adversarial_image,
    ]

    for i, image in enumerate(images):
        fig.add_trace(go.Image(z=to_plotly(image)), row=1, col=i+1)
        fig.add_trace(mk_top5_trace(image, task), row=2, col=i+1)

    # Hide axis for the images
    fig.update_xaxes(showticklabels=False, row=1)
    fig.update_yaxes(showticklabels=False, row=1)
    fig.update_yaxes(range=[0, 1], row=2)

    fig.update_layout(
        font=dict(size=15),
        width=900,
        height=500,
        font_size=15,
        margin=dict(l=10, r=10, b=10, t=30),
        showlegend=False,
    )

    fig.show()
    fig.write_image(IMAGES_FOLDER / f"fgsm_example_{task.name}.png")

### Compute accuracy under attack

In [None]:
with torch.inference_mode():
    rows = ["Clean accuracy\t\t", "Accuracy $\epsilon=4/255$", "Accuracy $\epsilon=10/255$"]
    epsilons = [0, 4/255, 10/255]
    columns = [task.name for task in tasks]

    # Output directly the latex table
    for row, epsilon in zip(rows, epsilons):
        print(row, end="\t& ")
        for task, attacks, labels in zip(tasks, all_attacks, all_labels):
            images = torch.stack([attack.start for attack in attacks])
            perturbations = torch.stack([attack.perturbation for attack in attacks])
            correct = 0
            for image, perturbation, label in batch(images, perturbations, labels, shuffle=False, drop_last=False):
                attacked = (image + perturbation * epsilon).clamp(0, 1)
                predictions = task.classifier(attacked).argmax(-1)
                correct += (predictions == label).sum()
            total = len(labels)
            print(f"{correct / total:.1%}".replace("%", r"\,\%"), end="\t& ")
        print(r"\\")

## Iterated Projected Gradient Descent

# Attacking Autoencoders

In [5]:
tasks = [
    Task(*get_mnist(), *get_mnist_networks(4), name="MNIST"),
    # Task(*get_cifar10(), *get_cifar10_networks(), name="CIFAR10"),
    Task(*get_imagenet(), *get_imagenet_networks(), name="ImageNet"),
]



### Autoencoder size and per-pixel error

In [11]:
from lib.utils import get_activations

# Number of parameters of the autoencoder and average per-pixel error
for task in tasks:
    print(task.name)
    print("Parameters:", sum(p.numel() for p in task.vae.parameters()))
    # Show the bottleneck size
    bottlnes_size = min(get_activations(task, task.sample_images[0], 'size'))
    print("Bottleneck size:", bottlnes_size)
    # Compute the average per-pixel error
    error = 0
    total = 0
    with torch.inference_mode():
        for batch_dict in tqdm(task.dataset.iter(batch_size=10)):
            images = torch.stack([task.preprocess(image) for image in batch_dict['image']]).cuda()
            reconstructions = task.vae(images)
            error += (images - reconstructions).abs().sum().item()
            total += images.numel()
    print("Average per-pixel error:", error / total)

MNIST
Parameters: 1639673
Bottleneck size: 8


1000it [00:01, 598.92it/s]


Average per-pixel error: 0.07281223453599579
ImageNet
Parameters: 83653863
Bottleneck size: 4096


1000it [03:39,  4.55it/s]

Average per-pixel error: 0.03313529102603594





### Show images and attacks through autoencoder

In [None]:
from matplotlib import cm

# 2 images for each task, 4 panels: (orig, orig+vae, adv 10eps, adv 10eps+vae)
fig = make_subplots(rows=2 * len(tasks), cols=4,
    column_titles=["Original image", "Original through VAE", "Adversarial image", "Adversarial through VAE"],
    row_titles=[task.name for task in tasks for _ in range(2)],
    vertical_spacing=0.01,
    horizontal_spacing=0.01,
)

row = 1
with torch.inference_mode():
    for task, attacks in zip(tasks, all_attacks):
        for attack in attacks[:2]:
            fig.add_trace(to_plotly(attack.start), row=row, col=1)
            fig.add_trace(to_plotly(task.vae(attack.start[None])[0]), row=row, col=2)
            fig.add_trace(to_plotly(attack.adversarial_image), row=row, col=3)
            fig.add_trace(to_plotly(task.vae(attack.adversarial_image[None])[0]), row=row, col=4)
            row += 1

fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)

fig.update_layout(
    font=dict(size=15),
    width=700,
    height=700,
    font_size=15,
    margin=dict(l=10, r=10, b=10, t=30),
    showlegend=False,
    coloraxis_showscale=False,
)

save(fig, "ae_examples")


### Compute accuracy under attack, with encoder defense

In [None]:
# Produce a LaTeX table two non-label columns: accuracy without the vae and accuracy with the vae
# On the left, the two tasks, and subdivide each task into three rows:
# - Clean accuracy, Accuracy epsilon=4/255, Accuracy epsilon=10/255

print(r"""
  \begin{tabular}{c|r|cc}
    & & \multicolumn{2}{c}{Autoencoder} \\
    & & Without & In front \\""")

with torch.inference_mode():
    rows = ["Clean accuracy\t\t", "Accuracy $\epsilon=4/255$", "Accuracy $\epsilon=10/255$"]
    epsilons = [0, 4/255, 10/255]

    for task, attacks, labels in zip(tasks, all_attacks, all_labels):
        print(r"    \hline")
        print(r"    \multirow{3}{*}{" + task.name + "}")
        images = torch.stack([attack.start for attack in attacks])
        perturbations = torch.stack([attack.perturbation for attack in attacks])
        for row, epsilon in zip(rows, epsilons):
            print("    & ", row, end="\t")
            for network in (task.classifier, nn.Sequential(task.vae, task.classifier)):
                correct = 0
                attacks = (images + perturbations * epsilon).clamp(0, 1)
                for attack, label in batch(attacks, labels, shuffle=False, drop_last=False):
                    predictions = network(attack).argmax(-1)
                    correct += (predictions == label).sum()
                total = len(labels)
                print(f"& {correct / total:.1%}".replace("%", r"\,\%"), end="\t")
            print(r"\\")
    print(r"  \end{tabular}")


### Attack through the autoencoder, samples

In [None]:
# Plot with 2 rows and 3 columns
# each row: base image, image through vae, top-5 categories
# A row for the base image, a row for the adversarial image
from lib.utils import mk_top5_trace


fig = make_subplots(rows=2  * len(tasks), cols=3,
                    column_titles=["Input image", "Input through VAE", "Top-5 predictions"],
                    row_titles=["Original image", "Adversarial image"] * len(tasks),
                    vertical_spacing=0.01,
                    horizontal_spacing=0.01,
                    )

epsilon = 10 / 255
row = 1
for task in tasks:
    for image, label in zip(*next(task.iter(batch_size=1))):
        fig.add_trace(to_plotly(image), row=row, col=1)
        with torch.no_grad():
            through_vae = task.vae(image[None])[0]
        fig.add_trace(to_plotly(through_vae), row=row, col=2)
        fig.add_trace(mk_top5_trace(through_vae, task), row=row, col=3)
        row += 1

        # Add the adversarial image
        attack = FgsmAttack(image, task.vae, task.classifier)
        attack.train(
            away_from_label=label,
            lr=epsilon / 10,
            num_steps=1000,
            early_stop=lambda env: env["self"].perturbation.mean() < 0.01 and env["loss"] < 0.01,
            l2_penalty=0.1,
            batch_size=10 if task.name == "MNIST" else 1,
        )

        fig.add_trace(to_plotly(attack.adversarial_image), row=row, col=1)
        with torch.no_grad():
            through_vae = task.vae(attack.adversarial_image[None])[0]
        fig.add_trace(to_plotly(through_vae), row=row, col=2)
        fig.add_trace(mk_top5_trace(through_vae, task), row=row, col=3)
        row += 1

fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_xaxes(col=3)
fig.update_yaxes(col=3, range=[0, 1], showticklabels=False)

fig.update_layout(
    font=dict(size=15),
    width=700,
    height=700,
    font_size=15,
    margin=dict(l=10, r=10, b=10, t=30),
    showlegend=False,
    coloraxis_showscale=False,
)

save(fig, "ae_attack_examples")

Loss: 0.00025: 100%|██████████| 1000/1000 [00:04<00:00, 247.82it/s]
Loss: 0.00006:   0%|          | 3/1000 [00:00<02:44,  6.06it/s]


Saved to /scratch/diego/semester-proj/images/ae_attack_examples.png


PosixPath('/scratch/diego/semester-proj/images/ae_attack_examples.png')

### Showing many attacks through the autoencoder

In [None]:
# Show many attacks
# 6 cols. even:adversarial image, odd: adv through vae
# 2 rows per task (so 6 images per task)

rows_per_task = 3

fig = make_subplots(rows=rows_per_task * len(tasks), cols=6,
                    column_titles=["Image", "Image through VAE"] * 3,
                    row_titles=[task.name for task in tasks for _ in range(rows_per_task)],
                    vertical_spacing=0.01,
                    horizontal_spacing=0.01,
                    )

epsilon = 10 / 255
img = 0
for task in tasks:
    for image, label in task.iter(batch_size=1):
        image.squeeze_(0)
        label = label.item()

        # Check if the image is correctly classified
        with torch.inference_mode():
            prediction = task.classifier(image[None]).argmax(-1).item()
            if prediction != label:
                print("Image misclassified")
                continue

        # Perform the attack
        attack = FgsmAttack(image, task.vae, task.classifier)
        attack.train(
            away_from_label=label,
            lr=epsilon / 10,
            num_steps=1000,
            early_stop=lambda env: env["self"].perturbation.mean() < 0.01 and env["loss"] < 0.01,
            l2_penalty=0.1,
            batch_size=10 if task.name == "MNIST" else 1,
        )

        # Check whether the attack was successful
        with torch.no_grad():
            through_vae = task.vae(attack.adversarial_image[None])[0]
            prediction = task.classifier(through_vae[None]).argmax(-1).item()
            if prediction == label:
                print(f"Attack failed {prediction} == {label}")
                continue

        fig.add_trace(to_plotly(attack.adversarial_image), row=img // 6 + 1, col=img % 6 + 1)
        fig.add_trace(to_plotly(through_vae), row=img // 6 + 1, col=img % 6 + 2)

        img += 2
        if img % (rows_per_task * 3) == 0:
            break


fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout(
    font=dict(size=15),
    width=900,
    height=700,
    font_size=15,
    margin=dict(l=10, r=10, b=10, t=30),
    showlegend=False,
    coloraxis_showscale=False,
)

save(fig, "ae_many_attack_examples")


Loss: 0.00025: 100%|██████████| 1000/1000 [00:04<00:00, 244.35it/s]
Loss: 0.00025:  28%|██▊       | 282/1000 [00:01<00:02, 248.63it/s]
Loss: 0.00044: 100%|██████████| 1000/1000 [00:03<00:00, 254.62it/s]
Loss: 0.00045: 100%|██████████| 1000/1000 [00:04<00:00, 244.19it/s]
Loss: 1.00000: 100%|██████████| 1000/1000 [00:03<00:00, 254.29it/s]


Attack failed 4 == 4


Loss: 0.00040: 100%|██████████| 1000/1000 [00:04<00:00, 249.98it/s]
Loss: 0.00022:  33%|███▎      | 332/1000 [00:01<00:02, 241.33it/s]
Loss: 0.00028: 100%|██████████| 1000/1000 [00:03<00:00, 253.48it/s]
Loss: 0.00017:   4%|▍         | 44/1000 [00:00<00:05, 190.91it/s]
Loss: 0.00027: 100%|██████████| 1000/1000 [00:03<00:00, 255.08it/s]
Loss: 0.00035:   0%|          | 3/1000 [00:00<02:43,  6.08it/s]
Loss: 0.00597:   0%|          | 1/1000 [00:00<03:49,  4.35it/s]
Loss: 0.00654:   0%|          | 2/1000 [00:00<02:51,  5.81it/s]
Loss: 0.00927:   0%|          | 3/1000 [00:00<02:32,  6.53it/s]
Loss: 0.00125:   0%|          | 3/1000 [00:00<02:32,  6.52it/s]


Image misclassified
Image misclassified


Loss: 0.00382:   0%|          | 4/1000 [00:00<02:23,  6.93it/s]


Image misclassified


Loss: 0.00215:   0%|          | 3/1000 [00:00<02:33,  6.51it/s]
Loss: 0.00227:   0%|          | 4/1000 [00:00<02:26,  6.79it/s]
Loss: 0.00012:   0%|          | 4/1000 [00:00<02:24,  6.87it/s]


Saved to /scratch/diego/semester-proj/images/ae_many_attack_examples.png


PosixPath('/scratch/diego/semester-proj/images/ae_many_attack_examples.png')

## Targeted attacks
### Attacking the ImageNet autoencoder to produce a different image

In [13]:
from lib.utils import find_closest
task = tasks[-1]
start_idx = 0
start = task.sample_images[start_idx]

target = find_closest(start, task.sample_labels[start_idx],
                      task, max_checks=15)

Found target at 1 with distance 78821.6640625
Found target at 3 with distance 71796.390625
Found target at 4 with distance 53146.8984375
Found target at 9 with distance 32515.8359375
Found target at 12 with distance 31962.734375


In [None]:
# Show the target
show([start, target], ["Start", "Target"])
attack = FgsmAttack(start, task.vae, task.classifier)

In [None]:
try:
    attack.train(
        target_image=target,
        # target_class=9,A
        # away_from_label=task.sample_labels[start_idx],
        lr=0.004,
        num_steps=100,
        # early_stop=lambda env: env["self"].perturbation.mean() < 0.01 and env["loss"] < 0.01,
        l2_penalty=1,
        max_l1_distance=50 / 255,
        batch_size=10,
        # normalize_grad=True,
    )
except KeyboardInterrupt:
    pass
attack.show(task.sample_labels[start_idx], task.labels)

In [46]:
settings = attack.export_settings("target_image", "target_class", )
settings

{0: ['args/max_l1_distance', 'args/lr', 'args/l2_penalty', 'args/batch_size'], 100: ['args/max_l1_distance', 'args/lr', 'args/l2_penalty', 'args/batch_size'], 200: ['args/max_l1_distance', 'args/lr', 'args/l2_penalty', 'args/batch_size'], 300: ['args/max_l1_distance', 'args/lr', 'args/l2_penalty', 'args/batch_size']}


[{'max_l1_distance': 0.19607843137254902,
  'lr': 0.04,
  'l2_penalty': 1,
  'batch_size': 10,
  'num_steps': 100},
 {'max_l1_distance': 0.19607843137254902,
  'lr': 0.004,
  'l2_penalty': 1,
  'batch_size': 10,
  'num_steps': 200},
 {'max_l1_distance': 0.19607843137254902,
  'lr': 0.0004,
  'l2_penalty': 1,
  'batch_size': 10,
  'num_steps': 100}]

In [17]:
attack = FgsmAttack(start, task.vae, task.classifier)
for setting in settings[task.name]:
    attack.train(target_image=target, **setting)

Loss: 0.00655: 100%|██████████| 600/600 [01:05<00:00,  9.20it/s]
Loss: 0.00513: 100%|██████████| 300/300 [00:32<00:00,  9.29it/s]
Loss: 0.00483: 100%|██████████| 200/200 [00:21<00:00,  9.24it/s]


In [None]:
attack.show(task.sample_labels[start_idx], task.labels)

In [19]:
# 2x2 plot with orig, target | adv, adv+vae
fig = make_subplots(rows=2, cols=2,
    subplot_titles=["Original image", "Target image", "Adversarial image", "Adversarial image through VAE"],
    vertical_spacing=0.04,
    horizontal_spacing=0.01,
    )

fig.add_trace(to_plotly(start), row=1, col=1)
fig.add_trace(to_plotly(target), row=1, col=2)
fig.add_trace(to_plotly(attack.adversarial_image), row=2, col=1)
fig.add_trace(to_plotly(task.vae(attack.adversarial_image[None])), row=2, col=2)

fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout(
    font=dict(size=15),
    width=700,
    height=700,
    font_size=15,
    margin=dict(l=10, r=10, b=10, t=30),
    showlegend=False,
    coloraxis_showscale=False,
)

save(fig, "ae_targeted_attack_example")

Saved to /scratch/diego/semester-proj/images/ae_targeted_attack_example.png


PosixPath('/scratch/diego/semester-proj/images/ae_targeted_attack_example.png')

### Show many attacks at once

In [3]:
import torch
from lib.utils import find_closest

settings_mnist = [dict(
    lr=1 / 255,
    num_steps=1000,
    early_stop=lambda env: env["self"].perturbation.mean() < 0.01 and env["loss"] < 0.01,
    l2_penalty=0.1,
    batch_size=10,
)]

settings_imagenet = [{'max_l1_distance': 70 / 255,
  'lr': 0.004,
  'l2_penalty': 1,
  'num_steps': 600},
 {'max_l1_distance': 70 / 255,
  'lr': 0.001,
  'l2_penalty': 1,
  'num_steps': 300},
 {'max_l1_distance': 70 / 255,
  'lr': 0.0004,
  'l2_penalty': 1,
  'num_steps': 200}]

settings = {
    "MNIST": settings_mnist,
    "ImageNet": settings_imagenet,
}


def get_transform_attacks(task, n=12, force_compute=False):
    path = f"data/attacks_transform_{task.name}_{n}.pt"

    # Load the attacks from disk if they exist
    try:
        if not force_compute:
            attacks = torch.load(path)
            print(f"Loaded attacks from {path}")
            return attacks
    except (FileNotFoundError, RuntimeError):
        pass

    attacks = []
    for image, label in task.iter(batch_size=1):
        image.squeeze_(0)
        label = label.item()

        if task.name == "MNIST":
            target = dict(away_from_label=label)
        else:
            target = dict(target_image=find_closest(image, label, task, max_checks=1000))

        # Perform the attack
        attack = FgsmAttack(image, task.vae, task.classifier)
        for setting in settings[task.name]:
            attack.train(**target, **setting)

        attacks.append(attack)

        if len(attacks) == n:
            break

    # Save the attacks to disk
    torch.save(attacks, path)

    return attacks


In [25]:
task = tasks[0]
attacks = get_transform_attacks(task, n=12)

Loaded attacks from data/attacks_transform_MNIST_12.pt


In [26]:
# 6 cols. even:adversarial image, odd: adv through vae
# 4 rows. Only with Imagenet

rows = 4
fig = make_subplots(rows=rows, cols=6,
    column_titles=["Image", "Image through VAE"] * 3,
    vertical_spacing=0.01,
    horizontal_spacing=0.01,
    )

for i, attack in enumerate(attacks):
    row = i // 3 + 1
    col = i % 3 * 2 + 1
    fig.add_trace(to_plotly(attack.adversarial_image), row=row, col=col)
    with torch.no_grad():
        through_vae = task.vae(attack.adversarial_image[None])
    fig.add_trace(to_plotly(through_vae), row=row, col=col+1)

fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout(
    width=1600,
    height=1200,
    font_size=20,
    margin=dict(l=10, r=10, b=10, t=30),
    showlegend=False,
    coloraxis_showscale=False,
)

save(fig, f"ae_many_targeted_attack_examples_{task.name}")

Saved to /scratch/diego/semester-proj/images/ae_many_targeted_attack_examples_MNIST.png


PosixPath('/scratch/diego/semester-proj/images/ae_many_targeted_attack_examples_MNIST.png')

# Phase trasition: evolution inside of the network

### Setup

In [4]:
from lib.datasets import get_tasks

tasks = get_tasks()
task = tasks[-1]

attacks = get_transform_attacks(task, n=12)



Loaded attacks from data/attacks_transform_ImageNet_12.pt


In [16]:
from functools import partial
from lib.utils import record_activations, get_activations
from itertools import islice


activation_sizes = get_activations(task, task.sample_images[:1], "size")
n_activations = len(activation_sizes)

n_images = 100
mean_norms = 0

# Compute the mean norm of the activations of 100
for image, _ in islice(task.iter(batch_size=1), n_images):
    mean_norms += get_activations(task, image, "norm")

mean_norms /= n_images
mean_norms.shape

torch.Size([18])

### Evolution of distance between original and adversarial images

In [19]:
from einops import reduce
all_starts = torch.stack([attack.start for attack in attacks])
all_adversarial = torch.stack([attack.adversarial_image for attack in attacks])

with torch.no_grad():
    with recorder() as cache_clean:
        task.vae(all_starts)

    with recorder() as cache_adv:
        task.vae(all_adversarial)

    distances = torch.stack([
        reduce((activations_clean - activations_adv) ** 2, "b ... -> b", "sum").sqrt()
        for activations_clean, activations_adv in zip(cache_clean.values(), cache_adv.values())
    ])
    print(distances.shape)

halfway
torch.Size([82, 12])


In [30]:
# Plot the distances
for normalize in (False, True):

    if normalize:
        ys = distances.T / mean_norms
    else:
        ys = distances.T

    fig = go.Figure()
    for dist in ys:
        fig.add_trace(go.Scatter(
            y=(dist).cpu(),
            mode='lines',
        ))

    annotate_modules(fig, cache.keys(), has_classifier=False)
    fig.update_xaxes(title_text="Layer", showticklabels=False)
    fig.update_yaxes(title_text="Distance" if not normalize else "Distance / mean norm",
                     type="linear" if normalize else "log")

    fig.update_layout(
        width=500,
        height=400,
        font_size=15,
        margin=dict(l=10, r=10, b=10, t=30),
        showlegend=False,
        coloraxis_showscale=False,
    )
    save(fig, "ae_attack_distances" + ("_normalized" if normalize else ""))
    fig.show()


Images folder: /scratch/diego/semester-proj/images
Saved to /scratch/diego/semester-proj/images/ae_attack_distances.png


Saved to /scratch/diego/semester-proj/images/ae_attack_distances_normalized.png


### Norms inside the network

In [37]:
from lib.utils import annotate_modules, get_activations
from itertools import islice
import plotly

for task in [
    tasks[0],
    tasks[-1],
]:
    attacks = get_transform_attacks(task, n=12)

    clean_norms = torch.stack([
        get_activations(task, attack.start, "norm")
        for attack in attacks
    ])
    adversarial_norms = torch.stack([
        get_activations(task, attack.adversarial_image, "norm")
        for attack in attacks
    ])

    # distances = torch.tensor([
    #     [torch.dist(a1, a2)
    #     for a1, a2 in zip(
    #         get_activations(task, attack.start, "cache").values(),
    #         get_activations(task, attack.adversarial_image, "cache").values(),
    #     )]
    #     for attack in attacks
    # ], device="cuda")

    names = get_activations(task, attacks[0].start, "name")

    # Compute the mean norm of the activations of 100
    n_images = 100
    mean_norms = 0
    for image, _ in islice(task.iter(batch_size=1), n_images):
        mean_norms += get_activations(task, image, "norm")
    mean_norms /= n_images

    print(mean_norms.shape)
    print(clean_norms.shape)

    adv = adversarial_norms / mean_norms
    clean = clean_norms / mean_norms

    print(adv.shape)

    fig = go.Figure()
    for name, norms in zip(["Clean", "Adversarial"], [clean, adv]):
            # y = norms[4]
            # y = distances[0] / mean_norms
        if task.name == "MNIST":
            for i, y in enumerate(norms[:4]):
                fig.add_trace(go.Scatter(
                    x=names,
                    y=y.cpu(),
                    name=name + f" {i}",
                    mode='lines',
                    line=dict(
                        color=plotly.colors.DEFAULT_PLOTLY_COLORS[i],
                        dash="solid" if name != "Clean" else "dash",
                    ),
                    legendgroup=i
                ))
        else:
            fig.add_trace(go.Bar(
                x=names,
                y=norms[0].cpu() - 1,
                base=1,
                name=name,
                legendgroup=name,
            ))


    annotate_modules(fig, names, has_classifier=False)

    fig.update_xaxes(showticklabels=False, title_text="Layer")
    fig.update_yaxes(title_text="Normalised activation norm")
    fig.update_layout(
        font=dict(size=15),
        width=900,
        height=500,
        font_size=15,
        margin=dict(l=10, r=10, b=10, t=30),
        # showlegend=False,
        coloraxis_showscale=False,
    )
    fig.show()
    save(fig, f"ae_activation_norm_{task.name}")


Loaded attacks from data/attacks_transform_MNIST_12.pt
torch.Size([18])
torch.Size([12, 18])
torch.Size([12, 18])


Saved to /scratch/diego/semester-proj/images/ae_activation_norm_MNIST.png
Loaded attacks from data/attacks_transform_ImageNet_12.pt
torch.Size([82])
torch.Size([12, 82])
torch.Size([12, 82])


Saved to /scratch/diego/semester-proj/images/ae_activation_norm_ImageNet.png


### Evolution of the norms inside the classifier

In [42]:
from lib.utils import annotate_modules, get_activations

for task in tasks[1:]:
    # using 1, because 0 is not classified correctly through the vae
    idx = 1
    attack = FgsmAttack(task.sample_images[idx], task.vae, task.classifier)
    attack.train(
        away_from_label=task.sample_labels[idx],
        lr=10 / 255,
        # num_steps=100,
        # early_stop=lambda env: env["self"].perturbation.mean() < 0.01 and env["loss"] < 0.01,
        # l2_penalty=1,
        # batch_size=10,
    )
    attack.show(task.sample_labels[idx], task.labels)

    # Compute the mean norm of the activations of 100
    n_images = 100
    mean_norms = 0
    for image, _ in islice(task.iter(batch_size=1), n_images):
        mean_norms += get_activations(task, image, "norm")
    mean_norms /= n_images

    adversarial_norms = get_activations(task, attack.adversarial_image, "norm") / mean_norms
    clean_norms = get_activations(task, attack.start, "norm") / mean_norms

    names = get_activations(task, attack.start, "name")

    fig = go.Figure()
    for name, norms in zip(["Clean", "Adversarial"], [clean_norms, adversarial_norms]):
        fig.add_trace(go.Scatter(
            x=names,
            y=norms.cpu(),
            name=name,
            mode='lines',
            line=dict(
                dash="solid" if name != "Clean" else "dash",
            ),
        ))

    annotate_modules(fig, names, has_classifier=False)

    fig.update_xaxes(showticklabels=False, title_text="Layer")
    fig.update_yaxes(title_text="Normalised activation norm")

    fig.update_layout(
        font=dict(size=15),
        width=900,
        height=500,
        font_size=15,
        margin=dict(l=10, r=10, b=10, t=30),
        # showlegend=False,
        coloraxis_showscale=False,
    )
    fig.show()


[4mOriginal image[0m
[92mTop 1: 57.67% ship[0m
Top 2: 14.32% automobile
Top 3:  5.42% cat
Top 4:  5.08% bird
Top 5:  3.40% deer

[4mAdversarial image[0m
[92mTop 1: 29.69% ship[0m
Top 2: 24.94% automobile
Top 3: 13.98% airplane
Top 4:  6.48% cat
Top 5:  5.44% bird

[4mOriginal reconstruction[0m
[92mTop 1: 57.81% ship[0m
Top 2: 11.51% airplane
Top 3:  5.18% cat
Top 4:  4.64% automobile
Top 5:  4.13% dog

[4mAdversarial reconstruction[0m
Top 1: 50.34% airplane
Top 2:  9.65% automobile
Top 3:  9.53% horse
Top 4:  8.07% dog
Top 5:  4.42% truck
[92mTop 8:  3.67% ship[0m



[4mOriginal image[0m
[92mTop 1: 70.39% Italian greyhound[0m
Top 2: 14.14% Doberman, Doberman pinscher
Top 3:  3.81% Ibizan hound, Ibizan Podenco
Top 4:  3.61% Weimaraner
Top 5:  2.14% whippet

[4mAdversarial image[0m
Top 1: 11.14% African elephant, Loxodonta africana
Top 2:  9.70% ram, tup
Top 3:  7.05% Chesapeake Bay retriever
Top 4:  6.79% wild boar, boar, Sus scrofa
Top 5:  5.20% Indian elephant, Elephas maximus
[92mTop 17:  1.29% Italian greyhound[0m

[4mOriginal reconstruction[0m
[92mTop 1: 73.58% Italian greyhound[0m
Top 2: 10.69% Weimaraner
Top 3:  9.39% Doberman, Doberman pinscher
Top 4:  1.93% Ibizan hound, Ibizan Podenco
Top 5:  1.27% whippet

[4mAdversarial reconstruction[0m
Top 1: 45.72% African elephant, Loxodonta africana
Top 2:  5.12% warthog
Top 3:  5.04% Indian elephant, Elephas maximus
Top 4:  4.60% ram, tup
Top 5:  3.75% water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
[92mTop 28:  0.22% Italian greyhound[0m



### Bird extinction

In [54]:
from PIL import Image

task = tasks[-1]
bird = task.sample_images[0]
target = task.preprocess(Image.open("images/no-bird-dalle.png"))

attack = FgsmAttack(bird, task.vae, task.classifier)


In [None]:
for setting in settings[task.name]:
    sets = dict(setting)
    sets["num_steps"] //= 4
    sets["max_l1_distance"] = 20/255
    print(sets)
    attack.train(target_image=target, **sets)
attack.show(task.sample_labels[0], task.labels)

In [57]:
from torchvision.transforms import ToPILImage

with torch.no_grad():
    through_vae = task.vae(attack.adversarial_image[None])[0].clip(0, 1)

ToPILImage()(through_vae.cpu()).save("./images/no-more-bird.png")
ToPILImage()(attack.adversarial_image).save("images/no-more-bird-pre.png")