In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

## Define fxn to find roots of and plot

In [None]:
def fxn_for_roots(x):
    a = 1.01
    b = -3.04
    c = 2.07
    # Get roots of eqn below
    return a*x**2 + b*x + c

## Fxn to check whether initial values valid

In [None]:
def check_initial_values(f, x_min, x_max, tol):
    
    # Check initial guess
    y_min = f(x_min)
    y_max = f(x_max)
    
    # Check that x_min, x_max contain a root in the interval
    if(y_min * y_max >= 0.0):
        print("No root found in range = ", x_min, x_max)
        s = "f(%f) = %f, f(%f) = %f" % (x_min, y_min, x_max, y_max)
        print(s)
        return 0
    
    # If x-min is a root, return flag == 1
    if(np.fabs(y_min) < tol):
        return 1
    
    # If x_max a root, return flag == 2
    if(np.fabs(y_max) < tol):
        return 2
    
    # If this is reached, interval valid, return 3
    return 3

## Define main work fxn that performs iterative search

In [None]:
# This fxn uses bisection search to find root
def bisection_root_finding(f, x_min_start, x_max_start, tol):
    
    # Min x in bracket
    x_min = x_min_start
    # Max x in bracket
    x_max = x_max_start
    # Midpoint
    x_mid = 0.0
    
    # Fxn value at x_min
    y_min = f(x_min)
    # Fxn value at x_max
    y_max = f(x_max)
    # Fxn value at midpoint
    y_mid = 0.0
    
    # Set max # of iterations
    imax = 10000
    # Iteration counter
    i = 0
    
    # Check initial values
    flag = check_initial_values(f, x_min, x_max,tol)
    if(flag == 0):
        print("Error in bisection_root_finding().")
        raise ValueError('Initial values invalid', x_min, x_max)
    elif(flag == 1):
        # Lucky guess
        return x_min
    elif(flag == 2):
        # Lucky guess
        return x_max
    
    # If this point in code reached, perform search
    
    # Set a flag
    flag = 1
    
    # Enter while loop
    while(flag):
        # Midpoint
        x_mid = 0.5*(x_min + x_max)
        # Fxn value at x_mid
        y_mid = f(x_mid)
        
        # Check if x_mid is root
        if(np.fabs(y_mid) < tol):
            flag = 0
        else:
            # x_mid is not a root
            
            # If the product of fxn at midpoint and
            # at one of the end point is greater than
            # 0, replace this end point
            if(f(x_min)*f(x_mid) > 0):
                # Replace x_min with x_mid
                x_min = x_mid
            else:
                # Replace x_max with x_mid
                x_max = x_mid
                
        # Print iteration
        print(x_min, f(x_min), x_max, f(x_max))
        
        # Count the iteration
        i += 1
        # Print # of iterations; i starts at zero
        # then adds one at the end of each
        # iteration
        print("Iteration number is: ", i)
        
        # If exceeded max # of iterations, exit
        if(i >= imax):
            print("Exceeded max number of iterations = ", i)
            s = "Min bracket f(%f) = %f" % (x_min, f(x_min))
            print(s)
            s = "Max bracket f(%f) = %f" % (x_max, f(x_max))
            print(s)
            s = "Mid bracket f(%f) = %f" % (x_mid, f(x_mid))
            print(s)
            raise StopIteration('Stopping iterations after ', i)
            
    # Finished
    return x_mid


## Perform search

In [None]:
x_min = 0.0
x_max = 1.5
tolerance = 1.0e-6

# Print initial guess
print(x_min, fxn_for_roots(x_min))
print(x_max, fxn_for_roots(x_max))

x_root = bisection_root_finding(fxn_for_roots, x_min, x_max, tolerance)
y_root = fxn_for_roots(x_root)

t = "Root found with y(%f) = %f" % (x_root, y_root)
print(t)

x_min_1 = 1.7
x_max_1 = 3.0
tolerance = 1.0e-6

print(x_min_1, fxn_for_roots(x_min_1))
print(x_max_1, fxn_for_roots(x_max_1))

x_root_1 = bisection_root_finding(fxn_for_roots, x_min_1, x_max_1, tolerance)
y_root_1 = fxn_for_roots(x_root_1)

s = "Root found with y(%f) = %f" % (x_root_1, y_root_1)
print(s)
                             
# Values
a = 1.01
b = -3.04
c = 2.07

# Fxn to plot
x = np.linspace(0, 3, 1000)
y = a*x**2 + b*x + c
z = 0.0*x

# Plot fxn
plt.plot(x, y, label = 'y')
plt.plot(x, z, label = 'y = 0')

# Plot bracket points
plt.plot(x_min, fxn_for_roots(x_min), 'ro', label = 'Lower bracket')
plt.plot(x_max, fxn_for_roots(x_max), 'ro')
plt.plot(x_min_1, fxn_for_roots(x_min_1), 'bo', label = 'Upper bracket')
plt.plot(x_max_1, fxn_for_roots(x_max_1), 'bo')

# Plot roots
plt.plot(x_root, y_root, 'go', label = 'Root ~ 1.04')
plt.plot(x_root_1, y_root_1, 'ko', label = 'Root ~ 1.97')


plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.ylim(-0.5, 2.1)
plt.show()





## It takes 18 iterations for lower bracket; 17 iterations for upper bracket