In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
# Set the root path as if it's being ran from the project root
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if project_root not in sys.path:
    sys.path.append(project_root)

In [9]:
import pygad
import numpy as np
import torch
import os
from ga.fitness import constrained_fitness_func
from ga.model import load_model, evaluate_without_perturbation, evaluate_with_perturbation
from ga.utils import get_dataloader, visualize_images_batch,visualize_image, visualize_perturbation, compute_pixel_statistics, load_config

# Parameters

In [28]:
config = load_config()

In [11]:
num_generations =num_generations=config["ga"]["num_generations"]
num_parents_mating=config["ga"]["num_parents_mating"]
sol_per_pop=config["ga"]["sol_per_pop"]
init_range_low=config["ga"]["init_range_low"]
init_range_high=config["ga"]["init_range_high"]
mutation_percent_genes=config["ga"]["mutation_percent_genes"]

model_type = config["model"]["model_type"]
batch_size = config["model"]["batch_size"]

visualize = config["visualization"]["visualize"]
visualize_every = config["visualization"]["visualize_every"]


# Model

In [12]:
model = load_model(config["model"]["model_type"])

# Loading

In [13]:
# Load
project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
image_dir = os.path.join(project_root, "nn/data/imagenet/val")
dataloader = get_dataloader(config["model"]["batch_size"], image_dir)


In [14]:
# Compute the pixel mean and standard deviation for each pixel across the entire dataset - for constrained fitness func
pixel_mean, pixel_std = compute_pixel_statistics(dataloader)

In [15]:
# First batch
input_batch, original_labels = next(iter(dataloader))

# Collect top perturbations from each generation
top_perturbations = []

# GA Config

In [25]:
def fitness_wrapper(ga_instance, solution, solution_idx):
    return constrained_fitness_func(
        ga_instance, solution, solution_idx, pixel_std, model, input_batch, original_labels, config["fitness"]["pixel_constraint_weight"], config["fitness"]["max_perturbation_magnitude"]
    )

In [26]:
def on_generation(ga_instance):

    print(f"\nGeneration {ga_instance.generations_completed} completed with fitness: {ga_instance.last_generation_fitness}")

    input_batch, original_labels = next(iter(dataloader))
    print(f"New batch loaded with first label: {original_labels[0]}")
    
    # Print the best fitness for this generation
    best_solution, best_solution_fitness, _ = ga_instance.best_solution()
    print(f"Best Fitness = {best_solution_fitness}\n")

    best_perturbation = torch.tensor(best_solution).float().reshape(input_batch.shape[1:])
    top_perturbations.append(best_perturbation)
    print(f"Best perturbation magnitude: {torch.norm(best_perturbation).item()}")

    ########
    # VISUALIZATION
    ########
    if config["visualization"]["visualize"] and ga_instance.generations_completed % config["visualization"]["visualize_every"] == 0:
        print(f"Visualizing")
        # get the current best perturbation
        visualize_images_batch(input_batch, best_perturbation)

    # print(f"Generation {ga_instance.generations_completed}: Current Fitness: Best Fitness = {ga_instance.best_solution()[1]}")

In [31]:
ga_instance = pygad.GA(
    num_generations=config["ga"]["num_generations"],
    num_parents_mating=config["ga"]["num_parents_mating"],
    sol_per_pop=config["ga"]["sol_per_pop"],
    num_genes=3*224*224,
    gene_type=float,
    init_range_low=config["ga"]["init_range_low"],
    init_range_high=config["ga"]["init_range_high"],
    mutation_percent_genes=config["ga"]["mutation_percent_genes"],
    fitness_func=fitness_wrapper,
    on_generation=on_generation,
)

ga_instance.run()

Perturbation magnitude after channel-wise clamping: 22.428010940551758
Perturbation magnitude after global clamping: 22.428010940551758
After applying pixel constraints: 22.428010940551758
Perturbation magnitude after channel-wise clamping: 22.33572006225586
Perturbation magnitude after global clamping: 22.33572006225586
After applying pixel constraints: 22.33572006225586
Perturbation magnitude after channel-wise clamping: 22.379501342773438
Perturbation magnitude after global clamping: 22.379501342773438
After applying pixel constraints: 22.379501342773438
Perturbation magnitude after channel-wise clamping: 22.397825241088867
Perturbation magnitude after global clamping: 22.397825241088867
After applying pixel constraints: 22.397825241088867
Perturbation magnitude after channel-wise clamping: 22.384401321411133
Perturbation magnitude after global clamping: 22.384401321411133
After applying pixel constraints: 22.384401321411133
Perturbation magnitude after channel-wise clamping: 22.453

KeyboardInterrupt: 

In [29]:
input_batch.shape

torch.Size([64, 3, 224, 224])

In [22]:
# After the run
solution, solution_fitness, _ = ga_instance.best_solution()
print(f"Best solution fitness: {solution_fitness}") # Might be the penultimate best fitness??

Best solution fitness: 0.84375


# Calculating Universal Perturbation

In [25]:
if len(top_perturbations) > 0:
    universal_perturbation = torch.mean(torch.stack(top_perturbations), dim=0)
    print(universal_perturbation.shape)
    print(f"Universal perturbation created.")
    visualize_perturbation(universal_perturbation)

AttributeError: 'list' object has no attribute 'shape'

# Metrics

In [None]:
evaluate_without_perturbation(model, dataloader, universal_perturbation)

In [None]:
evaluate_with_perturbation(model, dataloader, universal_perturbation)