# Nelder-Mead optimization

In [191]:
import plotly.express as px
import numpy as np
import pandas as pd

### Define functions

In [192]:
# Define the objective function (x - 3)^2
def objective_function(x):
    return (x - 3)**2+20

In [193]:
def plot_quadratic_function(min_x, min_name='global minimum', title="Plot of f(x) = (x - 3)²+20"):
    # Generate x and y values
    x_values = np.linspace(-10, 10, 100)
    y_values = objective_function(x_values)

    # Create a DataFrame for Plotly
    data = pd.DataFrame({
        "x": x_values,
        "f(x)": y_values
    })

    # Coordinates of the local minimum
    min_x = 3
    min_y = objective_function(min_x)

    # Plot the quadratic function using Plotly Express
    fig = px.line(data, x="x", y="f(x)", title=title,
                labels={"x": "x", "f(x)": "f(x)"},
                )

    # Add a marker for the local minimum
    fig.add_scatter(x=[min_x], y=[min_y], mode="markers",
                    marker=dict(color="red", size=10, symbol="circle"),
                    name=min_name)

    # Update the figure layout to adjust size
    fig.update_layout(
        width=800,  # Width in pixels
        height=600  # Height in pixels
    )

    fig.show()

In [194]:
def evaluate_points(f, points):
    # Compute function values for each point in the simplex
    values = [f(x) for x in points]
    values = np.array(values)
    
    # Sort indices based on function values (from the lowest to the highest)
    sorted_indices = np.argsort(values)
    
    # Return both sorted simplex points and their corresponding function values
    sorted_points = points[sorted_indices]
    sorted_values = values[sorted_indices]

    return sorted_points, sorted_values


In [196]:
def nelder_meads_algorithm(points, max_iter=1000, alpha=1, gamma=2, beta=0.5, delta=0.5, tol=1e-4):
    for iteration in range(max_iter):
        # Step 0: Evaluate points and sort by function values
        points, f_values = evaluate_points(f=objective_function, points=points)

        print(f"iteration: {iteration}")
        print(f"f(x) values: {f_values}")
        print(f"points: {points}")

        # Step 1: Define best, second-worst, and worst points
        x_best, x_worst = points[0], points[-1]
        f_best, f_second_worst, f_worst = f_values[0], f_values[-2], f_values[-1]

        # Step 2: Calculate the centroid (excluding the worst point)
        x_centroid = np.mean(points[:-1], axis=0)

        # Step 3: Reflection
        x_ref = x_centroid + alpha * (x_centroid - x_worst)
        f_ref = objective_function(x_ref)

        if f_ref < f_best:  # Expansion
            x_exp = x_centroid + gamma * (x_ref - x_centroid)
            f_exp = objective_function(x_exp)
            points[-1] = x_exp if f_exp < f_ref else x_ref
        elif f_best <= f_ref < f_second_worst:  # Accept reflection
            points[-1] = x_ref
        else:
            # Contraction
            if f_ref < f_worst:  # Outside contraction
                x_cont = x_centroid + beta * (x_ref - x_centroid)
            else:  # Inside contraction
                x_cont = x_centroid + beta * (x_worst - x_centroid)

            f_cont = objective_function(x_cont)
            if f_cont < min(f_ref, f_worst):
                points[-1] = x_cont
            else:  # Shrink the simplex
                points = [x_best] + [x_best + delta * (point - x_best) for point in points if point != x_best]
                print('shrink')

        # Convergence check
        if np.std(f_values) < tol:
            print(f"Converged after {iteration} iterations.")
            break



### Plot the function

In [197]:
plot_quadratic_function(    min_x = 3,
                            title='Plot of f(x) = (x - 3)² + 20'
                         )

### Run the algorithm

In [198]:
# Initial points
initial_points = np.array([20.5, 19.1, 18.3])

# Run the algorithm
nelder_meads_algorithm(points=initial_points)

iteration: 0
f(x) values: [254.09 279.21 326.25]
points: [18.3 19.1 20.5]
iteration: 1
f(x) values: [166.41 254.09 279.21]
points: [15.1 18.3 19.1]
iteration: 2
f(x) values: [ 99.21 166.41 254.09]
points: [11.9 15.1 18.3]
iteration: 3
f(x) values: [ 20.81  99.21 166.41]
points: [ 3.9 11.9 15.1]
iteration: 4
f(x) values: [20.81 25.29 99.21]
points: [ 3.9  0.7 11.9]
iteration: 5
f(x) values: [20.81 25.29 36.81]
points: [3.9 0.7 7.1]
iteration: 6
f(x) values: [20.81 22.89 25.29]
points: [3.9 4.7 0.7]
iteration: 7
f(x) values: [20.25 20.81 22.89]
points: [2.5 3.9 4.7]
iteration: 8
f(x) values: [20.25   20.3025 20.81  ]
points: [2.5  2.45 3.9 ]
iteration: 9
f(x) values: [20.03515625 20.25       20.3025    ]
points: [3.1875 2.5    2.45  ]
iteration: 10
f(x) values: [20.03515625 20.05640625 20.25      ]
points: [3.1875 3.2375 2.5   ]
iteration: 11
f(x) values: [20.02066406 20.03515625 20.05640625]
points: [2.85625 3.1875  3.2375 ]
iteration: 12
f(x) values: [20.00738525 20.02066406 20.0351562