In [17]:
import logging
import multiprocessing
import threading

import niapy
from niapy.algorithms.basic import (
    BatAlgorithm,
    FireflyAlgorithm,
    ParticleSwarmOptimization,
)
from niapy.algorithms import Algorithm
from niapy.problems.sphere import Sphere
from niapy.task import Task
import torch
import pandas
from matplotlib import pyplot as plt
import numpy as np
from numpy.random import default_rng

from util.optimization_data import SingleRunData, PopulationData
from util.diversity_metrics import DiversityMetric

#### Constants

In [18]:
RNG_SEED = 0

In [19]:
def run_optimization(algorithm: Algorithm, task, single_run_data: SingleRunData):
    try:
        algorithm.callbacks.before_run()
        pop, fpop, params = algorithm.init_population(task)
        # reset seed to random
        algorithm.rng = default_rng()
        xb, fxb = algorithm.get_best(pop, fpop)
        while not task.stopping_condition():
            algorithm.callbacks.before_iteration(pop, fpop, xb, fxb, **params)
            pop, fpop, xb, fxb, params = algorithm.run_iteration(
                task, pop, fpop, xb, fxb, **params
            )

            # save population data
            pop_data = PopulationData(
                population=np.array(pop), population_fitness=np.array(fpop)
            )
            pop_data.calculate_metrics([DiversityMetric.PDC])
            single_run_data.add_population(pop_data)

            algorithm.callbacks.after_iteration(pop, fpop, xb, fxb, **params)
            task.next_iter()
        algorithm.callbacks.after_run()
        return xb, fxb * task.optimization_type.value
    except BaseException as e:
        if (
            threading.current_thread() is threading.main_thread()
            and multiprocessing.current_process().name == "MainProcess"
        ):
            raise e
        algorithm.exception = e
        return None, None