In [None]:
import copy
import torch

def fgsm_attack(image, model, epsilon=0.02):
    # Clone the image to avoid modifying the original
    perturbed_image = image.clone().detach()
    
    # Forward pass to get the network's predictions
    original_predictions = model(perturbed_image)
    
    # Get the predicted class label
    original_label = original_predictions.max(1)[1].item()

    # Calculate the gradients
    loss = torch.nn.CrossEntropyLoss()
    loss_value = loss(original_predictions, torch.tensor([original_label]))
    loss_value.backward()
    
    # Calculate the sign of the gradient
    gradient_sign = perturbed_image.grad.sign()
    
    # Create the perturbation
    perturbation = epsilon * gradient_sign
    
    # Apply the perturbation to the image
    perturbed_image = image + perturbation
    
    # Get the model's predictions for the perturbed image
    perturbed_predictions = model(perturbed_image)
    
    # Get the predicted class label for the perturbed image
    perturbed_label = perturbed_predictions.max(1)[1].item()
    
    return perturbation, original_label, perturbed_label, perturbed_image