In [1]:
from src import NeuralTester, Config
import torch
from src.learner import (
    RevDELearner,
    GeneticLearner,
    REV_DE_DEFAULT_PARAMS,
    GENETIC_DEFAULT_PARAMS,
)
from src.objective_functions import get_penalized_distance
from models import load_stylegan
import numpy as np
import logging
import wandb
import sys



## Example usage of the Neural Tester
We use wandb to log progress in various steps!

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33moliverweissl[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
# Setup logging to be able to view logging outputs in notebooks.
logging.basicConfig(
    level=logging.INFO, format="[%(asctime)s - %(name)s - %(levelname)s] %(message)s"
)
logger = logging.getLogger()
logger.handlers[0].stream = sys.stdout

In [4]:
# Define the configurations for our experiments.
learner_params = GENETIC_DEFAULT_PARAMS
conf = Config(
    samples_per_class=10,
    generations=50,
    mix_dim_range=(0, 8),
    predictor="../models/wrn_mnist.pkl",
    generator="../models/sg2_mnist.pkl",
    learner=GeneticLearner,
    learner_params=learner_params,
)

In [5]:
predictor = torch.load(conf.predictor)  # The System under test (SUT)
generator = load_stylegan(conf.generator)  # The generator network (a stylegan in this case)
learner = conf.learner(
    x0=np.random.rand(learner_params["population_size"] * 2, conf.genome_size), **learner_params
)  # The learner for search based optimization of candidates.
objective_function = get_penalized_distance  # The objective function to calculate fitness with.
device = torch.device("cuda")  # The target device for all operaitons.

In [6]:
tester = NeuralTester(
    predictor=predictor,
    generator=generator,
    learner=learner,
    objective_function=objective_function,
    num_generations=conf.generations,
    mix_dim_range=conf.mix_dim_range,
    device=device,
)  # Here we initialize the Tester object.

In [None]:
tester.test(
    num_classes=10,
    samples_per_class=conf.samples_per_class,
)  # We start the testing procedure.

[2024-11-07 12:09:23,486 - root - INFO] Start testing. Number of classes: 10, iterations per class: 10, total iterations: 100



[2024-11-07 12:09:24,822 - root - INFO] Generate seed(s) for class: 0.
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
[2024-11-07 12:09:25,404 - root - INFO] 	Found 1 valid seed(s) after: 1 iterations.
[2024-11-07 12:09:25,730 - root - INFO] Running Search-Algorithm for 50 generations.
