In [0]:
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F


def load_dataset(dataset_name):

    if dataset_name == 'mnist':

        num_classes = 10
        in_channels = 1

        train = datasets.MNIST('data', train=True, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    #transforms.Normalize((0.1307,), (0.3081,))
                                    #transforms.Normalize((0.5,), (0.5,))
                                    ]))


        test = datasets.MNIST('data', train=False, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    #transforms.Normalize((0.1307,), (0.3081,))
                                    #transforms.Normalize((0.5,), (0.5,))
                                    ]))


    elif dataset_name == 'fmnist':

        num_classes = 10
        in_channels = 1

        train = datasets.FashionMNIST('data', train=True, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    # transforms.Normalize((0.1307,), (0.3081,))
                                    ]))

        test = datasets.FashionMNIST('data', train=False, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    # transforms.Normalize((0.1307,), (0.3081,))
                                    ]))


    elif dataset_name == 'cifar10':

        num_classes = 10
        in_channels = 3

        train = datasets.CIFAR10('data', train=True, download=True,
                                transform=transforms.Compose([
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615))
                                    ]))

        test = datasets.CIFAR10('data', train=False, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                    ]))


    else:
        raise Exception("dataset must be one of mnist, fmnist and cifar10")

    return train, test, in_channels, num_classes



class Model_MNIST(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Model_MNIST, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes

        self.conv1_1 = nn.Conv2d(self.in_channels, 32, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.fc1 = nn.Linear(7*7*64, 200)
        self.fc2 = nn.Linear(200, self.num_classes)


    def forward(self, x):

        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))

        x = self.maxpool1(x)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))

        x = self.maxpool2(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


class Model_CIFAR(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Model_CIFAR, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes

        self.conv1_1 = nn.Conv2d(self.in_channels, 32, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3_1 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.fc1 = nn.Linear(3*4*4*64, 200)
        self.fc2 = nn.Linear(200, self.num_classes)


    def forward(self, x):

        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))

        x = self.maxpool1(x)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))

        x = self.maxpool2(x)

        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))

        x = self.maxpool3(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x



class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        residual = x

        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))

        out = out + residual

        return out


class UpsampleConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()

        self.upsample = upsample
        if upsample:
            self.upsample_layer = nn.Upsample(mode='nearest', scale_factor=upsample)

        padding = kernel_size // 2

        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding)

    def forward(self, x):

        if self.upsample:
            x = self.upsample_layer(x)

        x = self.conv2d(x)

        return x


class Generator(nn.Module):
    def __init__(self, dataset_name):
        super(Generator, self).__init__()
        self.dataset_name = dataset_name

        if dataset_name in ['mnist', 'fmnist']:
            channels = 1
        elif dataset_name == 'cifar10':
            channels = 3
        else:
            raise Exception('dataset must be one of mnist, fmnist and cifar10')

        self.conv1 = nn.Conv2d(channels, 8, kernel_size=3, stride=1, padding=1)
        self.in1 = nn.InstanceNorm2d(8)

        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)
        self.in2 = nn.InstanceNorm2d(16)

        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.in3 = nn.InstanceNorm2d(32)

        self.resblock1 = ResidualBlock(32)
        self.resblock2 = ResidualBlock(32)
        self.resblock3 = ResidualBlock(32)
        self.resblock4 = ResidualBlock(32)


        self.up1 = UpsampleConvLayer(32, 16, kernel_size=3, stride=1, upsample=2)
        self.in4 = nn.InstanceNorm2d(16)
        self.up2 = UpsampleConvLayer(16, 8, kernel_size=3, stride=1, upsample=2)
        self.in5 = nn.InstanceNorm2d(8)


        self.conv4 = nn.Conv2d(8, channels, kernel_size=3, stride=1, padding=1)
        self.in6 = nn.InstanceNorm2d(channels)


    def forward(self, x):

        x = F.relu(self.in1(self.conv1(x)))
        x = F.relu(self.in2(self.conv2(x)))
        x = F.relu(self.in3(self.conv3(x)))

        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)

        x = F.relu(self.in4(self.up1(x)))
        x = F.relu(self.in5(self.up2(x)))

        x = self.in6(self.conv4(x)) # remove relu for better performance and when input is [-1 1]

        return x



class Discriminator(nn.Module):
    def __init__(self, dataset_name):
        super(Discriminator, self).__init__()
        self.dataset_name = dataset_name

        if dataset_name in ['mnist', 'fmnist']:
            self.conv1 = nn.Conv2d(1, 8, kernel_size=4, stride=2, padding=1)
            self.conv2 = nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1)
            self.in1 = nn.InstanceNorm2d(16)
            self.conv3 = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
            self.in2 = nn.InstanceNorm2d(32)
            self.fc = nn.Linear(3 * 3 * 32, 1)

        elif dataset_name == 'cifar10':
            self.conv1 = nn.Conv2d(3, 8, kernel_size=4, stride=2, padding=1)
            self.conv2 = nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1)
            self.in1 = nn.InstanceNorm2d(16)
            self.conv3 = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
            self.conv4 = nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1)
            self.in2 = nn.InstanceNorm2d(32)
            self.fc = nn.Linear(2 * 2 * 3 * 32, 1)

        else:
            raise Exception("dataset must be one of mnist, fmnist and cifar10")

    def forward(self, x):
        if self.dataset_name in ['mnist', 'fmnist']:
            x = F.leaky_relu(self.conv1(x), negative_slope=0.2)
            x = F.leaky_relu(self.in1(self.conv2(x)), negative_slope=0.2)
            x = F.leaky_relu(self.in2(self.conv3(x)), negative_slope=0.2)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
        else:
            x = F.leaky_relu(self.conv1(x), negative_slope=0.2)
            x = F.leaky_relu(self.in1(self.conv2(x)), negative_slope=0.2)
            x = F.leaky_relu(self.conv3(x), negative_slope=0.2)
            x = F.leaky_relu(self.in2(self.conv4(x)), negative_slope=0.2)
            x = x.view(x.size(0), -1)
            x = self.fc(x)

        return x


In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
import datetime
import os

torch.backends.cudnn.benchmark = True


def CWLoss(logits, target, is_targeted, num_classes=10, kappa=0):
    # inputs to the softmax function are called logits.
    # https://arxiv.org/pdf/1608.04644.pdf
    target_one_hot = torch.eye(num_classes).type(logits.type())[target.long()]

    # workaround here.
    # subtract large value from target class to find other max value
    # https://github.com/carlini/nn_robust_attacks/blob/master/l2_attack.py
    real = torch.sum(target_one_hot*logits, 1)
    other = torch.max((-target_one_hot + 1)*logits - target_one_hot*10000, 1)[0]
    kappa = torch.zeros_like(other).fill_(kappa)

    if is_targeted:
        return torch.sum(torch.max(other-real, kappa))
    return torch.sum(torch.max(real-other, kappa))


def train(G, D, f, target, is_targeted, thres, criterion_adv, criterion_gan, alpha, beta, train_loader, optimizer_G, optimizer_D, epoch, epochs, device, num_steps=3, verbose=True):
    n = 0
    acc = 0

    G.train()
    D.train()
    for i, (img, label) in enumerate(train_loader):
        valid = torch.ones(img.size(0), 1, requires_grad=False).to(device)
        fake = torch.zeros(img.size(0), 1, requires_grad=False).to(device)
        img_real = img.to(device)

        optimizer_G.zero_grad()

        pert = torch.clamp(G(img_real), -thres, thres)
        img_fake = pert + img_real
        img_fake = img_fake.clamp(min=0, max=1)

        y_pred = f(img_fake)

        if is_targeted:
            y_target = torch.empty_like(label).fill_(target).to(device)
            loss_adv = criterion_adv(y_pred, y_target, is_targeted)
            acc += torch.sum(torch.max(y_pred, 1)[1] == y_target).item()
        else:
            y_true = label.to(device)
            loss_adv = criterion_adv(y_pred, y_true, is_targeted)
            acc += torch.sum(torch.max(y_pred, 1)[1] != y_true).item()

        loss_gan = criterion_gan(D(img_fake), valid)
        loss_hinge = torch.mean(torch.max(torch.zeros(1, ).type(y_pred.type()), torch.norm(pert.view(pert.size(0), -1), p=2, dim=1) - thres))

        loss_g = loss_adv + alpha*loss_gan + beta*loss_hinge

        loss_g.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        if i % num_steps == 0:
            loss_real = criterion_gan(D(img_real), valid)
            loss_fake = criterion_gan(D(img_fake.detach()), fake)

            loss_d = 0.5*loss_real + 0.5*loss_fake

            loss_d.backward()
            optimizer_D.step()

        n += img.size(0)

        if verbose:
            print("\rEpoch [%d/%d]: [%d/%d], D Loss: %1.4f, G Loss: %3.4f [H %3.4f, A %3.4f], Acc: %.4f"
                  %(epoch+1, epochs, i, len(train_loader), loss_d.mean().item(), loss_g.mean().item(),
                  loss_hinge.mean().item(), loss_adv.mean().item(), acc/n) , end="")
    
    if verbose: print()
    return acc/n


def test(G, f, target, is_targeted, thres, test_loader, epoch, epochs, device, verbose=True):
    n = 0
    acc = 0

    G.eval()
    for i, (img, label) in enumerate(test_loader):
        img_real = img.to(device)

        pert = torch.clamp(G(img_real), -thres, thres)
        img_fake = pert + img_real
        img_fake = img_fake.clamp(min=0, max=1)

        y_pred = f(img_fake)

        if is_targeted:
            y_target = torch.empty_like(label).fill_(target).to(device)
            acc += torch.sum(torch.max(y_pred, 1)[1] == y_target).item()
        else:
            y_true = label.to(device)
            acc += torch.sum(torch.max(y_pred, 1)[1] != y_true).item()

        n += img.size(0)

        if verbose:
            print('\rTest [%d/%d]: [%d/%d]' %(epoch+1, epochs, i, len(test_loader)), end="")
    
    if verbose: print()
    return acc/n


lr = 0.001  # learning rate
batch_size = 128
num_workers = 4  # number of cpu cores that can be used
epochs = 20
model_name = "Model_MNIST"  # must be "Model_MNIST", "Model_FMNIST" or "Model_CIFAR"
dataset_name = "mnist"  # must be "mnist", "fmnist", or "cifar10"
target = -1  # Target label. -1 means untargeted.
thres = 0.2 # perturbation bound, used in loss_hinge. 0.2 and 0.3 work the best
gpu = True  # use gpu for training

device = 'cuda' if gpu else 'cpu'
# print(torch.cuda.get_device_name(0))
torch.cuda.set_device(0)

is_targeted = False
if target in range(0, 10):
    is_targeted = True # bool variable to indicate targeted or untargeted attack

print('Training AdvGAN ', '(Target %d)'%(target) if is_targeted else '(Untargeted)')

train_data, test_data, in_channels, num_classes = load_dataset(dataset_name)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

D = Discriminator(dataset_name).to(device)
G = Generator(dataset_name).to(device)
f = eval(model_name)(in_channels, num_classes).to(device)

# load a pre-trained target model
checkpoint_path = os.path.join('saved', 'target_models', 'best_%s.pth.tar'%(model_name))
checkpoint = torch.load(checkpoint_path, map_location='cpu')
f.load_state_dict(checkpoint["state_dict"])
f.eval()

optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

scheduler_G = StepLR(optimizer_G, step_size=5, gamma=0.1)
scheduler_D = StepLR(optimizer_D, step_size=5, gamma=0.1)

criterion_adv = CWLoss # loss for fooling target model
criterion_gan = nn.MSELoss() # for gan loss
alpha = 1 # gan loss multiplication factor
beta = 1 # for hinge loss
num_steps = 3 # number of generator updates for 1 discriminator update

for epoch in range(epochs):
    start_time = datetime.datetime.now()

    acc_train = train(G, D, f, target, is_targeted, thres, criterion_adv, criterion_gan, alpha, beta, train_loader, optimizer_G, optimizer_D, epoch, epochs, device, num_steps, verbose=True)
    acc_test = test(G, f, target, is_targeted, thres, test_loader, epoch, epochs, device, verbose=True)

    scheduler_G.step(epoch)
    scheduler_D.step(epoch)

    end_time = datetime.datetime.now()

    print('Epoch [%d/%d]: %.2f seconds\t'%(epoch+1, epochs, (end_time - start_time).total_seconds()))
    print('Train Acc: %.5f\t'%(acc_train))
    print('Test Acc: %.5f\n'%(acc_test))

    torch.save({"epoch": epoch+1,
                "epochs": epochs,
                "is_targeted": is_targeted,
                "target": target,
                "thres": thres,
                "state_dict": G.state_dict(),
                "acc_test": acc_test,
                "optimizer": optimizer_G.state_dict()
                }, "saved/%s_%s.pth.tar"%(model_name, 'target_%d'%(target) if is_targeted else 'untargeted'))


Training AdvGAN  (Untargeted)
Epoch [1/20]: [28/469], D Loss: 0.2722, G Loss: 457.8306 [H 5.0013, A 452.0983], Acc: 0.18132

KeyboardInterrupt: 