In [1]:
from desdeo.problem.testproblems.single_objective import new_branin_function, mystery_function, mishras_bird_constrained
from desdeo.problem import Problem
from desdeo.emo.hooks.archivers import Archive
from desdeo.emo import algorithms, selection, termination, generator, crossover, mutation, scalar_selection
import polars as pl
import numpy as np
import plotly.graph_objects as go


def run_nsga2_with_mode(
    problem: Problem,
    mode: str,
    constraint_threshold: float,
    pop_size: int,
    n_generations: int,
    constraint_symbol="c_1",
):
    """Run the NSGA-II style EA once for a given constraint-handling mode."""
    # ---- Problem ----
    nsga2_options = algorithms.nsga2_options()

    nsga2_options.template.crossover = crossover.SimulatedBinaryCrossoverOptions(
        xover_probability=0.9, xover_distribution=20
    )
    nsga2_options.template.mutation = mutation.BoundedPolynomialMutationOptions(
        mutation_probability=1.0 / len(problem.variables), distribution_index=20
    )
    nsga2_options.template.mate_selection = scalar_selection.TournamentSelectionOptions(
        name="TournamentSelection", tournament_size=2, winner_size=pop_size
    )
    nsga2_options.template.selection = selection.SingleObjectiveConstrainedRankingSelectorOptions(
        target_objective_symbol="f_1",
        target_constraint_symbol=constraint_symbol,
        constraint_threshold=constraint_threshold,
        population_size=pop_size,
        mode=mode,
    )
    nsga2_options.template.generator = generator.LHSGeneratorOptions(n_points=pop_size)
    nsga2_options.template.termination = termination.MaxGenerationsTerminatorOptions(max_generations=n_generations)

    solver, extras = algorithms.emo_constructor(emo_options=nsga2_options, problem=problem)

    archive = Archive(problem=problem, publisher=extras.publisher)

    extras.publisher.auto_subscribe(archive)
    extras.publisher.register_topics(archive.provided_topics[archive.verbosity], archive.__class__.__name__)

    # ---- Run optimization ----
    _ = solver()  # result object not strictly needed; archive holds all solutions

    # Full history of solutions as Polars DataFrame, last population
    return archive


def run(mode: str):
    # modes = ["relaxed", "baseline", "alternate"]
    # run baseline
    problem = new_branin_function()
    baseline_res = run_nsga2_with_mode(problem, mode, 1, pop_size=6, n_generations=100)

    baseline_solutions = baseline_res.solutions

    feasible_baseline = baseline_solutions.with_columns(
        pl.when(pl.col("c_1") <= 0).then(pl.col("f_1_min")).otherwise(float("inf")).alias("feasible_f_1")
    )

    best_baseline = (
        feasible_baseline.group_by("generation")
        .agg(pl.col("feasible_f_1").min().alias("best_f_1_this_gen"))
        .sort("generation")
    )
    # best_so_far = best_by_gen["best_f_1_this_gen"].cum_min()
    # best_by_gen = best_by_gen.with_columns(
    # pl.Series("best_f_1_so_far", best_so_far)
    # )
    best_so_far_baseline = best_baseline.with_columns(best_baseline["best_f_1_this_gen"].cum_min().alias("best"))

    return best_so_far_baseline


In [2]:
# run many times
times = 10
results_relaxed = []
results_baseline = []

for _ in range(times):
    res_relaxed, res_baseline = run("baseline"), run("baseline2")
    results_relaxed.append(res_relaxed)
    results_baseline.append(res_baseline)

combined_relaxed = pl.concat(results_relaxed)
combined_baseline = pl.concat(results_baseline)

stats_relaxed = (
    combined_relaxed.filter(pl.col("best").is_finite())
    .group_by("generation")
    .agg(
        [
            pl.col("best").mean().alias("best_mean"),
            pl.col("best").std().alias("best_std"),
        ]
    )
    .with_columns(
        [
            (pl.col("best_mean") + pl.col("best_std")).alias("best_upper"),
            (pl.col("best_mean") - pl.col("best_std")).alias("best_lower"),
        ]
    )
    .sort("generation")
).filter(
    pl.all_horizontal(
        pl.col("best_mean").is_finite(),
        pl.col("best_upper").is_finite(),
        pl.col("best_lower").is_finite(),
    )
)

stats_baseline = (
    combined_baseline.filter(pl.col("best").is_finite())
    .group_by("generation")
    .agg(
        [
            pl.col("best").mean().alias("best_mean"),
            pl.col("best").std().alias("best_std"),
        ]
    )
    .with_columns(
        [
            (pl.col("best_mean") + pl.col("best_std")).alias("best_upper"),
            (pl.col("best_mean") - pl.col("best_std")).alias("best_lower"),
        ]
    )
    .sort("generation")
).filter(
    pl.all_horizontal(
        pl.col("best_mean").is_finite(),
        pl.col("best_upper").is_finite(),
        pl.col("best_lower").is_finite(),
    )
)


fig = go.Figure()

fig.add_trace(go.Scatter(x=stats_relaxed["generation"], y=stats_relaxed["best_mean"], mode="lines", name="Relaxed"))
fig.add_trace(go.Scatter(x=stats_baseline["generation"], y=stats_baseline["best_mean"], mode="lines", name="Baseline"))
fig.add_traces(
    go.Scatter(
        x=stats_relaxed["generation"].to_list() + stats_relaxed["generation"].to_list()[::-1],
        y=stats_relaxed["best_upper"].to_list() + stats_relaxed["best_lower"].to_list()[::-1],
        fill="toself",
        name="std one dev",
        line={"width": 0},
        fillcolor="rgba(0,0,255,0.15)",
        hoverinfo="skip",
        showlegend=False,
    )
)
fig.add_traces(
    go.Scatter(
        x=stats_baseline["generation"].to_list() + stats_baseline["generation"].to_list()[::-1],
        y=stats_baseline["best_upper"].to_list() + stats_baseline["best_lower"].to_list()[::-1],
        fill="toself",
        name="std one dev",
        line={"width": 0},
        fillcolor="rgba(255,0,0,0.15)",
        hoverinfo="skip",
        showlegend=False,
    )
)

fig.update_layout(
    title=f"Best mean f_1 per generation (Relaxed vs Baseline)",
    xaxis_title="Generation",
    yaxis_title="Cumulative best mean f_1",
)

fig.show()

print(combined_relaxed)
"""
fig_scatter.add_trace(
    go.Scatter(
        x=best_so_far_baseline["generation"],
        y=best_so_far_baseline["best"],
        mode="markers",
        name=f"{mode}",
        opacity=0.6,
    )
)

fig_scatter.update_layout(
    title=f"f_1 vs c_1 {mode}",
    xaxis_title="Generation",
    yaxis_title="Best f_1 thus far",
)

fig_scatter.show()
print(best_so_far_baseline)
"""

Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing duplicates!
Removing dupl

KeyboardInterrupt: 

In [112]:
print(combined_baseline)

shape: (200_000, 3)
┌────────────┬───────────────────┬─────────────┐
│ generation ┆ best_f_1_this_gen ┆ best        │
│ ---        ┆ ---               ┆ ---         │
│ i32        ┆ f64               ┆ f64         │
╞════════════╪═══════════════════╪═════════════╡
│ 1          ┆ inf               ┆ 1.7977e308  │
│ 2          ┆ inf               ┆ 1.7977e308  │
│ 3          ┆ inf               ┆ 1.7977e308  │
│ 4          ┆ inf               ┆ 1.7977e308  │
│ 5          ┆ inf               ┆ 1.7977e308  │
│ …          ┆ …                 ┆ …           │
│ 196        ┆ -268.637847       ┆ -268.637925 │
│ 197        ┆ -268.637847       ┆ -268.637925 │
│ 198        ┆ -268.637846       ┆ -268.637925 │
│ 199        ┆ -268.637847       ┆ -268.637925 │
│ 200        ┆ -268.637847       ┆ -268.637925 │
└────────────┴───────────────────┴─────────────┘
