Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify implementation of fast non-dominated sort #5160

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b3fc441
Implement testing._create_frozen_trial()
Alnusjaponica Dec 21, 2023
63a9488
Replace test_cma._create_frozen_trial and test_nsgaii._create_frozen_…
Alnusjaponica Dec 21, 2023
0e70a1f
Implement _fast_non_dominated_sort()
Alnusjaponica Dec 21, 2023
319bce0
replace _calculate_nondomination_rank() with _fast_non_dominated_sort()
Alnusjaponica Dec 21, 2023
90fc55e
Move comparison-time validation to _validate_constraints()
Alnusjaponica Dec 21, 2023
01cb145
Add helper function to calculate penalty
Alnusjaponica Dec 21, 2023
9d358b4
Rename a module
Alnusjaponica Dec 21, 2023
43a9ae3
Update argument for _validate_constraints()
Alnusjaponica Dec 21, 2023
373903c
Remove _fast_non_dominated_sort()
Alnusjaponica Dec 21, 2023
516e0ab
Add wrapper of _fast_non_dominated_sort for constrained nsga algorithm
Alnusjaponica Dec 21, 2023
04385eb
Move test_calculate_nondomination_rank() as non dominated sort logic …
Alnusjaponica Dec 21, 2023
a613050
Replace _fast_non_dominated_sort with _rank_population in the tests
Alnusjaponica Dec 21, 2023
a4cd1cd
Handle the case where both trials are infeasible and have the same pe…
Alnusjaponica Dec 22, 2023
b88d010
Update penalty handling
Alnusjaponica Feb 1, 2024
881009f
Remove a comment
Alnusjaponica Feb 1, 2024
1bac8bd
Merge branch 'master' of https://github.com/optuna/optuna into unifiy…
Alnusjaponica Feb 1, 2024
a174c04
Remove unnecessary assertion
Alnusjaponica Feb 1, 2024
950fcc0
Remove unnecessary assertion
Alnusjaponica Feb 1, 2024
f5ae9d8
Fix rank starts
Alnusjaponica Feb 1, 2024
bd54538
Initialize num_constraints before the loop
Alnusjaponica Feb 1, 2024
a081ef0
Re-write test to use @pytest.mark.parametrize
Alnusjaponica Feb 1, 2024
cc46d17
Add test cases for duplicate values
Alnusjaponica Feb 1, 2024
a38e76e
Add test cases of different constraint dimension
Alnusjaponica Feb 1, 2024
51f4458
import annotations to use the type hint list[float]
Alnusjaponica Feb 1, 2024
b484b28
Merge branch 'master' of https://github.com/optuna/optuna into unifiy…
Alnusjaponica Feb 8, 2024
8fcd160
Reduce lines
Alnusjaponica Feb 8, 2024
13aa189
Remove unnecessary initialization
Alnusjaponica Feb 8, 2024
afac585
Merge branch 'unifiy-implementation-of-fast-nondominated-sort' of htt…
Alnusjaponica Feb 8, 2024
9876648
Reduce lines
Alnusjaponica Feb 8, 2024
60ea231
Reduce lines
Alnusjaponica Feb 8, 2024
a5c55e4
Move nondomination_rank definition for readability
Alnusjaponica Feb 15, 2024
c6a96d7
run np.isnan(penalty) only once
Alnusjaponica Feb 15, 2024
37b043d
pass n_below for calculating ranks of infeasible trials
Alnusjaponica Feb 15, 2024
dfbb3dd
Simplify while loop
Alnusjaponica Feb 15, 2024
6593e16
Rename test name
Alnusjaponica Feb 15, 2024
70d5802
Pass base rank to _calculate_nondomination_rank
Alnusjaponica Feb 15, 2024
378c390
Fix nondomination_rank update for rank=-1 cases
Alnusjaponica Feb 15, 2024
b6343d1
Ignore rank-1 trials in _rank_population
Alnusjaponica Feb 15, 2024
4bfe871
Fix docstring to make variable name cinsistent
Alnusjaponica Feb 22, 2024
5c25571
Rename variable from `is_nan` to `is_penalty_nan `
Alnusjaponica Feb 22, 2024
e4f193e
Cover test cases for invalid input
Alnusjaponica Feb 22, 2024
3c6547b
Early return when n_below<=-1
Alnusjaponica Feb 22, 2024
7ae7855
Replace _calculate_nondomination_rank with np.unique for 1d-array sort
Alnusjaponica Feb 22, 2024
a490277
Add docstring ro _fast_non_dominated_sort()
Alnusjaponica Feb 22, 2024
542f011
Fix flake8 error
Alnusjaponica Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 6 additions & 7 deletions optuna/samplers/_nsgaiii/_elite_population_selection_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
import numpy as np

from optuna.samplers._lazy_random_state import LazyRandomState
from optuna.samplers.nsgaii._dominates import _constrained_dominates
from optuna.samplers.nsgaii._dominates import _validate_constraints
from optuna.samplers.nsgaii._elite_population_selection_strategy import _fast_non_dominated_sort
from optuna.samplers.nsgaii._constraints_evaluation import _validate_constraints
from optuna.samplers.nsgaii._elite_population_selection_strategy import _rank_population
from optuna.study import Study
from optuna.study._multi_objective import _dominates
from optuna.trial import FrozenTrial


Expand Down Expand Up @@ -52,10 +50,11 @@ def __call__(self, study: Study, population: list[FrozenTrial]) -> list[FrozenTr
Returns:
A list of trials that are selected as elite population.
"""
_validate_constraints(population, self._constraints_func)
_validate_constraints(population, is_constrained=self._constraints_func is not None)
population_per_rank = _rank_population(
population, study.directions, is_constrained=self._constraints_func is not None
)

dominates = _dominates if self._constraints_func is None else _constrained_dominates
population_per_rank = _fast_non_dominated_sort(population, study.directions, dominates)
elite_population: list[FrozenTrial] = []
for population in population_per_rank:
if len(elite_population) + len(population) < self._population_size:
Expand Down
24 changes: 4 additions & 20 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optuna.search_space.group_decomposed import _GroupDecomposedSearchSpace
from optuna.search_space.group_decomposed import _SearchSpaceGroup
from optuna.study import Study
from optuna.study._multi_objective import _fast_non_dominated_sort
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
Expand Down Expand Up @@ -607,21 +608,6 @@ def after_trial(
self._random_sampler.after_trial(study, trial, state, values)


def _calculate_nondomination_rank(loss_vals: np.ndarray, n_below: int) -> np.ndarray:
ranks = np.full(len(loss_vals), -1)
num_ranked = 0
rank = 0
domination_mat = np.all(loss_vals[:, None, :] >= loss_vals[None, :, :], axis=2) & np.any(
loss_vals[:, None, :] > loss_vals[None, :, :], axis=2
)
while num_ranked < n_below:
counts = np.sum((ranks == -1)[None, :] & domination_mat, axis=1)
num_ranked += np.sum((counts == 0) & (ranks == -1))
ranks[(counts == 0) & (ranks == -1)] = rank
rank += 1
return ranks


def _split_trials(
study: Study,
trials: list[FrozenTrial],
Expand Down Expand Up @@ -693,13 +679,11 @@ def _split_complete_trials_multi_objective(
# The type of trials must be `list`, but not `Sequence`.
return [], list(trials)

lvals = np.asarray([trial.values for trial in trials])
for i, direction in enumerate(study.directions):
if direction == StudyDirection.MAXIMIZE:
lvals[:, i] *= -1
lvals = np.array([trial.values for trial in trials])
lvals *= np.array([-1.0 if d == StudyDirection.MAXIMIZE else 1.0 for d in study.directions])

# Solving HSSP for variables number of times is a waste of time.
nondomination_ranks = _calculate_nondomination_rank(lvals, n_below)
nondomination_ranks = _fast_non_dominated_sort(lvals, n_below=n_below)
assert 0 <= n_below <= len(lvals)

indices = np.array(range(len(lvals)))
Expand Down
2 changes: 1 addition & 1 deletion optuna/samplers/nsgaii/_child_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from optuna.distributions import BaseDistribution
from optuna.samplers._lazy_random_state import LazyRandomState
from optuna.samplers.nsgaii._constraints_evaluation import _constrained_dominates
from optuna.samplers.nsgaii._crossover import perform_crossover
from optuna.samplers.nsgaii._crossovers._base import BaseCrossover
from optuna.samplers.nsgaii._dominates import _constrained_dominates
from optuna.study import Study
from optuna.study._multi_objective import _dominates
from optuna.trial import FrozenTrial
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Sequence
import warnings

Expand Down Expand Up @@ -86,15 +85,43 @@ def _constrained_dominates(
return violation0 < violation1


def _evaluate_penalty(population: Sequence[FrozenTrial]) -> np.ndarray:
"""Evaluate feasibility of trials in population.
Returns:
A list of feasibility status T/F/None of trials in population, where T/F means
feasible/infeasible and None means that the trial does not have constraint values.
"""

penalty: list[float] = []
for trial in population:
constraints = trial.system_attrs.get(_CONSTRAINTS_KEY)
if constraints is None:
penalty.append(np.nan)
else:
penalty.append(sum(v for v in constraints if v > 0))
return np.array(penalty)


def _validate_constraints(
population: list[FrozenTrial],
constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None,
*,
is_constrained: bool = False,
) -> None:
if constraints_func is None:
if not is_constrained:
return

num_constraints = max(
[len(t.system_attrs.get(_CONSTRAINTS_KEY, [])) for t in population], default=0
)
for _trial in population:
_constraints = _trial.system_attrs.get(_CONSTRAINTS_KEY)
if _constraints is None:
warnings.warn(
f"Trial {_trial.number} does not have constraint values."
" It will be dominated by the other trials."
)
continue
if np.any(np.isnan(np.array(_constraints))):
raise ValueError("NaN is not acceptable as constraint value.")
elif len(_constraints) != num_constraints:
raise ValueError("Trials with different numbers of constraints cannot be compared.")
Alnusjaponica marked this conversation as resolved.
Show resolved Hide resolved
73 changes: 29 additions & 44 deletions optuna/samplers/nsgaii/_elite_population_selection_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Sequence
import itertools

import optuna
from optuna.samplers.nsgaii._dominates import _constrained_dominates
from optuna.samplers.nsgaii._dominates import _validate_constraints
import numpy as np

from optuna.samplers.nsgaii._constraints_evaluation import _evaluate_penalty
from optuna.samplers.nsgaii._constraints_evaluation import _validate_constraints
from optuna.study import Study
from optuna.study._multi_objective import _dominates
from optuna.study import StudyDirection
from optuna.study._multi_objective import _fast_non_dominated_sort
from optuna.trial import FrozenTrial


Expand Down Expand Up @@ -38,10 +39,10 @@
Returns:
A list of trials that are selected as elite population.
"""
_validate_constraints(population, self._constraints_func)
dominates = _dominates if self._constraints_func is None else _constrained_dominates
population_per_rank = _fast_non_dominated_sort(population, study.directions, dominates)

_validate_constraints(population, is_constrained=self._constraints_func is not None)
population_per_rank = _rank_population(
population, study.directions, is_constrained=self._constraints_func is not None
)
elite_population: list[FrozenTrial] = []
for individuals in population_per_rank:
if len(elite_population) + len(individuals) < self._population_size:
Expand Down Expand Up @@ -109,42 +110,26 @@
population.reverse()


def _fast_non_dominated_sort(
def _rank_population(
population: list[FrozenTrial],
directions: list[optuna.study.StudyDirection],
dominates: Callable[[FrozenTrial, FrozenTrial, list[optuna.study.StudyDirection]], bool],
directions: Sequence[StudyDirection],
*,
is_constrained: bool = False,
) -> list[list[FrozenTrial]]:
dominated_count: defaultdict[int, int] = defaultdict(int)
dominates_list = defaultdict(list)

for p, q in itertools.combinations(population, 2):
if dominates(p, q, directions):
dominates_list[p.number].append(q.number)
dominated_count[q.number] += 1
elif dominates(q, p, directions):
dominates_list[q.number].append(p.number)
dominated_count[p.number] += 1

population_per_rank = []
while population:
non_dominated_population = []
i = 0
while i < len(population):
if dominated_count[population[i].number] == 0:
individual = population[i]
if i == len(population) - 1:
population.pop()
else:
population[i] = population.pop()
non_dominated_population.append(individual)
else:
i += 1

for x in non_dominated_population:
for y in dominates_list[x.number]:
dominated_count[y] -= 1

assert non_dominated_population
population_per_rank.append(non_dominated_population)
if len(population) == 0:
return []

objective_values = np.array([trial.values for trial in population], dtype=np.float64)
objective_values *= np.array(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use loss_values to be clear to us in the future that each objective is better when it is lower in this array?
Please check out other functions as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable, but the objective values are not necessarily 'loss' value. Therefore the change might be confusing.
@HideakiImamura Do you have any opinion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think objective_values is appropriate here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to give @HideakiImamura the context:
loss_vals (or another name is lvals) is often used in TPESampler when lower values are better (please check here and here)

So the problem here is that I would need to take a look at the source of the function calls whether objective_values is always better when each objective is lower.
loss does not necessarily mean the machine-learning loss functions, but just loss, which we already have a universal consensus such that lower loss is better even in normal conversations.

I do not mind using objective_values, but I strongly encourage you to specify whether each objective is better when it is lower.
Again, the reason is simple because I would need to refer to the function call origin to see if objective_values is always better when it is lower.

[-1.0 if d == StudyDirection.MAXIMIZE else 1.0 for d in directions]
)
penalty = _evaluate_penalty(population) if is_constrained else None

domination_ranks = _fast_non_dominated_sort(objective_values, penalty=penalty)
population_per_rank: list[list[FrozenTrial]] = [[] for _ in range(max(domination_ranks) + 1)]
for trial, rank in zip(population, domination_ranks):
if rank == -1:
continue

Check warning on line 132 in optuna/samplers/nsgaii/_elite_population_selection_strategy.py

View check run for this annotation

Codecov / codecov/patch

optuna/samplers/nsgaii/_elite_population_selection_strategy.py#L132

Added line #L132 was not covered by tests
population_per_rank[rank].append(trial)

return population_per_rank
130 changes: 130 additions & 0 deletions optuna/study/_multi_objective.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

from collections import defaultdict
from typing import List
from typing import Optional
from typing import Sequence

import numpy as np

import optuna
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
Expand Down Expand Up @@ -69,6 +74,131 @@ def _get_pareto_front_trials(study: "optuna.study.Study") -> List[FrozenTrial]:
return _get_pareto_front_trials_by_trials(study.trials, study.directions)


def _fast_non_dominated_sort(
objective_values: np.ndarray,
*,
penalty: np.ndarray | None = None,
n_below: int | None = None,
) -> np.ndarray:
Comment on lines +77 to +82
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from __future__ import annotations

from collections import defaultdict
import time

import numpy as np


def run_alnusjaponica(objective_values: np.ndarray) -> np.ndarray:
    domination_mat = np.all(
        objective_values[:, np.newaxis, :] >= objective_values[np.newaxis, :, :], axis=2
    ) & np.any(
        objective_values[:, np.newaxis, :] > objective_values[np.newaxis, :, :], axis=2
    )

    domination_list = np.nonzero(domination_mat)
    domination_map = defaultdict(list)
    for dominated_idx, dominating_idx in zip(*domination_list):
        domination_map[dominating_idx].append(dominated_idx)

    ranks = np.full(len(objective_values), -1)
    dominated_count = np.sum(domination_mat, axis=1)

    rank = -1
    ranked_idx_num = 0
    while ranked_idx_num < len(objective_values):
        (non_dominated_idxs,) = np.nonzero(dominated_count == 0)
        ranked_idx_num += len(non_dominated_idxs)
        rank += 1
        ranks[non_dominated_idxs] = rank

        dominated_count[non_dominated_idxs] = -1
        for non_dominated_idx in non_dominated_idxs:
            dominated_count[domination_map[non_dominated_idx]] -= 1
    return ranks


def is_pareto_front_nd(ordered_loss_values: np.ndarray, assume_unique: bool) -> np.ndarray:
    loss_values = ordered_loss_values.copy()
    n_trials = loss_values.shape[0]
    is_front = np.zeros(n_trials, dtype=bool)
    nondominated_indices = np.arange(n_trials)
    while len(loss_values):
        nondominated_and_not_top = np.any(loss_values < loss_values[0], axis=1)
        # NOTE: trials[j] cannot dominate trials[j] for i < j because of lexsort.
        # Therefore, nondominated_indices[0] is always non-dominated.
        if assume_unique:
            is_front[nondominated_indices[0]] = True
        else:
            top_indices = nondominated_indices[np.all(loss_values[~nondominated_and_not_top] == loss_values[0], axis=1)]
            is_front[top_indices] = True

        loss_values = loss_values[nondominated_and_not_top]
        nondominated_indices = nondominated_indices[nondominated_and_not_top]

    return is_front


def is_pareto_front_2d(ordered_loss_values: np.ndarray, assume_unique: bool) -> np.ndarray:
    n_trials = ordered_loss_values.shape[0]
    cummin_value1 = np.minimum.accumulate(ordered_loss_values[:, 1])
    is_value1_min = cummin_value1 == ordered_loss_values[:, 1]
    is_value1_new_min = cummin_value1[1:] < cummin_value1[:-1]

    on_front = np.ones(n_trials, dtype=bool)
    if assume_unique:
        on_front[1:] = is_value1_min[1:] & is_value1_new_min
    if not assume_unique:
        is_value0_same = ordered_loss_values[1:, 0] == ordered_loss_values[:-1, 0]
        on_front[1:] = is_value1_min[1:] & (is_value0_same | is_value1_new_min)

    return on_front


def is_pareto_front(ordered_loss_values: np.ndarray, assume_unique: bool) -> np.ndarray:
    (n_trials, n_objectives) = ordered_loss_values.shape
    if n_objectives == 1:
        return ordered_loss_values[:, 0] == ordered_loss_values[0]
    elif n_objectives == 2:
        return is_pareto_front_2d(ordered_loss_values, assume_unique)
    else:
        return is_pareto_front_nd(ordered_loss_values, assume_unique)


def calculate_nondomination_rank(loss_values: np.ndarray) -> np.ndarray:
    (n_trials, n_objectives) = loss_values.shape

    if n_objectives == 1:
        _, ranks = np.unique(loss_values[:, 0], return_inverse=True)
        return ranks
    else:
        # It ensures that trials[j] will not dominate trials[i] for i < j.
        # np.unique does lexsort.
        ordered_loss_values, order_inv = np.unique(loss_values, return_inverse=True, axis=0)

    n_unique = ordered_loss_values.shape[0]
    ranks = np.zeros(n_unique, dtype=int)
    rank = 0
    indices = np.arange(n_unique)
    while indices.size > 0:
        on_front = is_pareto_front(ordered_loss_values, assume_unique=True)
        ranks[indices[on_front]] = rank
        # Remove the recent Pareto solutions.
        indices = indices[~on_front]
        ordered_loss_values = ordered_loss_values[~on_front]
        rank += 1

    return ranks[order_inv]


def run_nabenabe(loss_values: np.ndarray) -> np.ndarray:
    return calculate_nondomination_rank(loss_values)


def measure_time(target, loss_values: np.ndarray) -> tuple[np.ndarray, float]:
    start = time.time()
    results = target(loss_values.copy())
    elapsed_time = (time.time() - start) * 1000
    return results, elapsed_time


if __name__ == "__main__":
    n_trials = 1000
    n_objectives = 1
    n_seeds = 5
    results = {"nabenabe": [], "alnusjaponica": []}
    for seed in range(n_seeds):
        rng = np.random.RandomState(seed)
        loss_values = rng.normal(size=(n_trials, n_objectives))
        ans_nabenabe, t = measure_time(run_nabenabe, loss_values)
        results["nabenabe"].append(t)
        ans_alnusjaponica, t = measure_time(run_alnusjaponica, loss_values)
        results["alnusjaponica"].append(t)
        print(np.all(ans_alnusjaponica == ans_nabenabe))

    print({k: f"{np.mean(v):.2f} +/- {np.std(v) / np.sqrt(n_seeds):.2f} [ms]" for k, v in results.items()})

For simplicity, I removed penalty and n_below, but this implementation maintains the same results as your program yet much quicker.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit is milliseconds.

Mine Yours
n_trials=1000, n_objectives=2 1.57 $\pm$ 0.1 84.5 $\pm$ 3.1
n_trials=10000, n_objectives=2 18.13 $\pm$ 1.4 9858.0 $\pm$ 126.8
n_trials=1000, n_objectives=3 11.95 $\pm$ 0.9 61.5 $\pm$ 1.0
n_trials=10000, n_objectives=3 310.5 $\pm$ 5.0 7025.0 $\pm$ 181.4

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your implementation of dominance with penalty includes the unnecessary consideration of penalty for feasible cases.
Namely, we do not consider the penalty amount once each trial satisfies the penalty, but yours are considering the penalty amount even for feasible cases.

Please check the following definition by Deb et al. [1]:

image

Plus, the implementation below is much quicker.

def calculate_nondomination_rank_with_penalty(
    loss_values: np.ndarray, penalty: np.ndarray | None = None
) -> np.ndarray:
    if penalty is None:
        return calculate_nondomination_rank(loss_values)

    # If values[i] constrained-dominates values[j] given penalty[i] and penalty[j],
    # one of the following must be satisfied:
    # 1. penalty[i] <= 0 and penalty[j] > 0,
    # 2. penalty[i] > 0 and penalty[j] > 0 and penalty[i] < penalty[j], or
    # 3. penalty[i] <= 0 and penalty[j] <= 0 and values[i] dominates values[j].
    # Therefore, if trials[i] is feasible and trials[j] is infeasible,
    # nondomination_rank[i] <= nondomination_rank[j] always holds and we can separate the sortings
    # for feasible trials and infeasible trials.
    # Ref: Definition 1 by K. Deb et al. in
    # `A Fast and Elitist Multiobjective Genetic Algorithm: NSGA-II`
    if len(penalty) != len(loss_values):
        raise ValueError(
            f"The length of penalty and loss_values must be same, but got "
            f"len(penalty)={len(penalty)} and len(loss_values)={len(loss_values)}."
        )

    penalty[np.isnan(penalty)] = np.inf
    is_feasible = penalty <= 0
    nondomination_rank = np.zeros(len(loss_values), dtype=int)
    nondomination_rank[is_feasible] += calculate_nondomination_rank(loss_values[is_feasible])
    nondomination_rank[~is_feasible] += calculate_nondomination_rank(
        penalty[~is_feasible, np.newaxis],
    ) + np.max(nondomination_rank[is_feasible], initial=-1) + 1
    return nondomination_rank

[1] K. Deb et al. A Fast and Elitist Multiobjective Genetic Algorithm: NSGA-II

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this implementation maintains the same results as your program yet much quicker.

Thanks for the suggestion. I just move what's implemented in _tpe/sampler.py and no need to stick to the first implementation. I'll consider how to reconcile current _get_pareto_front_trials_by_trials function with your suggestion.

I think your implementation of dominance with penalty includes the unnecessary consideration of penalty for feasible cases.

In my understanding, the penalty is set to 0 when the trial is feasible, thus having no influence on the result. Anyway, I'll use the faster implementation in the suggestion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from __future__ import annotations

import itertools

import numpy as np


def is_pareto_front(trials: list[Trial]) -> np.ndarray:
    n_trials = len(trials)
    is_front = np.zeros(n_trials, dtype=bool)
    next_index = 0
    nondominated_indices = np.arange(n_trials)
    while next_index < len(trials):
        nondominated_mask = np.array([not dominates(t, trials[next_index]) for t in trials])
        trials = list(itertools.compress(trials, nondominated_mask))
        nondominated_indices = nondominated_indices[nondominated_mask]
        next_index = np.sum(nondominated_mask[:next_index]) + 1

    is_front[nondominated_indices] = True
    return is_front


def calculate_nondomination_rank(trials: list[Trial]) -> np.ndarray:
    n_trials = len(trials)
    ranks = np.zeros(n_trials, dtype=int)
    rank = 0
    indices = np.arange(n_trials)
    while indices.size > 0:
        on_front = is_pareto_front(trials)
        ranks[indices[on_front]] = rank
        indices = indices[~on_front]
        trials = list(itertools.compress(trials, ~on_front))
        rank += 1

    return ranks

This is Just a memo for future discussion, please ignore it for now.

When using dominates, the runtime will be much much longer compared to the vectorization version, but it still runs quicker than creating the dominance matrix.
This implementation is much slower because:

  1. we do not use vectorization,
  2. we cannot pre-sort trials so that trials[j] cannot dominate trials[i] for $i &lt; j$, and
  3. we cannot assume uniqueness in the trials.

Alnusjaponica marked this conversation as resolved.
Show resolved Hide resolved
"""Perform the fast non-dominated sort algorithm.

The fast non-dominated sort algorithm assigns a rank to each trial based on the dominance
relationship of the trials, determined by the objective values and the penalty values. The
algorithm is based on `the constrained NSGA-II algorithm
<https://doi.org/10.1109/4235.99601>`_, but the handling of the case when penalty
values are None is different. The algorithm assigns the rank according to the following
rules:

1. Feasible trials: First, the algorithm assigns the rank to feasible trials, whose penalty
values are less than or equal to 0, according to unconstrained version of fast non-
dominated sort.
2. Infeasible trials: Next, the algorithm assigns the rank from the minimum penalty value of to
the maximum penalty value.
3. Trials with no penalty information (constraints value is None): Finally, The algorithm
assigns the rank to trials with no penalty information according to unconstrained version
of fast non-dominated sort. Note that only this step is different from the original
constrained NSGA-II algorithm.
Plus, the algorithm terminates whenever the number of sorted trials reaches n_below.

Args:
objective_values:
Objective values of each trials.
penalty:
Constraints values of each trials. Defaults to None.
n_below: The minimum number of top trials required to be sorted. The algorithm will
terminate when the number of sorted trials reaches n_below. Defaults to None.

Returns:
An ndarray in the shape of (n_trials,), where each element is the non-dominated rank of
each trial. The rank is 0-indexed and rank -1 means that the algorithm terminated before
the trial was sorted.
"""
if penalty is None:
ranks, _ = _calculate_nondomination_rank(objective_values, n_below=n_below)
return ranks

if len(penalty) != len(objective_values):
raise ValueError(
Alnusjaponica marked this conversation as resolved.
Show resolved Hide resolved
"The length of penalty and objective_values must be same, but got "
"len(penalty)={} and len(objective_values)={}.".format(
len(penalty), len(objective_values)
)
)
nondomination_rank = np.full(len(objective_values), -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we can already bound max(nondomination_rank) by n_below and nondomination_rank of n_below + 1 will not be used, so what about using n_below + 1?
Another reason why we should probably avoid -1 is that it might cause unexpected bugs in the future when some developers use nondomination_rank being always better when it is lower.
Plus, this implementation requires an ad-hoc handling of nondomination_rank=-1 in each place where the function is used.

is_penalty_nan = np.isnan(penalty)
n_below = n_below or len(objective_values)

# First, we calculate the domination rank for feasible trials.
is_feasible = np.logical_and(~is_penalty_nan, penalty <= 0)
ranks, bottom_rank = _calculate_nondomination_rank(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we define nondomination_rank as:

nondomination_rank = np.full(len(objective_values), n_below + 1)

bottom_rank becomes bottom_rank = np.max(ranks).
Note that if np.max(bottom_rank) = n_below + 1, the processes hereafter simply define each nondomination_rank as n_below + <positive_integer>, so they will be ignored.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree what you say but it makes this PR even larger. Can I split the task as a follow-up and resolve your comment in another PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggestion is a little bit complicated, so I remarked the comment on #5089

objective_values[is_feasible], n_below=n_below
)
nondomination_rank[is_feasible] += 1 + ranks
Alnusjaponica marked this conversation as resolved.
Show resolved Hide resolved
n_below -= np.count_nonzero(is_feasible)

# Second, we calculate the domination rank for infeasible trials.
is_infeasible = np.logical_and(~is_penalty_nan, penalty > 0)
num_infeasible_trials = np.count_nonzero(is_infeasible)
if num_infeasible_trials > 0:
_, ranks = np.unique(penalty[is_infeasible], return_inverse=True)
ranks += 1
nondomination_rank[is_infeasible] += 1 + bottom_rank + ranks
bottom_rank += np.max(ranks)
n_below -= num_infeasible_trials

# Third, we calculate the domination rank for trials with no penalty information.
ranks, _ = _calculate_nondomination_rank(
objective_values[is_penalty_nan], n_below=n_below, base_rank=bottom_rank + 1
)
nondomination_rank[is_penalty_nan] += 1 + ranks

return nondomination_rank


def _calculate_nondomination_rank(
objective_values: np.ndarray,
*,
n_below: int | None = None,
base_rank: int = 0,
) -> tuple[np.ndarray, int]:
if n_below is not None and n_below <= 0:
return np.full(len(objective_values), -1), base_rank
# Normalize n_below.
n_below = n_below or len(objective_values)
n_below = min(n_below, len(objective_values))
Alnusjaponica marked this conversation as resolved.
Show resolved Hide resolved

# The ndarray `domination_mat` is a boolean 2d matrix where
# `domination_mat[i, j] == True` means that the j-th trial dominates the i-th trial in the
# given multi objective minimization problem.
domination_mat = np.all(
objective_values[:, np.newaxis, :] >= objective_values[np.newaxis, :, :], axis=2
) & np.any(objective_values[:, np.newaxis, :] > objective_values[np.newaxis, :, :], axis=2)

domination_list = np.nonzero(domination_mat)
domination_map = defaultdict(list)
for dominated_idx, dominating_idx in zip(*domination_list):
domination_map[dominating_idx].append(dominated_idx)

ranks = np.full(len(objective_values), -1)
dominated_count = np.sum(domination_mat, axis=1)

rank = base_rank - 1
ranked_idx_num = 0
while ranked_idx_num < n_below:
# Find the non-dominated trials and assign the rank.
(non_dominated_idxs,) = np.nonzero(dominated_count == 0)
ranked_idx_num += len(non_dominated_idxs)
rank += 1
ranks[non_dominated_idxs] = rank

# Update the dominated count.
dominated_count[non_dominated_idxs] = -1
for non_dominated_idx in non_dominated_idxs:
dominated_count[domination_map[non_dominated_idx]] -= 1

return ranks, rank


def _dominates(
trial0: FrozenTrial, trial1: FrozenTrial, directions: Sequence[StudyDirection]
) -> bool:
Expand Down