# Example of polynomial solving + delta correction pipeline

In [1]:
# Imports
from io import StringIO
import time
import random
import sys

import numpy as np
import pandas as pd

from sympy import symbols, solve, Poly, simplify, latex, Rational, init_printing

init_printing(use_latex=True)

In [2]:

def generate_nondimensionalized(max_n1):
    """
    Inputs:
    max_n1 (int): Maximum power in polynomial

    Outputs:
    polynomial (sympy expression): Nondimensionalized random polynomial of the given format
    """
    # Ensure max_n1 is at least 2
    if max_n1 < 2:
        raise ValueError("max_n1 must be at least 2")

    # Randomly choose n1 and n2
    n1 = random.randint(2, max_n1)
    n2 = random.randint(1, n1 - 1)
    # Randomly choose signs
    signs = random.choices([-1, 1], k=2)

    # Construct polynomial
    x = symbols('x')
    epsilon = symbols('\epsilon')
    polynomial = epsilon * x**n1 + signs[0] * x**n2 + signs[1]

    # Return
    return polynomial, x, epsilon

def solve_nondimensionalized(polynomial, x, epsilon):
  s = StringIO()

  # Extract terms and solve
  # x, epsilon = symbols('x \epsilon')
  terms = list(polynomial.expand().args)[::-1]
  A, B, C = terms[0], terms[1], terms[2]
  print('Entire polynomial:', latex(A), latex(B), latex(C), file=s)
  print('Term A:', latex(A), file=s)
  print('Term B:', latex(B), file=s)
  print('Term C:', latex(C), file=s)
  sol_ab = [simplify(sol) for sol in solve(A + B, x)]
  sol_bc = [simplify(sol) for sol in solve(B + C, x)]
  sol_ac = [simplify(sol) for sol in solve(A + C, x)]

  # Remove extraneous root 0 that shows up - bc it never occurs in this problem formulation
  # Helper function to remove zeros
  def remove_zeros(sol_list):
      while 0 in sol_list:
          sol_list.remove(0)
  remove_zeros(sol_ab)
  remove_zeros(sol_bc)
  remove_zeros(sol_ac)

  # Check dominant balances to see if roots belong to small or large epsilon regimes
  AB_valid_small_eps = (abs(A.subs(x, sol_ab[0]).subs(epsilon, 0.0001)) > abs(C.subs(x, sol_ab[0]).subs(epsilon, 0.001))) # A,B >> C small eps
  AB_valid_large_eps = (abs(A.subs(x, sol_ab[0]).subs(epsilon, 10000)) > abs(C.subs(x, sol_ab[0]).subs(epsilon, 1000))) # A,B >> C large eps

  BC_valid_small_eps = (abs(B.subs(x, sol_bc[0]).subs(epsilon, 0.0001)) > abs(A.subs(x, sol_bc[0]).subs(epsilon, 0.001))) # B,C >> A small eps
  BC_valid_large_eps = (abs(B.subs(x, sol_bc[0]).subs(epsilon, 10000)) > abs(A.subs(x, sol_bc[0]).subs(epsilon, 1000))) # B,C >> A large eps

  AC_valid_small_eps = (abs(A.subs(x, sol_ac[0]).subs(epsilon, 0.0001)) > abs(B.subs(x, sol_ac[0]).subs(epsilon, 0.001))) # A,C >> B small eps
  AC_valid_large_eps = (abs(A.subs(x, sol_ac[0]).subs(epsilon, 10000)) > abs(B.subs(x, sol_ac[0]).subs(epsilon, 1000))) # A,C >> B large eps

  print("AB valid for small \epsilon:", latex(AB_valid_small_eps), file=s)
  print("AB valid for large \epsilon:", latex(AB_valid_large_eps), file=s)
  print("BC valid for small \epsilon:", latex(BC_valid_small_eps), file=s)
  print("BC valid for large \epsilon:", latex(BC_valid_large_eps), file=s)
  print("AC valid for small \epsilon:", latex(AC_valid_small_eps), file=s)
  print("AC valid for large \epsilon:", latex(AC_valid_large_eps), file=s)
  print("\n", file=s)

  # TODO: Figure out how to go backwards from multiple complex roots/roots of unity to simpler format (e.g. 1/eps**5)

  # Display results
  print("polynomial:", file=s)
  # print(display(polynomial), file=s)
  print(latex(polynomial), file=s)
  print("\n", file=s)

  print("Balance A, B | ", len(sol_ab), "roots", file=s)
  # print(display(sol_ab), file=s)
  print(latex(sol_ab), file=s)
  print("\n", file=s)

  print("Balance B, C | ", len(sol_bc), "roots", file=s)
  # print(display(sol_bc), file=s)
  print(latex(sol_bc), file=s)
  print("\n")

  print("Balance A, C | ", len(sol_ac), "roots", file=s)
  # print(display(sol_ac), file=s)
  print(latex(sol_ac), file=s)
  print("\n", file=s)

  return sol_ab, sol_bc, sol_ac, s

In [3]:

def solve_delta_corrected_term(center, poly_eqn, x_var, term_trunc=3, eps_var=None, specific_val_eps=None):
    """
    Solve for the delta correction term in a polynomial equation.

    Args:
        center (symbol or numeric value): The point at which the delta correction is applied.
        poly_eqn (sympy expression): The polynomial equation.
        x_var (symbol): The variable in the polynomial equation.
        term_trunc (int, optional): The number of terms to keep in the delta correction expansion. Defaults to 3.
        eps_var (symbol, optional): The epsilon symbol if used in the equation. Defaults to None.
        specific_val_eps (numeric, optional): Specific epsilon value to substitute if eps_var is used. Defaults to None.

    Returns:
        tuple: A tuple containing the solutions for the delta correction term and a StringIO object with step-by-step instructions.
    """
    delta = symbols('\delta')
    poly_delta = poly_eqn.subs(x_var, center + delta)

    s = StringIO()
    print('Substituting in the delta correction term:', file=s)
    print(poly_delta, file=s)

    poly_delta.expand()
    print('Expanding out the equation:', file=s)
    print(poly_delta, file=s)

    sol = solve(poly_delta.series(delta, n=term_trunc).removeO(), delta)

    if specific_val_eps:
        sol = sol.subs(eps_var, specific_val_eps)

    print('Solving the problem gives us the following values for the delta term:', file=s)
    print(sol, file=s)

    return sol, s

def get_delta_corrections(guessed_root_arr, poly_eqn, x_var, term_trunc=3, eps_var=None, specific_val_eps=None):
    """
    Calculate delta corrections for a list of guessed roots in a polynomial equation.

    Args:
        guessed_root_arr (list): List of guessed roots to compute delta corrections for.
        poly_eqn (sympy expression): The polynomial equation.
        x_var (symbol): The variable in the polynomial equation.
        term_trunc (int, optional): The number of terms to keep in the delta correction expansion. Defaults to 3.
        eps_var (symbol, optional): The epsilon symbol if used in the equation. Defaults to None.
        specific_val_eps (numeric, optional): Specific epsilon value to substitute if eps_var is used. Defaults to None.

    Returns:
        tuple: A tuple containing a list of delta correction values and a list of step-by-step instructions (StringIO objects).
    """
    instructions_arr = []
    delta_corrections_arr = []

    for guessed_root in guessed_root_arr:
        guessed_root_delta, instructions = solve_delta_corrected_term(center=guessed_root, poly_eqn=poly_eqn, x_var=x_var, term_trunc=term_trunc, eps_var=eps_var, specific_val_eps=specific_val_eps)
        delta_corrections_arr += guessed_root_delta
        instructions_arr.append(instructions)

    s = StringIO()
    print('List of all delta correction values:', file=s)
    print(delta_corrections_arr, file=s)

    instructions_arr.append(s)

    return delta_corrections_arr, instructions_arr

def format_instructions_arr(instructions_arr):
    """
    Format a list of step-by-step instructions (StringIO objects) into a single string.

    Args:
        instructions_arr (list): List of StringIO objects containing step-by-step instructions.

    Returns:
        str: A string containing formatted instructions.
    """
    combined_str_instructions = [instruction.getvalue() for instruction in instructions_arr]
    res = '\n'.join(combined_str_instructions)
    return res


In [11]:
def run_single_instruction_trial(highest_degree=10, term_trunc=2):
    """
    Run a single instruction trial for dimensional analysis and delta corrections.

    Args:
        highest_degree (int, optional): The highest degree of the generated polynomial. Defaults to 10.
        term_trunc (int, optional): The number of terms to keep in delta correction expansion. Defaults to 2.

    Returns:
        tuple: A tuple containing the LaTeX representation of the generated polynomial and a string with step-by-step instructions.
    """
    # Generate a polynomial
    polynomial, x, epsilon = generate_nondimensionalized(highest_degree)
    polynomial_str = latex(polynomial)

    # compile dimensional balance solutions
    sol_ab, sol_bc, sol_ac, solving_nondim_instructions = solve_nondimensionalized(polynomial=polynomial, x=x, epsilon=epsilon)
    all_sols = sol_ab + sol_ac + sol_bc

    # linear truncation
    delta_corrections_arr, instructions_arr = get_delta_corrections(guessed_root_arr=all_sols, poly_eqn=polynomial, x_var=x, term_trunc=term_trunc)

    # combine instructions
    instruction_stuff = '\n'.join([instruction.getvalue() for instruction in [solving_nondim_instructions] + instructions_arr])

    return latex(polynomial), instruction_stuff


In [5]:
def generate_dataset_delta_correction(n_problems=10, highest_degree=5):
    """
    Generate a dataset of math problems and their corresponding answers using delta correction.
    Known error: If finding the delta correction of a specific root proves to take too long, we don't terminate our process. Hence keep truncations low to make the algebra easier.

    Args:
        n_problems (int, optional): The number of problems to generate. Defaults to 10.
        highest_degree (int, optional): The highest degree of math problems to generate. Defaults to 5.

    Returns:
        pandas.DataFrame: A DataFrame containing two columns, 'Question' and 'Answer', where each row
        represents a math problem and its corrected answer.
    """
    question_arr = []
    answer_arr = []

    for problem_idx in range(n_problems):
        question_str, instruction_str = run_single_instruction_trial(highest_degree=highest_degree)
        question_arr.append(question_str)
        answer_arr.append(instruction_str)

    df = pd.DataFrame({
        'Question': question_arr,
        'Answer': answer_arr
    })

    return df


### Example usage starts here:

In [None]:
df = generate_dataset_delta_correction(n_problems=10, highest_degree=5)
df.head()

## Adapting to plot numerics vs analytics

In [128]:
from sympy import *

def solve_delta_corrected_term(center, validity, poly_eqn, x_var, term_trunc=3, eps_var=None, specific_val_eps=None):

  if term_trunc == 1:
    return ValueError("Need term_trunc to be larger than 1.")

  delta = symbols('\delta') # delta is the sum of eps terms (expansion)
  poly_delta = poly_eqn.subs(x_var, center + delta)
  print(poly_delta) # this is delta subbed back into the original eqn
  poly_delta = poly_delta.expand()
  poly_delta = poly_delta.series(delta, n=term_trunc).removeO() # selects correct number of terms
  print(poly_delta)

  sol = solve(poly_delta, delta)
  sol = [simplify(d) for d in sol]

  print(sol, len(sol))

  # get 2 solutions for delta — one will take you closer to root and other further away
  # need to select the correct delta: use the fact that delta < root approx
  if validity == "small":
    test_eps = 0.0001
  elif validity == "large":
    test_eps = 10000

  if eps_var:
    delta_mags = [abs(d.subs(eps_var, test_eps).evalf()) for d in sol]
    try:
      root_mag = abs(center.subs(eps_var, test_eps)).evalf()
    except:
      root_mag = abs(center)
    if delta_mags[0] < root_mag:
      delta_idx = 0
    elif delta_mags[1] < root_mag:
      delta_idx = 1
    else:
      raise Exception("Cannot calculate corrections.")

  print(sol, delta_idx)
  delta_correction = sol[delta_idx]

  if specific_val_eps:
      delta_correction = simplify(delta_correction.subs(eps_var, specific_val_eps))

  return delta_correction

In [133]:
x, eps = symbols('x epsilon')
polynomial = eps * x**5 - x + 1
root_approx = -I*(1/eps)**(1/4)
correction = solve_delta_corrected_term(root_approx, "small", polynomial, x, term_trunc=3, eps_var=eps, specific_val_eps=0.001)

-\delta + epsilon*(\delta - I*(1/epsilon)**0.25)**5 + I*(1/epsilon)**0.25 + 1
10*I*\delta**2*epsilon*(1/epsilon)**0.75 + \delta*(5*epsilon*(1/epsilon)**1.0 - 1) - I*epsilon*(1/epsilon)**1.25 + I*(1/epsilon)**0.25 + 1
[I*(0.2 - 0.316227766016838*sqrt(-I*epsilon*(1/epsilon)**0.75 + 0.4))/(epsilon*(1/epsilon)**(3/4)), I*(0.316227766016838*sqrt(-I*epsilon*(1/epsilon)**0.75 + 0.4) + 0.2)/(epsilon*(1/epsilon)**(3/4))] 2
[I*(0.2 - 0.316227766016838*sqrt(-I*epsilon*(1/epsilon)**0.75 + 0.4))/(epsilon*(1/epsilon)**(3/4)), I*(0.316227766016838*sqrt(-I*epsilon*(1/epsilon)**0.75 + 0.4) + 0.2)/(epsilon*(1/epsilon)**(3/4))] 0


In [134]:
correction

  ⎛                                     ___________________________⎞
ⅈ⋅⎝1.1246826503807 - 1.77827941003892⋅╲╱ 0.4 - 0.177827941003892⋅ⅈ ⎠

In [None]:
"""def calculate_correction_terms(root_approx_list, poly):
    epsilon, a_1, a_2, a_3 = symbols('epsilon a_1 a_2 a_3')
    correction_terms = []
    for root_approx in root_approx_list:
        root_approx = simplify(root_approx)
        print("Root approximation: ", root_approx)
        power = degree(root_approx, gen=epsilon)
        pm = sign(root_approx)
        if power == 0 or power == 1:
            root_expr = pm + a_1*epsilon + a_2*(epsilon**2) + a_3*(epsilon**3)
        else:
            root_expr = pm + a_1*(epsilon**power) + a_2*(epsilon**(2*power)) + a_3*(epsilon**(3*power))

        expr = poly.subs(x, root_expr)
        expr = collect(expr.expand(), epsilon)

        equations = [expr.coeff(epsilon, n) for n in range(4)]
        print("Equations to solve: ", equations)
        try:
            solutions = solve(equations, (a_1, a_2, a_3))
        except Exception as e:
            print("Sympy solve failed: ", e)
            # Convert sympy equations to lambda functions so they can be used with scipy's fsolve
            f = lambda x: [eq.subs({a_1: x[0], a_2: x[1], a_3: x[2]}).evalf() for eq in equations]
            solutions = fsolve(f, (0, 0, 0))
        print("Solutions: ", solutions)
        correction_terms.append(solutions)

    return correction_terms"""

In [18]:
def run_single_instruction_trial(term_trunc=2):
    """
    Run a single instruction trial for dimensional analysis and delta corrections.

    Args:
        highest_degree (int, optional): The highest degree of the generated polynomial. Defaults to 10.
        term_trunc (int, optional): The number of terms to keep in delta correction expansion. Defaults to 2.

    Returns:
        tuple: A tuple containing the LaTeX representation of the generated polynomial and a string with step-by-step instructions.
    """
    # Generate a polynomial
    x = symbols('x')
    epsilon = symbols('\epsilon')
    polynomial = epsilon*x**5 - x + 1
    polynomial_str = latex(polynomial)

    # compile dimensional balance solutions
    sol_ab, sol_bc, sol_ac, solving_nondim_instructions = solve_nondimensionalized(polynomial=polynomial, x=x, epsilon=epsilon)
    all_sols = sol_ab + sol_ac + sol_bc

    # linear truncation
    delta_corrections_arr, instructions_arr = get_delta_corrections(guessed_root_arr=all_sols, poly_eqn=polynomial, x_var=x, term_trunc=term_trunc)

    # combine instructions
    instruction_stuff = '\n'.join([instruction.getvalue() for instruction in [solving_nondim_instructions] + instructions_arr])

    return latex(polynomial), instruction_stuff, delta_corrections_arr


In [23]:
poly, ins, deltas = run_single_instruction_trial(term_trunc=3)





In [24]:
deltas

⎡                                                                             
⎢                                                                             
⎢         ___________________________________________________________________ 
⎢        ╱                    __________                                      
⎢       ╱             3/4    ╱    1                     3/4       __________  
⎢  ⅈ⋅  ╱   10⋅\epsilon   ⋅4 ╱  ────────  - 10⋅ⅈ⋅\epsilon    - 6⋅╲╱ \epsilon   
⎢    ╲╱                   ╲╱   \epsilon                                       
⎢- ────────────────────────────────────────────────────────────────────────── 
⎢                                    __________                               
⎣                               10⋅╲╱ \epsilon                                

                                                                              
                                                                              
                         __________________________

In [25]:
len(deltas)

20