In [None]:
from tqdm import tqdm
from advsecurenet.models.model_factory import ModelFactory
from advsecurenet.datasets import DatasetFactory
from advsecurenet.dataloader import DataLoaderFactory
from advsecurenet.shared.types import DatasetType
from advsecurenet.utils.model_utils import train as train_model, test as test_model, save_model
from advsecurenet.shared.types.configs.train_config import TrainConfig
from advsecurenet.defenses import AdversarialTraining
from advsecurenet.attacks.fgsm import FGSM
from advsecurenet.attacks.pgd import PGD
from advsecurenet.shared.types.configs.defense_configs.adversarial_training_config import AdversarialTrainingConfig
import advsecurenet.shared.types.configs.attack_configs as AttackConfigs



In [None]:
# load basic MNIST model
mnist_model = ModelFactory.get_model(model_variant='CustomMnistModel', num_classes=10, num_input_channels=1)
mnist_model

In [None]:
# load MNIST dataset
dataset = DatasetFactory.load_dataset(DatasetType.MNIST)
train_data = dataset.load_dataset(train=True)
test_data = dataset.load_dataset(train=False)
train_loader = DataLoaderFactory.get_dataloader(dataset=train_data, batch_size=128, shuffle=True)
test_loader = DataLoaderFactory.get_dataloader(dataset=test_data, batch_size=128, shuffle=False)
print(f"Train dataset size: {len(train_data)}")
print(f"Test dataset size: {len(test_data)}")

In [None]:
# first normal training
train_config = TrainConfig(model= mnist_model, train_loader=train_loader, epochs=1, device="mps")
train_model(train_config)
test_model(mnist_model, test_loader,device="mps")

In [None]:
# Function to test model robustness against different attacks
def test_model_robustness(model, test_loader, attack, device):
    model.eval()  # Set the model to evaluation mode

    # Initialize counters
    correct = 0
    adv_correct = 0
    total = 0

    for data, target in tqdm(test_loader, desc='Testing'):
        # Send data and target to the same device as your model
        data, target = data.to(device), target.to(device)
        
        # Get the original model's predictions
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        # Generate adversarial data using the provided attack method
        fgsm_data = attack.attack(model=model, x=data, y=target)
        
        # Get the model's predictions on the adversarial data
        adv_output = model(fgsm_data)
        adv_pred = adv_output.argmax(dim=1, keepdim=True)
        adv_correct += adv_pred.eq(target.view_as(adv_pred)).sum().item()

        total += target.size(0)

    # Calculate the original accuracy
    original_accuracy = correct / total

    # Calculate the adversarial accuracy
    adversarial_accuracy = adv_correct / total

    # Calculate the robustness as the difference in accuracies
    robustness = original_accuracy - adversarial_accuracy

    print(f'Original Accuracy: {original_accuracy:.2%}')
    print(f'Adversarial Accuracy: {adversarial_accuracy:.2%}')
    print(f'Robustness (Accuracy Drop): {robustness:.2%}')


In [None]:
# get FGSM attack
fgsm_config = AttackConfigs.FgsmAttackConfig(epsilon=0.5, device="mps")
fgsm = FGSM(fgsm_config)

In [None]:
# Testing base model against FGSM attack 
test_model_robustness(mnist_model, test_loader, fgsm, device="mps")

In [None]:
# use fgsm to adversarially train the model
robust_model = ModelFactory.get_model(model_variant='CustomMnistModel', num_classes=10, num_input_channels=1)
adversarial_training_config = AdversarialTrainingConfig(model=robust_model, models=[robust_model], attacks=[fgsm], train_loader=train_loader, epochs=5, device="mps")
adversarial_training = AdversarialTraining(adversarial_training_config)
adversarial_training.adversarial_training()

In [None]:
# testing the clean accuracy of the adversarially trained model
test_model(robust_model, test_loader)