In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from art.attacks.poisoning import PoisoningAttack
from art.defences.detector import Me
from art.estimators.classification import PyTorchClassifier
from art.utils import load_dataset
import matplotlib.pyplot as plt

# 1. Define the model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 2. Load the dataset
def load_mnist_data():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    return trainloader, testloader

# 3. Train the model
def train_model(model, trainloader, criterion, optimizer, epochs=2):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}")
    print("Finished Training")

# 4. Convert PyTorch model to ART classifier
def get_art_classifier(model):
    return PyTorchClassifier(
        model=model,
        loss=nn.CrossEntropyLoss(),
        optimizer=optim.SGD(model.parameters(), lr=0.01),
        input_shape=(1, 28, 28),
        nb_classes=10
    )

# 5. Implement Poisoning Attack
def apply_poisoning_attack(art_classifier, trainloader):
    # Poisoning attack parameters
    poison_fraction = 0.1  # Fraction of poisoned data
    poison_target = 1  # Target class for the attack
    
    poison_attack = PoisoningAttack(
        classifier=art_classifier,
        poison_fraction=poison_fraction,
        target_class=poison_target
    )
    
    # Poison the training data
    poisoned_data, poisoned_labels = poison_attack.generate_poisoned_data(trainloader.dataset.data, trainloader.dataset.targets)
    
    # Create a poisoned dataloader
    poisoned_dataset = torch.utils.data.TensorDataset(poisoned_data, poisoned_labels)
    poisoned_loader = torch.utils.data.DataLoader(poisoned_dataset, batch_size=64, shuffle=True)
    
    return poisoned_loader

# 6. Evaluate the model
def evaluate_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the test images: {100 * correct // total}%')

# 7. Visualize some samples
def visualize_samples(data_loader, title):
    images, labels = next(iter(data_loader))
    images = images.numpy()
    labels = labels.numpy()
    
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for i in range(5):
        axes[i].imshow(images[i].squeeze(), cmap='gray')
        axes[i].set_title(f'Label: {labels[i]}')
        axes[i].axis('off')
    plt.suptitle(title)
    plt.show()

# 8. Main function
def main():
    # Load data
    trainloader, testloader = load_mnist_data()
    
    # Initialize and train model
    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    train_model(model, trainloader, criterion, optimizer, epochs=2)
    
    # Convert to ART classifier
    art_classifier = get_art_classifier(model)
    
    # Evaluate on clean model
    print("Evaluating clean model:")
    evaluate_model(model, testloader)
    visualize_samples(testloader, 'Clean Test Samples')
    
    # Apply poisoning attack
    poisoned_loader = apply_poisoning_attack(art_classifier, trainloader)
    
    # Re-train model on poisoned data
    print("Re-training model on poisoned data:")
    train_model(model, poisoned_loader, criterion, optimizer, epochs=2)
    
    # Evaluate on poisoned model
    print("Evaluating poisoned model:")
    evaluate_model(model, testloader)
    visualize_samples(testloader, 'Poisoned Test Samples')

    # Apply defense
    print("Applying Median Smoothing defense:")
    defense = MedianSmoothing()
    art_classifier = get_art_classifier(model)
    art_classifier.defense = defense

    # Re-train model on poisoned data with defense
    print("Re-training model on poisoned data with defense:")
    train_model(model, poisoned_loader, criterion, optimizer, epochs=2)
    
    # Evaluate with defense
    print("Evaluating model with defense:")
    evaluate_model(model, testloader)
    visualize_samples(testloader, 'Defended Test Samples')

if __name__ == "__main__":
    main()