### Load libraries

In [1]:
import pathlib
import sys

import git.repo
import numpy as np
import torch
import torchvision.datasets
import wandb
import wandb.sdk
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

GIT_ROOT = pathlib.Path(
    str(git.repo.Repo(".", search_parent_directories=True).working_tree_dir)
)

In [2]:
sys.path.append(str(GIT_ROOT))
from src.ax.attack.FastAutoAttack import FastAutoAttack
from src.ax.attack.FastPGD import FastPGD
from src.ax.models import wrn

### Load data

In [3]:
mnist_test = torchvision.datasets.MNIST(
    root="/var/tmp/scratch",
    train=False,
    download=True,
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /var/tmp/scratch/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /var/tmp/scratch/MNIST/raw/train-images-idx3-ubyte.gz to /var/tmp/scratch/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /var/tmp/scratch/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /var/tmp/scratch/MNIST/raw/train-labels-idx1-ubyte.gz to /var/tmp/scratch/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /var/tmp/scratch/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /var/tmp/scratch/MNIST/raw/t10k-images-idx3-ubyte.gz to /var/tmp/scratch/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /var/tmp/scratch/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /var/tmp/scratch/MNIST/raw/t10k-labels-idx1-ubyte.gz to /var/tmp/scratch/MNIST/raw



In [4]:
xs_test = (mnist_test.data / 255.0).reshape(-1, 1, 28, 28)
ys_test  = mnist_test.targets
print(xs_test.shape, ys_test.shape)
print(xs_test.dtype, ys_test.dtype)
print(xs_test.min(), xs_test.max())
print(ys_test.min(), ys_test.max())

torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.float32 torch.int64
tensor(0.) tensor(1.)
tensor(0) tensor(9)


### Load neural net

In [5]:
# Run https://wandb.ai/data-frugal-learning/adv-train/runs/134moxtz
api = wandb.Api()
run: wandb.sdk.wandb_run.Run = api.run(
    "data-frugal-learning/adv-train/134moxtz"
)

print("Nat (orig) acc:", run.summary["test_orig_acc_nat"])
print("Adv (orig) acc:", run.summary["test_orig_acc_adv"])

Nat (orig) acc: 0.9950999997138976
Adv (orig) acc: 0.9739000008583067


In [6]:
ckpt_path = (
    pathlib.Path(run.config["wandb_dir"])
    / "wandb"
    / "run-20221006_050356-134moxtz"
    / "files"
    / "model.ckpt"
)

model = wrn.get_mup_wrn(
    depth=run.config["depth"],
    width=run.config["width"],
    num_classes=10,
    mean=(0.5,),
    std=(1.0,),
    num_input_channels=1,
)
model = model.to(memory_format=torch.channels_last)  # type: ignore

# Load checkpoint
model.load_state_dict(torch.load(ckpt_path))
model.eval().cuda();

### Evaluate natural accuracy (sanity check)

In [7]:
n_correct = 0
with torch.no_grad():
    for images, labels in tqdm(
        DataLoader(TensorDataset(xs_test, ys_test), batch_size=256)
    ):
        logits = model(images.cuda())
        preds = logits.argmax(dim=-1)
        n_correct += (preds == labels.cuda()).sum().item()

print(f"Nat (orig) acc: {n_correct / len(xs_test):.4f}")

  0%|          | 0/40 [00:00<?, ?it/s]

Nat (orig) acc: 0.9951


### Evaluate autoattack accuracy

In [8]:
attack = FastAutoAttack(model, eps=0.3)
n = len(xs_test)

n_correct = 0
for xs, ys in tqdm(
    DataLoader(TensorDataset(xs_test[:n], ys_test[:n]), batch_size=512)
):
    xs_adv = attack(xs.half().cuda(), ys.cuda())

    logits = model(xs_adv)
    preds = logits.argmax(dim=-1)
    n_correct += (preds == ys.cuda()).sum().item()

print(f"Adv (orig) acc: {n_correct / n:.4f}")

  0%|          | 0/20 [00:00<?, ?it/s]

Adv (orig) acc: 0.9576


### Evaluate 40 step pgd

In [12]:
attack = FastPGD(
    model,
    eps=0.3,
    alpha=0.3 / 40 * 2.3,
    steps=40,
    random_start=True,
)
n_reps = 10

n = len(xs_test)

n_corrects = np.array([0 for _ in range(n_reps)])
for xs, ys in tqdm(
    DataLoader(TensorDataset(xs_test[:n], ys_test[:n]), batch_size=512)
):
    failures = torch.zeros_like(ys, dtype=torch.bool, device="cuda")
    for i in range(n_reps):
        xs_adv = attack(xs.half().cuda(), ys.cuda())

        logits = model(xs_adv)
        preds = logits.argmax(dim=-1)
        failures = failures | (preds != ys.cuda())

        n_corrects[i] += (~failures).sum().item()

print(n_corrects / n)


  0%|          | 0/20 [00:00<?, ?it/s]

[0.9726 0.9694 0.9687 0.9677 0.9667 0.966  0.9655 0.9653 0.9653 0.9651]
