In [1]:
import torch
import torchvision
from lib.fcn import FCN
from lib.attack import FGSMAttack

In [2]:
gpu_name = "cuda:0"
device = torch.device(gpu_name if torch.cuda.is_available() else "cpu")

In [3]:
transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
])
_train = torchvision.datasets.MNIST(root='data/mnist', train=True,
                                    download=True, transform=transform)
train = torch.utils.data.DataLoader(_train, batch_size=32,
                                    shuffle=True, num_workers=2)

_test = torchvision.datasets.MNIST(root='data/mnist', train=True,
                                   download=True, transform=transform)
test = torch.utils.data.DataLoader(_test, batch_size=32,
                                   shuffle=False, num_workers=2)

In [4]:
model = FCN()
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [5]:
fgsm = FGSMAttack(model, criterion=criterion, epsilon=0.3)
for epoch in range(5):
    # train model
    model.train()
    for data in train:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs = fgsm.generate(inputs)
        
        outputs = model(inputs)
        adv_outputs = model(adv_inputs)
        loss = criterion(outputs, labels) + criterion(adv_outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # eval model
    model.eval()
    total = 0
    correct = 0
    adv_correct = 0
    for data in test:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs = fgsm.generate(inputs)
                
        prediction = model(inputs).max(1)[1]
        adv_prediction = model(adv_inputs).max(1)[1]
                
        total += labels.size(0)
        correct += prediction.eq(labels).sum().item()
        adv_correct += adv_prediction.eq(labels).sum().item()

    test_acc = 100*correct/total
    adv_test_acc = 100*adv_correct/total
    
    print("[Epoch:%2d] Test: %.2f%%     AdvTest: %.2f%%"%(epoch+1,test_acc,adv_test_acc))
    

In [6]:
import os
if not os.path.isdir("model"):
    os.makedirs("model")
torch.save(model.state_dict(), "model/FCN_MNIST_ADT")