In [None]:
# Importing necessary packages, and classes adapted from decision boundary visualization studies

from huggingface_hub import hf_hub_download
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from PIL import Image
from IPython.display import display
from matplotlib import pyplot as plt

import Models
import Attacks
import Tools
from ssnp import SSNP, visualize_decision_boundaries

In [None]:
# Importing MNIST dataset from online
data_dir = os.path.abspath(os.path.join(os.getcwd(), "../data"))
MNIST_data = datasets.MNIST(root=data_dir, train=False, download=True, transform=transforms.ToTensor())

# pre-select adversarial examples
ADVERSARIAL_BATCH_N = 200
SEED = 42  # For reproducibility

torch.manual_seed(SEED)
random_indices = torch.randint(0, MNIST_data.data.shape[0], (ADVERSARIAL_BATCH_N,))

adversarial_images = MNIST_data.data[random_indices].float()
adversarial_labels = MNIST_data.targets[random_indices]

In [None]:
# Importing the MNIST classifier model
MODEL_SETTING = "small"

if MODEL_SETTING == "small":
    # Use SmallConvNet for MNIST classification
    MNIST_model = Models.SmallConvNet()
elif MODEL_SETTING == "normal":
    # Use ConvNet for MNIST classification
    MNIST_model = Models.ConvNet()

device_torch = torch.device(
    "mps" if torch.backends.mps.is_available() else 
    "cuda" if torch.cuda.is_available() else 
    "cpu"
    )

In [None]:
# Loop through all trained models to produce adversarial examples
#MODELS = ["Constant", "Cyclic", "Exponential", "Linear", "LinearUniformMix", "Random"]
MODELS = ["Cyclic", "Exponential", "Linear", "LinearUniformMix"]
#MODELS = ["Constant"]
MODEL_REPO_ID = "JulienStal/MNIST-SmallConvs-AdversarialSchedulers"
EPSILON = 0.3
K = 10
BATCH_SIZE = 10000
EPOCHS = 200
PATIENCE = 50

SAVE_SSNP_HF = False
SSNP_REPO_ID = "JulienStal/SSNPs"

ssnps = []
im_grids = []
prob_grids = []
dbms = []
models = []
pts_arr = []

for model in MODELS:
    # Import the model weights from HuggingFace
    model_name = f"{MODEL_SETTING}_conv_{model}"
    model_file_name = f"model_{model}.pth"    
    local_pth_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=model_file_name, force_download=True)    
    MNIST_model.load_state_dict(torch.load(local_pth_path, map_location=device_torch))

    for idx, (image, label) in enumerate(zip(adversarial_images, adversarial_labels)):
        if image.dim() == 2: image = image.unsqueeze(0).unsqueeze(0)
        elif image.dim() == 3: image = image.unsqueeze(0)
        label = torch.tensor([label])  # Make label a batch

        adv_image, perturbation = Attacks.pgd_attack(image, label, MNIST_model, nn.CrossEntropyLoss(), EPSILON, K, device_torch)
        adversarial_images[idx] = adv_image.squeeze()

    # Generate visualization
    ssnp_model_path = os.path.abspath(os.path.join(os.getcwd(), "../models/ssnp", f"ssnp_{model}"))
    
    ssnp, im_grid, prob_grid, dbm, pts = visualize_decision_boundaries(
        original_dataset = MNIST_data,
        dataset_name = "MNIST",
        classifier_model = MNIST_model,
        classifier_model_name = model_name,
        ssnp_path_and_name = ssnp_model_path,
        batch_size = BATCH_SIZE,
        adversarial_images=adversarial_images,
        ssnp_training_epochs = EPOCHS,
        ssnp_training_patience = PATIENCE,
        verbose = False,
    )

    ssnps.append(ssnp)
    im_grids.append(im_grid)
    prob_grids.append(prob_grid)
    dbms.append(dbm)
    models.append(MNIST_model)
    pts_arr.append(pts)

In [None]:
# Print the synthetic image for a given pixel
PIXEL_X, PIXEL_Y = (150, 150) # vertical and horizontal
index = PIXEL_X*300 + PIXEL_Y

ssnp = ssnps[0]
img_grid = im_grids[0]
prob_grid = prob_grids[0]
dbm = dbms[0]
model = models[0]
pts = pts_arr[0]

# Set pixel to red in the dbm and display
dbm.putpixel((PIXEL_Y, PIXEL_X), (255, 0, 0))  # Set pixel to red
display(dbm)

# Points values of the pixel
pt = pts[index]

# Get the synthetic image for the given pixel and print it
synthetic_img = torch.tensor(ssnp.inverse_transform(torch.tensor(pt).unsqueeze(0))).view(1, 1, 28, 28)
label_tensor = torch.tensor(img_grid[PIXEL_X][PIXEL_Y]).view(1,1)

# Print the label of the synthetic image
print(f"Label of synthetic image at pixel ({PIXEL_X}, {PIXEL_Y}): {label_tensor.item()} (confidence {prob_grid[PIXEL_X][PIXEL_Y]:.2f})")

Tools.plot_predictions(model, synthetic_img, label_tensor, device_torch, 1)