In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from model.nn_model import VAEClassifier, StAEClassifier
from pgd_purify import vae_purify, stae_purify, pgd_linf

batch_size=1024
num_visualize = 12
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vae_classifier = VAEClassifier()
stae_classifier = StAEClassifier()
vae_classifier.load_state_dict(torch.load('./model/vae_clf.pth', map_location=torch.device('cpu')))
stae_classifier.load_state_dict(torch.load('./model/stae_clf.pth', map_location=torch.device('cpu')))
vae_classifier = vae_classifier.to(device)
stae_classifier = stae_classifier.to(device)
vae_classifier = vae_classifier.eval()
stae_classifier = stae_classifier.eval()

test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ])), batch_size=batch_size, shuffle=False, num_workers=4)

def pred2label(y_pred_vae, y_pred_stae, labels):
    label_name = np.array(['top', 'trouser', 'pullover', 'dress', 'coat',
                           'sandal', 'shirt', 'sneaker', 'bag', 'ank boot'])
    gt = labels.detach().cpu().numpy()
    pred_vae = y_pred_vae.argmax(-1).cpu().detach().numpy()
    pred_stae = y_pred_stae.argmax(-1).cpu().detach().numpy()
    label_selected = label_name[gt]
    vae_pred = label_name[pred_vae]
    stae_pred = label_name[pred_stae]
    print('Accuracy of the selected batch')
    acc_vae = np.sum(np.array(gt) == np.array(pred_vae)) / len(gt)
    acc_stae = np.sum(np.array(gt) == np.array(pred_stae)) / len(gt)
    print('Accuracy of Standard-AE-Classifier:', acc_stae)
    print('Accuracy of VAE-Classifier:', acc_vae)
    return label_selected, vae_pred, stae_pred


def top_k_vis(num_visualize, images, x_reconst_vae, x_reconst_stae):
    input_img = images[:num_visualize,0].detach().cpu().numpy()
    vae_reconst = x_reconst_vae[:num_visualize,0].detach().cpu().numpy()
    stae_reconst = x_reconst_stae[:num_visualize,0].detach().cpu().numpy()
    input_sample = np.concatenate(input_img, axis=1)
    reconst_vae = np.concatenate(vae_reconst, axis=1)
    reconst_stae = np.concatenate(stae_reconst, axis=1)
    return input_sample, reconst_vae, reconst_stae

gt_label = []
stae_pred = []
vae_pred = []
for batch_idx, (data, target) in enumerate(test_loader):
    images, labels = data.to(device), target.to(device)
    with torch.no_grad():
        x_reconst_vae, z, y_pred_vae, mu, log_var = vae_classifier(images, deterministic=True,
                                                                   classification_only=False)
        x_reconst_stae, z, y_pred_stae = stae_classifier(images, classification_only=False)

    break

label_selected, vae_pred, stae_pred = pred2label(y_pred_vae, y_pred_stae, labels)
input_sample, reconst_vae, reconst_stae = top_k_vis(num_visualize, images, x_reconst_vae, x_reconst_stae)

adv_vae = pgd_linf(images, labels, vae_classifier, atk_itr=200, eps=50/255, alpha=2/255, device=device)
adv_stae = pgd_linf(images, labels, stae_classifier, atk_itr=200, eps=50/255, alpha=2/255, device=device)

torch.save(adv_vae, './adv_vae.pt')
torch.save(adv_stae, './adv_stae.pt')

# if you don't want to train the adversarial samples again in order to save time, you can use these codes where adv samples have been saved
'''
adv_vae = torch.load('./adv_vae.pt', map_location=device)
adv_stae = torch.load('./adv_stae.pt', map_location=device)
'''

with torch.no_grad():
    x_reconst_vae, z, y_pred_vae, mu, log_var = vae_classifier(adv_vae, deterministic=True, classification_only=False)
    x_reconst_stae, z, y_pred_stae = stae_classifier(adv_stae, classification_only=False)

label_selected, vae_pred, stae_pred = pred2label(y_pred_vae, y_pred_stae, labels)
input_adv_sample_vae, reconst_vae, _ = top_k_vis(num_visualize, adv_vae, x_reconst_vae, x_reconst_stae)
input_adv_sample_stae, _, reconst_stae = top_k_vis(num_visualize, adv_stae, x_reconst_vae, x_reconst_stae)

# MagNet ConvAE

In [None]:
import torch
import torch.nn as nn

class ConvAutoEncoder(nn.Module):
    def __init__(self):
        super(ConvAutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=4
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
import torch.optim as optim
from tqdm import tqdm

def train_ae(model, data_loader, criterion, optimizer, num_epochs=20):
    model.train()
    print(f"begin trainning {model.__class__.__name__}...")
    for epoch in range(num_epochs):
        total_loss = 0
        pbar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for data in pbar:
            images, _ = data
            images = images.to(device)
            outputs = model(images)
            loss = criterion(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.6f}"})

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.6f}")

    print("Completed")

print("="*30)
detector_ae = ConvAutoEncoder().to(device)
criterion_detector = nn.MSELoss()
optimizer_detector = optim.Adam(detector_ae.parameters(), lr=1e-3, weight_decay=1e-5)

train_ae(detector_ae, train_loader, criterion_detector, optimizer_detector, num_epochs=50)
torch.save(detector_ae.state_dict(), './detector_ae.pth')
print("Detector AE has been saved detector_ae.pth")


print("="*30)
reformer_ae = ConvAutoEncoder().to(device)
criterion_reformer = nn.L1Loss()
optimizer_reformer = optim.Adam(reformer_ae.parameters(), lr=1e-3, weight_decay=1e-5)

train_ae(reformer_ae, train_loader, criterion_reformer, optimizer_reformer, num_epochs=50)
torch.save(reformer_ae.state_dict(), './reformer_ae.pth')
print("Reformer AE has been saved reformer_ae.pth")

In [None]:
detector_ae = ConvAutoEncoder().to(device)
detector_ae.load_state_dict(torch.load('./detector_ae.pth', map_location=device))
detector_ae.eval()

reformer_ae = ConvAutoEncoder().to(device)
reformer_ae.load_state_dict(torch.load('./reformer_ae.pth', map_location=device))
reformer_ae.eval()


@torch.no_grad()
def magnet_purify_pytorch(adv_data, threshold=0.001, device=None):
    data_in = adv_data.to(device)
    reconstructed = detector_ae(data_in)
    rec_error = torch.mean((data_in - reconstructed)**2, dim=(1, 2, 3))
    mask_adv = rec_error > threshold
    purify_data = data_in.clone()
    if torch.any(mask_adv):
        purified = reformer_ae(data_in[mask_adv])
        purify_data[mask_adv] = purified
    return purify_data

In [None]:

BEST_THRESHOLD_VAE = 0.00001

BEST_THRESHOLD_STAE = 0.00788


_, labels = next(iter(test_loader))
labels = labels.to(device)


purify_data_vae = magnet_purify_pytorch(adv_vae, threshold=BEST_THRESHOLD_VAE, device=device)
purify_data_stae = magnet_purify_pytorch(adv_stae, threshold=BEST_THRESHOLD_STAE, device=device)


with torch.no_grad():
    x_reconst_vae, z, y_pred_vae, mu, log_var = vae_classifier(purify_data_vae, deterministic=True, classification_only=False)
    x_reconst_stae, z, y_pred_stae = stae_classifier(purify_data_stae, classification_only=False)

label_selected, vae_pred, stae_pred = pred2label(y_pred_vae, y_pred_stae, labels)

# ConvAE

In [None]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

reformer_ae = ConvAutoEncoder().to(device)
reformer_ae.load_state_dict(torch.load('./reformer_ae.pth', map_location=device))
reformer_ae.eval()

def reformer_purify(adv_data, reformer_model, device=None):
    data_in = adv_data.to(device)
    purified_data = reformer_model(data_in)
    return purified_data

purify_data_vae = reformer_purify(adv_vae, reformer_ae, device=device)
purify_data_stae = reformer_purify(adv_stae, reformer_ae, device=device)


with torch.no_grad():
    x_reconst_vae, z, y_pred_vae, mu, log_var = vae_classifier(purify_data_vae, deterministic=True, classification_only=False)
    x_reconst_stae, z, y_pred_stae = stae_classifier(purify_data_stae, classification_only=False)

label_selected, vae_pred, stae_pred = pred2label(y_pred_vae, y_pred_stae, labels)

# DAE-Recon

In [None]:
def dae_purify(data, model, atk_itr=100, eps=0.2, alpha=1/255, random_iteration=1, device=None):

    purified_data = data.clone().detach().to(device)

    for _ in range(random_iteration):
        delta = torch.zeros_like(purified_data, requires_grad=True).to(device)

        for i in range(atk_itr):
            model.zero_grad()
            try:
                x_reconst, *rest = model(purified_data + delta, deterministic=False, classification_only=False)
            except TypeError:
                x_reconst, *rest = model(purified_data + delta, classification_only=False)
            loss = torch.sum((x_reconst - purified_data - delta) ** 2)
            loss.backward()
            delta.data = delta.data - alpha * torch.sign(delta.grad.data)
            delta.data = torch.clamp(delta.data, -eps, eps)
            delta.data = torch.clamp(purified_data + delta.data, 0, 1) - purified_data
            delta.grad.zero_()

        purified_data = purified_data + delta.detach()
        purified_data = torch.clamp(purified_data, 0, 1)

    return purified_data


purify_data_vae = dae_purify(adv_vae, vae_classifier, atk_itr=100, eps=50/255, random_iteration=1, device=device)
purify_data_stae = dae_purify(adv_stae, stae_classifier, atk_itr=100, eps=50/255, random_iteration=1, device=device)

with torch.no_grad():
    x_reconst_vae, z, y_pred_vae, mu, log_var = vae_classifier(purify_data_vae, deterministic=True, classification_only=False)
    x_reconst_stae, z, y_pred_stae = stae_classifier(purify_data_stae, classification_only=False)

label_selected, vae_pred, stae_pred = pred2label(y_pred_vae, y_pred_stae, labels)
input_pfy_sample_vae, reconst_vae, _ = top_k_vis(num_visualize, purify_data_vae, x_reconst_vae, x_reconst_stae)
input_pfy_sample_stae, _, reconst_stae = top_k_vis(num_visualize, purify_data_stae, x_reconst_vae, x_reconst_stae)