In [1]:
#Gradient descent is a tool that minimizes the error of a model or maximizes the likelyhood of the data
#represents the solution to an optimization problem

def sum_of_squares(v):
    '''computes the sum of squared elements in v'''
    return sum(v_i ** 2 for v_i in v)

# Our goal is to find the input of v that maximizes or min such functions
# a gradient is a vector of partial derivatives
# this process is pretty much similar to a mcmc

def difference_quotient(f, x, h):
    return (f(x + h) - f(x))/h
# limit as h approaches 0

def square(x):
    return x**2

def derivative(x):
    return 2*x

In [3]:
import matplotlib.pyplot as plt 
def partial_difference_quotient(f,v,i,h):
    """Compute the ith partial difference quotent of f at v"""
    w = [v_j + (h if j == i else 0)
         for j,v_j in enumerate(v)]
    return(f(w) - f(v))/h

def estimate_gradient(f,v,h=0.0001):
    return [partial_difference_quotient(f,v,i,h)
           for i,_ in enumerate(v)]

x = range(-10,10)
#plt.plot(x, map(estimate_gradient,x**3,x))


In [6]:
import random
import math

def step(v,direction, step_size):
    '''move step_size in the direction from v'''
    return [v_i + step_size * direction_i
           for v_i, direction_i in zip(v,direction)]

def sum_of_squares_gradient(v):
    return [2*v_i for v_i in v]

# pick a random starting point
v = [random.randint(-10,10) for i in range(3)]
tolerance = 0.000001
print(v)
while True:
    gradient = sum_of_squares_gradient(v)
    #print(gradient)
    next_v = step(v, gradient, -0.01)
    if math.dist(next_v,v) < tolerance:
        break
    v = next_v
print(v)

[-10, -10, 5]
[-20, -20, 10]
[-19.6, -19.6, 9.8]
[-19.208000000000002, -19.208000000000002, 9.604000000000001]
[-18.82384, -18.82384, 9.41192]
[-18.4473632, -18.4473632, 9.2236816]
[-18.078415936000003, -18.078415936000003, 9.039207968000001]
[-17.716847617280003, -17.716847617280003, 8.858423808640001]
[-17.3625106649344, -17.3625106649344, 8.6812553324672]
[-17.01526045163571, -17.01526045163571, 8.507630225817856]
[-16.674955242602998, -16.674955242602998, 8.337477621301499]
[-16.34145613775094, -16.34145613775094, 8.17072806887547]
[-16.01462701499592, -16.01462701499592, 8.00731350749796]
[-15.694334474696001, -15.694334474696001, 7.847167237348001]
[-15.380447785202081, -15.380447785202081, 7.690223892601041]
[-15.072838829498039, -15.072838829498039, 7.5364194147490196]
[-14.771382052908079, -14.771382052908079, 7.385691026454039]
[-14.475954411849917, -14.475954411849917, 7.237977205924959]
[-14.186435323612919, -14.186435323612919, 7.0932176618064595]
[-13.90270661714066, -13.