In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def function_for_roots(x):
    a = 1.01
    b = -3.04
    c = 2.07
    return a*x**2 + b*x + c #get the roots of ax^2 + bx + c

In [None]:
def check_inital_values(f, x_min, x_max, tol):
    
    #check our initial guesses
    y_min = f(x_min)
    y_max = f(x_max)
    
    #check that x_min and x_max contain a zero crossing
    if(y_min*y_max>=0.0):
        print("no zero crossing found in range = ", x_min, x_max)
        s = "f(%f) = %f, f(%f) = %f" % (x_min,y_min,x_max,y_max)
        print (s)
        return 0
    
    #if x_min is a root, then return flag == 1
    if (np.fabs(y_min)<tol):
        return 1
    
    #if x_max is a root, then return flag == 2
    if (np.fabs(y_max)<tol):
        return 2
    
    #if we reach this point, the bracket is valid, return a 3
    return 3


In [None]:
def bisection_root_finding(f, x_min_start, x_max_start, tol):
    
    x_min = x_min_start
    x_max = x_max_start
    x_mid = 0.0
    
    y_min = f(x_min)
    y_max = f(x_max)
    y_mid = 0.0
    
    imax = 10000  #set max number of iterations
    i = 0.0       #iteration counter
    
    #check the initial values
    flag = check_initial_values(f, x_min, x_max, tol)
    if (flag==0):
        print("Error in bisection_root_finding().")
        raise ValueError('Initial values invalid', x_min, x_max)
    elif (flag==1): 
        #lucky guess
        return x_min
    elif (flag==2):
        #another lucky guess
        return x_max
    #if we reach here, then we need to conduct the search
    #set a flag
    flag = 1
    
    #lets do a while loop
    while(flag):
        x_mid = 0.5*(x_min + x_max)  #mid point
        y_mid = f(x_mid)
        
        #check if x_mid is a root
        if (np.fabs(y_mid)<tol):
            flag = 0
        else:
            #x_mid is not a root
            
            #if the product of the funtion at the midpoint and at 
            #one of the endpoints is greater than zero, replace this endpoint
            if (f(x_min)*f(x_mid)>0):
                #replace x_min with x_mid
                x_min = x_mid
            else:
                #replace x_max with x_mid
                x_max = x_mid
        
        #print the iteration
        print(x_min,f(x_min),x_max,f(x_max))
        
        #count the iteration
        i += 1
        
        #if we exceed the max number of iterations, exit
        if (i>=imax):
            print("Exceeded max number of iterations = ", i)
            