In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import torchvision.datasets as datasets

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.ToTensor(),
])

testset = datasets.CIFAR10(root='root for the CIFAR10 dataset', train=False, download=True, transform=transform)
target_class = 0
class_indices = [i for i, (_, label) in enumerate(testset) if label == target_class]
from torch.utils.data import Subset
filtered_subset = Subset(testset, class_indices)
testloader = DataLoader(filtered_subset, batch_size=10, shuffle=True)

In [None]:
import argparse
import os

import torch
import yaml
from torchvision.utils import make_grid, save_image
from ema_pytorch import EMA

from model.models import get_models_class
from utils import Config, print0


def get_default_steps(model_type, steps):
    if steps is not None:
        return steps
    else:
        return {'DDPM': 100, 'EDM': 18}[model_type]

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--use_amp", action='store_true', default=False)
parser.add_argument("--mode", type=str, choices=['DDPM', 'DDIM'], default='DDIM')
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--eta", type=float, default=0.0)
parser.add_argument("--batches", type=int, default=1)
parser.add_argument("--epoch", type=int, default=-1)
parser.add_argument("--w", type=float, default=0.3)

opt = parser.parse_args(args=[
'--config', 'config/cifar_conditional_EDM.yaml',
'--use_amp',
'--mode', 'DDIM',
])

In [None]:
print0(opt)
yaml_path = opt.config
use_amp = opt.use_amp
mode = opt.mode
steps = opt.steps
eta = opt.eta
batches = opt.batches
ep = opt.epoch
w = opt.w

with open(yaml_path, 'r') as f:
    opt = yaml.full_load(f)
print0(opt)
opt = Config(opt)
if ep == -1:
    ep = opt.n_epoch - 1

device = "cuda:1"
# steps = get_default_steps(opt.model_type, steps)
steps = 50
DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type, guide=True)
diff = DIFFUSION(nn_model=NETWORK(**opt.network),
                    **opt.diffusion,
                    device=device,
                    drop_prob=0.1)
diff.to(device)

target = os.path.join(opt.save_dir, "ckpts", f"model_{ep}.pth")
print0("loading model at", target)
checkpoint = torch.load(target, map_location=device)
ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
ema.to(device)
ema.load_state_dict(checkpoint['EMA'])
model = ema.ema_model
model.eval()
print('model prepared.')

In [None]:
classifier = torch.load('the classifier to be attacked').to(device)

In [None]:
from tqdm import tqdm

def sdedit(edm_model, x, t=3, steps=18, eta=0.0, n_sample=1, class_label=0):
    x = x * 2 - 1
    model_args = edm_model.prepare_single_class_condition_(class_label, n_sample=n_sample)

    x_noised, sigma = edm_model.perturb(x, t=t, steps=steps)
    x_denoised = edm_model.D_x(x_noised, sigma=sigma, model_args=(model_args[0][:n_sample], model_args[1][:n_sample]), use_amp=False)

    x_denoised = (x_denoised + 1) * 0.5
    
    return x_denoised

In [None]:
@torch.no_grad()
def generate_x_adv_denoised_v2(x, y, model, classifier, t=3, eps=16/255, alpha = 2/255, iter=10, device='cuda:0', n_samples=10, class_label=0):

    delta = torch.zeros_like(x).to(x.device)

    loss_fn = torch.nn.CrossEntropyLoss(reduction="mean")

    for pgd_iter_id in range(iter):
        x_diff = sdedit(edm_model=model, x=x+delta, t=t, n_sample=n_samples, class_label=class_label)
        # x_diff = x
        x_diff.requires_grad_()

        with torch.enable_grad():
            loss = loss_fn(classifier(x_diff), y)
            loss.backward()
            grad = x_diff.grad.data

            delta += grad * alpha

            norm = torch.norm(delta.view(delta.size(0), -1), p=2, dim=1).view(-1, 1, 1, 1)
            factor = torch.min(torch.ones_like(norm), eps / norm)

            delta = delta * factor

    print("Done")

    x_adv = torch.clamp(x+delta, 0, 1)    
    
    return x_adv.detach()

In [None]:
import torch

class_label = 1
label_tensor = torch.tensor([class_label], dtype=torch.long).to(device)

misclassified_samples = []
misclassified_labels = []

success_count = 0
samples_per_process = 10
args = dict(n_sample=samples_per_process, size=opt.network['image_shape'], guide_w=w, notqdm=False, use_amp=use_amp)

while success_count < 100:
    print(success_count)
        
    x_gen = model.edm_sample_single_class(**args, class_label=class_label, steps=steps, eta=eta).float()
    labels = label_tensor.repeat(x_gen.size(0))
    x_adv = generate_x_adv_denoised_v2(x=x_gen.to(device), y=labels.to(device), model=model, classifier=classifier, t=3, eps=1.5, alpha=0.2, iter=50, n_samples=10, class_label=class_label)

    with torch.no_grad():
        pred_labels = classifier(x_adv).argmax(dim=1)

    for i in range(len(x_adv)):
        if pred_labels[i] != labels[i]:
            misclassified_samples.append(x_adv[i].cpu())
            misclassified_labels.append(labels[i].cpu().item())
            success_count += 1
            if success_count >= 100:
                break

saved_samples = torch.stack(misclassified_samples)
saved_labels = torch.tensor(misclassified_labels)

In [None]:
save_path = 'path to save the UAEs'
torch.save({
    'samples': saved_samples,
    'labels': saved_labels
}, save_path)

In [None]:
import torch
import matplotlib.pyplot as plt

def visualize_images(tensor, n_rows=2):
    """
    Visualize a batch of images stored in a PyTorch tensor.
    
    Parameters:
    - tensor: a torch.Tensor of shape (batch_size, 3, 32, 32) on GPU.
    - n_rows: number of rows in the subplot grid.
    """
    # Ensure the tensor is on CPU
    tensor = tensor.cpu()

    # Convert to numpy and adjust dimensions
    images = tensor.numpy().transpose((0, 2, 3, 1))

    # Calculate number of columns for the subplot grid
    batch_size = tensor.shape[0]
    n_cols = (batch_size + n_rows - 1) // n_rows

    # Create subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
    axes = axes.flatten()

    # Plot the images
    for i, ax in enumerate(axes):
        if i < batch_size:
            ax.imshow(images[i])
            ax.axis('off')
        else:
            ax.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
visualize_images(x_adv, n_rows=2)