In [None]:
from src import NeuralTester, ExperimentConfig
import torch
from src.learner import (
    PymooLearner,
    RevDELearner,
    GeneticLearner,
)
from src.learner.auxiliary_components import (
    PYMOO_DE_DEFAULT_PARAMS,
    REV_DE_DEFAULT_PARAMS,
    GENETIC_DEFAULT_PARAMS,
)
from src.objective_functions import penalized_distance, ms_ssim, ssim_d2, uqi
from models import load_stylegan
import numpy as np
import logging
import wandb
import sys
from functools import partial
from pymoo.algorithms.soo.nonconvex.de import DE
from pymoo.operators.sampling.lhs import LHS

## Example usage of the Neural Tester
We use wandb to log progress in various steps!

In [None]:
wandb.login()

In [None]:
# Setup logging to be able to view logging outputs in notebooks.
logging.basicConfig(
    level=logging.INFO, format="[%(asctime)s - %(name)s - %(levelname)s] %(message)s"
)
logger = logging.getLogger()
logger.handlers[0].stream = sys.stdout

In [None]:
# Define the configurations for our experiments.
conf = ExperimentConfig(
    samples_per_class=10,
    generations=50,
    mix_dim_range=(0, 8),
    predictor="../models/wrn_mnist.pkl",
    generator="../models/sg2_mnist.pkl",
    learner=PymooLearner,
    metric=uqi,
)
learner_params = PYMOO_DE_DEFAULT_PARAMS

learner_params["n_var"] = conf.genome_size  # for pymoo
learner_params["algo_params"]["pop_size"] = 10 * conf.genome_size
# learner_params["x0"] = np.random.rand(learner_params["population_size"] * 2, conf.genome_size)  # for own

In [None]:
predictor = torch.load(conf.predictor)  # The System under test (SUT)
generator = load_stylegan(conf.generator)  # The generator network (a stylegan in this case)
learner = conf.learner(**learner_params)  # The learner for search based optimization of candidates.
objective_function = partial(
    penalized_distance, metric=conf.metric
)  # The objective function to calculate fitness with.
device = torch.device("cuda")  # The target device for all operaitons.

In [None]:
tester = NeuralTester(
    predictor=predictor,
    generator=generator,
    learner=learner,
    objective_functions=[objective_function],
    num_generations=conf.generations,
    mix_dim_range=conf.mix_dim_range,
    device=device,
)  # Here we initialize the Tester object.

In [None]:
tester.test(
    num_classes=10,
    samples_per_class=conf.samples_per_class,
)  # We start the testing procedure.