## Imports and Functions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
from scipy.spatial import ConvexHull

# symbols
x = sp.symbols('x')
a = sp.symbols('a')

def calculate_simplices(f):
    # calculate the first and second derivatives of the function
    f_prime = sp.diff(f, x)
    f_dprime = sp.diff(f_prime, x)

    # calculate the gradient projection and divide out the trivial root
    grad_proj = sp.div(f  - f.subs(x, a) - f_prime.subs(x, a), (x-a)**2)[0]
    grad_proj = sp.simplify(grad_proj)

    # calculate the discriminant of the projection
    discriminant = sp.discriminant(grad_proj, x)

    # solving the discriminant gives the inflection points and the x values of the vertices of the bounding simplices
    disc_roots = sp.solve(discriminant, a)

    # solve the inflection points to remove them from the discriminant solutions
    inflection_points = sp.solve(f_dprime, x)

    # remove the inflection points from the discriminant solutions
    vertices_xvals = [value for value in disc_roots if value not in inflection_points and value.is_real]
    vertices_xvals.sort()

    # compute the cooresponding y values of the vertices
    vertices_yvals = [f.subs(x, value) for value in vertices_xvals]

    points = np.column_stack((vertices_xvals, vertices_yvals))

    return points


def bounding_points(points):
    '''
    This function is to calculate the bounding points on our grid
    Parameters:
        points (array): this is the grid of points on the phase boundry in N dimensional phase space
    Returns:
        bounding_points (array): array of indices describing which points bound the set in our phase space
    '''
    if points.shape[1] == 1:
        bounding_points = np.array([np.argmin(points), np.argmax(points)])
    else:
        bounding_points = ConvexHull(points).simplices

    return bounding_points.flatten()


def lower_convex_hull(points):
    '''
    This function is to build a lower convex-hull. It is assumed that the last dimension of the points data represents the Energy
    Parameters:
        points (array): this is the grid of points on the phase boundry in N dimensional phase space
    Returns:
        lower_hull (array): array of indices describing which points in the grid form the lower convex hull
    '''
    processing_points = points.copy()
    dim = points.shape[1]

    # first project our points into the dim-1 dimensional space to find the bounding values with a convex hull
    projected_points = processing_points[:, :-1]
    bp = bounding_points(projected_points)
    
    # add extreme energies to the bounding points
    fake_points = processing_points[bp]
    fake_points[:, 1] += 1000 # <-------------------------------------------------This is the arbitrary fake energy being added
    processing_points = np.vstack((processing_points, fake_points))

    # compute the convex hull
    simplices = ConvexHull(processing_points).simplices

    # scrub the fake points
    mask = np.all(simplices < len(points), axis=1)
    lower_hull = simplices[mask]

    return lower_hull


def rectify_lower_hull(points, lower_hull):
    # keep every other element of the lower hull in sorted order; these are the convex planes
    mask = [i % 2 == 0 for i in range(len(lower_hull))]

    # compute the average x value for each pair of points
    x_values = (points[lower_hull[:, 0], 0] + points[lower_hull[:, 1], 0]) / 2

    # sort the simplices based on the computed average x values
    sorted_simplices = lower_hull[np.argsort(x_values)]

    return sorted_simplices[mask]


def plot_func(func, xrange=[-10, 10], yrange=[-10, 10]):
    # calcualte the simplices and plot everything
    points = calculate_simplices(f)
    convex_planes = rectify_lower_hull(points, lower_convex_hull(points))

    # plot the convex planes
    for plane in convex_planes:
        plt.plot(points[plane, 0], points[plane, 1], 'red', marker='o')

    # plot the function
    x_space = np.linspace(xrange[0], xrange[1], 1000)
    plt.plot(x_space, [func.subs(x, i) for i in x_space], label='f(x)')

    plt.legend()
    plt.ylim(yrange)
    plt.xlim(xrange)
    plt.grid()
    plt.show()

## Demo

The coordinates of the tangent points defining the bounding convex hyperplanes are given by the roots of this equation in $t$:
$$\boldsymbol{\Delta} \left( \frac{\begin{bmatrix}
  x - t\\
  f(x) - f(t)
\end{bmatrix} \cdot \begin{bmatrix}
  - \nabla f(x)|_t\\
  1
\end{bmatrix}}{(x-t)^2} \right)
 = 0
$$

In [None]:
# first define your function
f = x**4 - 4*x**2 + x

# plot the function and convex bounding plane
plot_func(f, xrange=[-3, 3], yrange=[-6, 2])

In [None]:
# now a different function
const = sp.Rational(5, 2)
f = x**4 - const*x**3 + x

# plot the function and convex bounding plane
plot_func(f, xrange=[-1, 3], yrange=[-3, 1])

#### problems arise here...

In [None]:
# higher order function
f = x**6 - 4*x**4 + 5*x**2

# plot the function and convex bounding plane
plot_func(f, xrange=[-2, 2], yrange=[-1, 3])

In [None]:
# make it weirder
f = x**6 - 4*x**4 + 5*x**2 + x

plot_func(f, xrange=[-2, 2], yrange=[-1, 4])