In [1]:
import torch
from tqdm.auto import tqdm
from advsecurenet.models.model_factory import ModelFactory
from advsecurenet.datasets import DatasetFactory
from advsecurenet.dataloader import DataLoaderFactory
from advsecurenet.shared.types import DatasetType
from advsecurenet.utils.model_utils import save_model
from advsecurenet.shared.types.configs.train_config import TrainConfig
from advsecurenet.defenses import AdversarialTraining
from advsecurenet.attacks.fgsm import FGSM
from advsecurenet.attacks.pgd import PGD
from advsecurenet.attacks.lots import LOTS
from advsecurenet.shared.types.configs.defense_configs.adversarial_training_config import AdversarialTrainingConfig
import advsecurenet.shared.types.configs.attack_configs as AttackConfigs
from advsecurenet.utils.tester import Tester
from advsecurenet.utils.trainer import Trainer
from advsecurenet.utils.evaluation import AdversarialAttackEvaluator



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load basic MNIST model
mnist_model = ModelFactory.create_model(model_name='CustomMnistModel', num_classes=10, num_input_channels=1)
mnist_model

CustomModel(
  (model): CustomMnistModel(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fc1): Linear(in_features=50176, out_features=512, bias=True)
    (relu): ReLU()
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [3]:
# load MNIST dataset
dataset = DatasetFactory.create_dataset(DatasetType.MNIST)
train_data = dataset.load_dataset(train=True)
test_data = dataset.load_dataset(train=False)
train_loader = DataLoaderFactory.create_dataloader(dataset=train_data, batch_size=128, shuffle=True)
test_loader = DataLoaderFactory.create_dataloader(dataset=test_data, batch_size=128, shuffle=False)
print(f"Train dataset size: {len(train_data)}")
print(f"Test dataset size: {len(test_data)}")

Train dataset size: 60000
Test dataset size: 10000


In [4]:
# first normal training
train_config = TrainConfig(model= mnist_model, train_loader=train_loader, epochs=1, device="cuda:2")
trainer = Trainer(train_config)
trainer.train()

100%|██████████| 469/469 [00:19<00:00, 23.97it/s]

Epoch 1 loss: 0.15511241101864368





In [5]:
# get FGSM attack
fgsm_config = AttackConfigs.FgsmAttackConfig(epsilon=0.5, device="cuda:2")
fgsm = FGSM(fgsm_config)

In [6]:
evaluator = AdversarialAttackEvaluator()

In [7]:
w

Generating Adversarial Images: 100%|██████████| 79/79 [00:01<00:00, 41.29it/s]

FGSM attack success rate: 0.15338212025316456





In [8]:
# use fgsm to adversarially train the model
robust_model = ModelFactory.create_model(model_name='CustomMnistModel', num_classes=10, num_input_channels=1)
adversarial_training_config = AdversarialTrainingConfig(model=robust_model, models=[robust_model], attacks=[fgsm], train_loader=train_loader, epochs=1, device="cuda:2")
adversarial_training = AdversarialTraining(adversarial_training_config)
adversarial_training.train()

Running epoch 1...


100%|██████████| 469/469 [00:19<00:00, 23.77it/s]

Epoch 1/1 Loss: 0.15664802291152527





In [9]:
# testing the clean accuracy of the adversarially trained model
tester = Tester(model=robust_model, test_loader=test_loader, device="cuda:2")
tester.test()

Testing on cuda:2


Testing: 100%|██████████| 79/79 [00:01<00:00, 65.71batch/s] 


Test set: Average loss: 0.0004, Accuracy: 9841/10000 (98.41%)





(0.0003975198087573517, 98.41)

In [10]:
attack_success_rate = 0
for images, labels in tqdm(test_loader, desc='Generating Adversarial Images'):
    images = images.to("cuda:2")
    labels = labels.to("cuda:2")
    fgsm_images = fgsm.attack(model=robust_model, x=images, y=labels)
    fgsm_images = fgsm_images.to("cuda:2")
    attack_success_rate += evaluator.evaluate_attack(robust_model, images, labels, fgsm_images)

print(f"FGSM attack success rate: {attack_success_rate/len(test_loader)}")

Generating Adversarial Images: 100%|██████████| 79/79 [00:01<00:00, 48.46it/s]

FGSM attack success rate: 0.014042721518987342



