<a href="https://colab.research.google.com/github/neverneeth/gradient-descent/blob/master/Gradient_Descent(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact
from sympy import Symbol, solve, lambdify

X = Symbol('x')

In [None]:
def poly_func(a, b, c, d, x):
    return a*x**3 + b*x**2 + c*x + d

def poly_grad(a, b, c, d, x):
    return 3*a*x**2 + 2*b*x + c

def relu(x):
    return np.maximum(0, x)

def gradient_descent(x0, a, b, c, d, lr=0.001, epochs=1000):
    x = x0
    for _ in range(epochs):
        grad = poly_grad(a, b, c, d, x)
        grad = np.clip(grad, -1e3, 1e3)
        grad = relu(grad)
        x = x - lr * grad
    final_value = poly_func(a, b, c, d, x)
    return x, grad, final_value


def plot_and_analyze(a, b, c, d):
    plt.figure(figsize=(8, 5))
    x_vals = np.linspace(-10, 10, 400)
    y_vals = poly_func(a, b, c, d, x_vals)

    plt.clf()

    plt.plot(x_vals, y_vals, label=f"${a:.1f}x^3 + {b:.1f}x^2 + {c:.1f}x + {d:.1f}$", color='blue')
    plt.axhline(0, color='gray', linestyle='--')
    plt.xlabel("x")
    plt.ylabel("f(x)")
    plt.title("Cubic Polynomial Function")

    X = Symbol('x')
    expr = a*X**3 + b*X**2 + c*X + d
    roots = solve(expr, X)
    real_roots = [r.evalf() for r in roots if r.is_real]

    for r in real_roots:
        try:
            plt.plot(float(r), 0, 'ro')
        except:
            continue

    plt.grid(True)
    plt.legend()
    plt.show()

    print("Real roots:")
    print([f"{r:.5f}" for r in real_roots])

    x0 = np.random.uniform(-5, 5)
    x_approx, grad, f_val = gradient_descent(x0, a, b, c, d, lr=0.001, epochs=1000)

    print("\nGradient Descent Finished")
    print(f"Initial x = {x0:.4f}")
    print(f"  Approx root x = {x_approx:.5f}, f(x) = {f_val:.5f}, grad = {grad:.5f}")
    print("P.S.: THE TASK ASKS US TO USE RELU TO SOLVE A POLYNOMIAL. THIS WORKS FOR POSITIVE ROOTS BUT LEADS TO VERY LARGE ERRORS FOR NEGATIVE ROOTS")

In [None]:
interact(
    plot_and_analyze,
    a=widgets.FloatSlider(value=1, min=-10, max=10, step=0.1),
    b=widgets.FloatSlider(value=0, min=-10, max=10, step=0.1),
    c=widgets.FloatSlider(value=0, min=-10, max=10, step=0.1),
    d=widgets.FloatSlider(value=0, min=-10, max=10, step=0.1),
);

interactive(children=(FloatSlider(value=1.0, description='a', max=10.0, min=-10.0), FloatSlider(value=0.0, des…