In [None]:
import time
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from art.estimators.classification import PyTorchClassifier
from art.attacks.evasion import FastGradientMethod
from skimage.metrics import structural_similarity as ssim
from torchvision import datasets, transforms
import torchvision.models as models
from PIL import Image

# Softmax Activation Function
def softmax_activation(inputs): 
    inputs = inputs.tolist()
    exp_values = np.exp(inputs - np.max(inputs))
    probabilities = exp_values / np.sum(exp_values)
    return probabilities

# Data Augmentation and Transformation
transform_train = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Loading Datasets
def load_datasets(dataset_name='CIFAR10'):
    if dataset_name == 'CIFAR10':
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
    elif dataset_name == 'CIFAR100':
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        num_classes = 100
    elif dataset_name == 'MNIST':
        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
    else:
        raise ValueError("Invalid dataset name. Choose 'CIFAR10', 'CIFAR100', or 'MNIST'.")
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    return train_loader, test_loader, num_classes

# Dataset selection
train_loader, test_loader, num_classes = load_datasets('CIFAR10')  # Change to 'CIFAR100' or 'MNIST' as needed

# Convert tensor to image
def im_convert(tensor):  
    image = tensor.cpu().clone().detach().numpy()
    image = image.transpose(1, 2, 0)
    image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
    image = image.clip(0, 1)
    return image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Simple CNN Model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleCNN(num_classes=num_classes).to(device)

# Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training Loop
epochs = 5
for e in range(epochs):
    running_loss = 0.0
    running_corrects = 0.0
    val_running_loss = 0.0
    val_running_corrects = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data)
    
    with torch.no_grad():
        for val_inputs, val_labels in test_loader:
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)
            _, val_preds = torch.max(val_outputs, 1)
            val_running_loss += val_loss.item()
            val_running_corrects += torch.sum(val_preds == val_labels.data)

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = running_corrects.float() / len(train_loader)
    val_epoch_loss = val_running_loss / len(test_loader)
    val_epoch_acc = val_running_corrects.float() / len(test_loader)

    print(f"Epoch {e+1}: Train Loss {epoch_loss:.4f}, Train Acc {epoch_acc:.4f}")
    print(f"Epoch {e+1}: Val Loss {val_epoch_loss:.4f}, Val Acc {val_epoch_acc:.4f}")

# FGSM Adversarial Attack with Flame Flair Comparison
def compare_images(imageA, imageB):
    return 1 - ssim(imageA, imageB, multichannel=True)

fgsm_attack = FastGradientMethod(PyTorchClassifier(
    model=model,
    loss=torch.nn.CrossEntropyLoss(),
    input_shape=(3, 32, 32),
    nb_classes=num_classes,
    device_type=device
), eps=0.05)

dataiter = iter(test_loader)
images, labels = dataiter.next()
images, labels = images.to(device), labels.to(device)

# Generate Adversarial Examples
start = time.time()
x_test_adv = fgsm_attack.generate(images.cpu().detach().numpy())
x_test_adv = torch.tensor(x_test_adv, dtype=torch.float32, device=device)
print("Attack time: {:.4f} seconds".format(time.time() - start))

# Evaluate on Adversarial Examples
output_adv = model(x_test_adv)
_, preds_adv = torch.max(output_adv, 1)

# Compute accuracy for adversarial examples
class_correct_adv = [0 for _ in range(num_classes)]
class_total_adv = [0 for _ in range(num_classes)]
for idx in np.arange(len(labels)):
    predicted_class_adv = preds_adv[idx].item()
    true_class_adv = labels[idx].item()
    class_correct_adv[predicted_class_adv] += (predicted_class_adv == true_class_adv)
    class_total_adv[true_class_adv] += 1

class_accuracy_adv = [float(correct) / total if total > 0 else 0.0
                      for correct, total in zip(class_correct_adv, class_total_adv)]

for i, (class_name, accuracy) in enumerate(zip(classes, class_accuracy_adv)):
    print(f"Class {class_name} (Adversarial): {accuracy:.4f}")

# Compare the original and adversarial images using SSIM for Flame Flair
flame_flair_comparison = compare_images(im_convert(images[0]), im_convert(x_test_adv[0]))
print(f"Flame Flair Comparison (SSIM) between original and adversarial image: {flame_flair_comparison:.4f}")

# Visualizing the Adversarial Attack Effect
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(im_convert(images[0]))
ax[0].set_title(f"Original Image: {classes[labels[0]]}", color="green")
ax[1].imshow(im_convert(x_test_adv[0]))
ax[1].set_title(f"Adversarial Image: {classes[preds_adv[0]]}", color="red")
plt.show()
