## How to solve a^2 + a^3 = 392 by cost function?

* assume the y = f(x) = theta^2 + theta^3 - 392
* while t = 0, cost(x,t) = (t-y)^2 = y^2 = f(x)^2
* grad(cost(x,t)) = 2 * f(x) * grad(f(x)) = 2 * f(x)*(2 * theta + 3 * theta^2)
* theta(n) = theta(n-1) - learning rate * 2 * fn-1(x) * (2 * theta(n-1) + 3 * theta(n-1)^2)

In [15]:
import random

def f(a):
    return a**2 + a**3 - 392

def cost(a):
    return f(a)**2

def grad(a):
    return 2*f(a)*((3*a)**2 + 2*a)

def solve(eps, lr, max_iterations = 1000, patience = 1, decay = 0.5):
    a = random.random()
    i = 1
    print('initial value :', a)
    print('initial learning rate:', lr)
    curr_cost = cost(a)
    d_cost = abs(curr_cost)
    prev_costs = []
    while d_cost > eps and max_iterations > 0:
        a -= grad(a) * lr
        print('iteration {}: {}'.format(i,a))
        max_iterations -= 1
        
        if len(prev_costs) > 0 and curr_cost < prev_costs[-1]:
            prev_costs.clear()
        elif len(prev_costs) == patience and curr_cost > prev_costs[-1]:
            lr *= decay
            print('new learning rate: ',lr)
            prev_costs.clear()
            
        prev_cost, curr_cost = curr_cost, cost(a)
        d_cost = abs(curr_cost - prev_cost)
        prev_costs.insert(0, prev_cost)
        i += 1
    return a if max_iterations > 0 else None

eps = 0.0001
lr = 0.00003
print('Solution: ', solve(eps,lr))

initial value : 0.8058571494440805
initial learning rate: 3e-05
iteration 1: 0.9807062178816265
iteration 2: 1.229215635088834
iteration 3: 1.6036351930513966
iteration 4: 2.212849558338123
iteration 5: 3.3076986793452003
iteration 6: 5.48210210578597
iteration 7: 8.81200141673078
iteration 8: -7.090355602292746
new learning rate:  1.5e-05
iteration 9: 2.089560748471028
new learning rate:  7.5e-06
iteration 10: 2.336399518529771
iteration 11: 2.638055266144595
iteration 12: 3.0115761491418995
iteration 13: 3.4791203088256073
iteration 14: 4.0663403936148
iteration 15: 4.791979271507468
iteration 16: 5.632111344996588
iteration 17: 6.440570205190652
iteration 18: 6.923475848472506
iteration 19: 7.0049045987148935
iteration 20: 6.999504246895627
iteration 21: 7.000048879933854
iteration 22: 6.999995168343219
Solution:  6.999995168343219
