## Imports and Functions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
from scipy.optimize import fsolve
import plotly.graph_objects as go

def affine_gradient_projection(f1, f2):
    '''
    This will project f2 onto the affine gradient of f1
    '''
    # first collect our free variables with their parameterized counterparts
    vars = tuple(f1.free_symbols)
    vars_dict = {var: sp.symbols(''.join(f'{var}_p')) for var in vars}

    # compute the projections of the functions onto eachother's affine gradient
    proj_f2_f1 = f2 - f1.subs(vars_dict)
    for var in vars_dict.keys():
        proj_f2_f1 -= sp.diff(f1, var).subs(var, vars_dict[var]) * (var - vars_dict[var])

    return proj_f2_f1, vars


def recursive_discriminant(expr, vars):
    # define a helper function and a copy of the variables to save memory when computing the discriminant
    vars_list = list(vars)
    def recursive_discriminant_helper(expr, vars):
        # if there are no variables return the expression
        if vars == []:
            return expr
        
        for var in vars_list:
            if expr.has(var):
                vars_list.remove(var)
                return sp.discriminant(recursive_discriminant(expr, vars_list), var)
            
    return recursive_discriminant_helper(expr, vars_list)


def solve_with_fsolve(equations, variables, initial_guess):
    """
    Solve a system of equations using SciPy's fsolve.

    Parameters:
    equations (list): List of SymPy equations.
    variables (list): List of SymPy variables.
    initial_guess (list): Initial guess for the variables.

    Returns:
    list: Solution to the system of equations.
    """
    # Convert SymPy equations to lambda functions
    eq_lambdas = [sp.lambdify(variables, eq, 'numpy') for eq in equations]

    # Define a function that evaluates the equations
    def func(var_values):
        return [eq(*var_values) for eq in eq_lambdas]

    # Use fsolve to solve the equations
    solution = fsolve(func, initial_guess)

    return solution

def get_dicriminants(function, other_functions):
    discriminant_list = []

    for other_function in other_functions:
        proj, vars = affine_gradient_projection(function, other_function)

        disc = recursive_discriminant(proj, vars)
        discriminant_list.append(disc)

    # hard coding the parameters because the order gets broken TODO: fix this
    parameterized_vars = sp.symbols('x_p y_p')

    return discriminant_list

def solve_convex_vertex(function, other_functions):
    discriminant_list = get_dicriminants(function, other_functions)
    print(discriminant_list)

    # hard coding the parameters because the order gets broken TODO: fix this
    parameterized_vars = sp.symbols('x_p y_p')

    # TODO: we need to assert that there are n+1 number of functions provided for maximum number of n variables

    # solve the system of equations with Newton-Raphson method for speed
    grad_f1 = [sp.diff(f1, var) for var in f1.free_symbols]
    critical_points_eqs = [sp.Eq(grad, 0) for grad in grad_f1]
    critical_points = sp.solve(critical_points_eqs, f1.free_symbols)
    initial_guess = [float(critical_points[var]) for var in f1.free_symbols]

    solution = solve_with_fsolve(discriminant_list, parameterized_vars, initial_guess)

    return solution, discriminant_list


# plotting function
def plot_funcs(f1, f2, f3, zrange=[-2, 12]):
    # solve the vertices
    vertex_f1, d1s = solve_convex_vertex(f1, [f2, f3])
    vertex_f2, d2s = solve_convex_vertex(f2, [f1, f3])
    vertex_f3, d3s = solve_convex_vertex(f3, [f2, f1])

    discs = d1s + d2s + d3s

    # Create a figure
    # Define the grid and function values
    x_vals = np.linspace(-10, 10, 100)
    y_vals = np.linspace(-10, 10, 100)
    x_grid, y_grid = np.meshgrid(x_vals, y_vals)

    # Convert sympy expressions to numerical functions
    f1_func = sp.lambdify((x, y), f1, 'numpy')
    f2_func = sp.lambdify((x, y), f2, 'numpy')
    f3_func = sp.lambdify((x, y), f3, 'numpy')

    # Evaluate the functions on the grid
    f1_vals = f1_func(x_grid, y_grid)
    f2_vals = f2_func(x_grid, y_grid)
    f3_vals = f3_func(x_grid, y_grid)

    fig = go.Figure()

    # Add the discriminants

    # Create a meshgrid for x and y
    x_vals = np.linspace(-10, 10, 4000)
    y_vals = np.linspace(-10, 10, 4000)
    x_mesh, y_mesh = np.meshgrid(x_vals, y_vals)

    # Evaluate the expressions on the meshgrid
    z_vals_list = []
    for expr in discs:
        z_vals = sp.lambdify(('x_p', 'y_p'), expr, 'numpy')(x_mesh, y_mesh)
        z_vals_list.append(z_vals)

    colors = ['red', 'green', 'blue', 'orange', 'purple', 'brown', 'pink']

    # Now, for each z_vals, find the contours at level 0 and assign a color
    for idx, z_vals in enumerate(z_vals_list):
        # Use matplotlib's contour function to find contours at level 0
        CS = plt.contour(x_mesh, y_mesh, z_vals, levels=[0])
        line_color = colors[idx % len(colors)]  # Assign a color from the list
        for collection in CS.collections:
            for contour in collection.get_paths():
                vertices = contour.vertices
                x_coords = vertices[:, 0]
                y_coords = vertices[:, 1]
                z_coords = np.zeros_like(x_coords)  # z=0 for all points
                # Plot the line in 3D with the assigned color
                fig.add_trace(go.Scatter3d(
                    x=x_coords,
                    y=y_coords,
                    z=z_coords,
                    mode='lines',
                    line=dict(color=line_color, width=5),
                    name=f'Discriminant {idx + 1}'
                ))
        plt.clf()  # Clear the figure for the next contour

    # Add surface plot for f1 with a specific color
    fig.add_trace(go.Surface(z=f1_vals, x=x_grid, y=y_grid, opacity=0.8, name='f1', showscale=False, colorscale='Viridis'))

    # Add surface plot for f2 with a specific color
    fig.add_trace(go.Surface(z=f2_vals, x=x_grid, y=y_grid, opacity=0.8, name='f2', showscale=False, colorscale='emrld'))

    # Add surface plot for f3 with a specific color
    fig.add_trace(go.Surface(z=f3_vals, x=x_grid, y=y_grid, opacity=0.8, name='f3', showscale=False, colorscale='Plasma'))

    # Extract the x, y, and z coordinates of the vertices
    f1_val_float = float(f1.subs({x: vertex_f1[0], y: vertex_f1[1]}))
    f2_val_float = float(f2.subs({x: vertex_f2[0], y: vertex_f2[1]}))
    f3_val_float = float(f3.subs({x: vertex_f3[0], y: vertex_f3[1]}))

    x_coords = [vertex_f1[0], vertex_f2[0], vertex_f3[0], vertex_f1[0]]
    y_coords = [vertex_f1[1], vertex_f2[1], vertex_f3[1], vertex_f1[1]]
    z_coords = [f1_val_float, f2_val_float, f3_val_float, f1_val_float]

    # Add the polygon to the figure using Mesh3d
    fig.add_trace(go.Mesh3d(
        x=x_coords,
        y=y_coords,
        z=z_coords,
        color='red',
        opacity=0.50,
        i=[0],
        j=[1],
        k=[2]
    ))

    # Update layout for better visualization
    fig.update_layout(
        width=1000,
        height=800,
        scene=dict(
            xaxis_title='X axis',
            yaxis_title='Y axis',
            zaxis_title='Z axis',
        ),
        title='3D Plot of f1, f2, and f3',
        coloraxis_showscale=False
    )
    fig.update_layout(scene=dict(zaxis=dict(range=zrange)))

    # Show the figure
    fig.show()


## Demo

The coordinates of the tangent points on $f_1$ defining the bounding convex hyperplane between $f_1$ and $f_2$ are given by the roots of this equation in $t$:
$$\boldsymbol{\Delta} \left( \begin{bmatrix}
  \vec{x} - \vec{t}\\
  f_2(x_1, x_2, \ldots) - f_1(t_1, t_2, \ldots)
\end{bmatrix} \cdot \begin{bmatrix}
  - \nabla f_1(x_1, x_2, \ldots)|_{\vec{t}}\\
  1
\end{bmatrix} \right)
 = 0
$$

In [None]:
# symbols
x, y = sp.symbols('x y')

In [None]:
# define your funcitons here in explicit form
f1 = x**2 + y**2
f2 = 3*((x-1)**2) + ((y-4)**2) + 1
f3 = ((x+4)**2) + 2*((y-2)**2) + 2

plot_funcs(f1, f2, f3)