In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from numpy.polynomial import Polynomial
import sympy
from functools import reduce
from scipy.spatial.distance import cosine
from tqdm import tqdm
import helpers
from sympy.solvers.solveset import linsolve
from sympy.printing.latex import LatexPrinter, print_latex

import polynomial_helpers as ph


In [None]:
def get_equations(x, P_poly):
    P = lambda x: P_poly.as_poly().eval(x)
    dP = lambda x: P_poly.diff().as_poly().eval(x)    

    # The degree of the polynomial s
    num_symbols = 2*max_t + 3

    # The coefficients of the polynomial s
    a_dict = {i: sympy.symbols("a{}".format(i)) for i in range(num_symbols)}

    # The equations to require that the derivative of s is dl/dx
    dldx_dsdt_equations = [ph.get_dldx_dsdt_equation(t=i, a_dict=a_dict, dP=dP) for i in range(1, max_t + 1)]

    # The equations to require that s passes through gradient descent
    dynamical_system_equations = [ph.get_dynamical_system_equation(t=i, a_dict=a_dict, dP=dP) for i in range(1, max_t + 1)]
    return dldx_dsdt_equations + dynamical_system_equations, a_dict


def get_values(x, P, max_t):
    equations, a_dict = get_equations(x, P)
    solution_dict = ph.get_solution_dict(equations=equations,  a_dict=a_dict)
    t_values = range(1, max_t + 1)
    a_values = {i: solution_dict[a_dict[i]] for i in range(len(solution_dict))}
    s_values = [ph.evaluate_s(t_value=t_value, a_dict=a_values) for t_value in t_values]
    return t_values, a_values, s_values



In [None]:
from textwrap import wrap

fig = plt.figure(figsize=(35, 20))
fig.tight_layout()
max_t = 6


plt.subplot(1,2,1)
x = sympy.symbols("x")
P = (x*x)
t_values, a_values, s_values = get_values(x, P, max_t)
title = (("$s(t) = {}$".format(LatexPrinter().doprint(ph.evaluate_s(sympy.symbols("t"), a_values)))  + "\n")
                .replace("t^{13} -", "t^{13} - $\n$")
                .replace("t^{13} +", "t^{13} + $\n$")
                .replace("t^{10} +", "t^{10} + $\n$")
                .replace("t^{10} -", "t^{10} + $\n$")
                .replace("t^{6} +", "t^{6} + $\n$")
                .replace("t^{6} -", "t^{6} + $\n$"))
plt.title(title, wrap=False, fontsize=20)
plt.scatter(t_values, s_values, s=400)
plt.xlabel("$t$", fontsize=30)
plt.ylabel("$s(t)$", fontsize=30)
plt.gca().tick_params(axis='both', which='major', labelsize=20)
plt.gca().tick_params(axis='both', which='minor', labelsize=20)


plt.subplot(1,2,2)
x = sympy.symbols("x")
P = (2*x*x - 1)
t_values, a_values, s_values = get_values(x, P, max_t)

title = (("$s(t) = {}$".format(LatexPrinter().doprint(ph.evaluate_s(sympy.symbols("t"), a_values)))  + "\n")
                .replace("t^{13} -", "t^{13} - $\n$")
                .replace("t^{13} +", "t^{13} + $\n$")
                .replace("t^{10} +", "t^{10} + $\n$")
                .replace("t^{10} -", "t^{10} + $\n$")
                .replace("t^{6} +", "t^{6} + $\n$")
                .replace("t^{6} -", "t^{6} + $\n$"))
plt.title(title, wrap=False, fontsize=20)
plt.scatter(t_values, s_values, s=400)
plt.xlabel("$t$", fontsize=30)
plt.ylabel("$s(t)$", fontsize=30)
plt.gca().tick_params(axis='both', which='major', labelsize=20)
plt.gca().tick_params(axis='both', which='minor', labelsize=20)


In [None]:
title = (("$l_P = {}$\n$s(t) = {}$".format(
    LatexPrinter().doprint(ph.evaluate_s(sympy.symbols("t"), {x:2})),
    LatexPrinter().doprint(ph.evaluate_s(sympy.symbols("t"), a_values)))  + "\n")
                .replace("t^{13} -", "t^{13} - $\n$")
                .replace("t^{13} +", "t^{13} + $\n$")
                .replace("t^{10} +", "t^{10} + $\n$")
                .replace("t^{10} -", "t^{10} + $\n$")
                .replace("t^{6} +", "t^{6} + $\n$")
                .replace("t^{6} -", "t^{6} + $\n$"))