# The Chain Rule

### Introduction

At this point, we have used gradient descent to find the minimum of a cost curve.  But we have skipped over the steps of finding the derivative for our cost curve.  In our Applying Gradient Descent Lab, for example, we gave you the derivatives without calculating them ourselves.

Why did we do this?

Remember that with gradient descent, we find the change in our function's output as we alter each parameter, and then take a step in the direction of steepest descent.  But if we try to use this technique to see how changing our paramters changes the output of our cost function, we'll see that the impact on the output is a bit more indirect.  

### Seeing the issue

In previous lessons, we found the rate of change in our output as we changed the parameters in a function as the following:

$$z(w,b) = 3w + b$$

To find how the output of $z$ changes as we change each parameter, we take the partial derivative with respect to each parameter, above $w$ and $b$.  

$$\frac{\delta z}{\delta w} = 3, \frac{\delta z}{\delta b} = 1$$

But when trying to find the parameters that minimize a *cost curve $J$* , this time we are not finding the impact of altering the parameters on that same function $z$, but on *another* function, our cost function -- here the sum of the squared error: 

$$J(w,b) = (y - z(w, b))^2 = (y - (3w + b))^2$$

So calculating how a change to $w$ or $b$ impacts our cost function is more complicated than in our first function.

And really it's more indirect than even that, because $z(w, b)$ is just the linear component, which is then fed to the activation function $\sigma$ to make the prediction, which is then passed into our cost function:

$$J(w,b) = (y - \sigma(z(w, b)))^2$$

So as we can see, the impact of changing our parameters, $w$ and $b$, on the output of our cost curve $J$ is well, complicated.  

But don't worry, mathematicians have already figured out how to solve problems like the one above.  We just have to learn their approach.

### Introducing the Chain Rule

The problem that we are running into above, is how to find the derivative of this nested function

$$J(w,b) = (y - \sigma(z(w, b)))^2$$

By a nested function, we mean a function where our parameters $w$ and $b$ are nested inside of a function $z$ whose output is fed to a function $\sigma$ whose output is fed as a component to $J$.

Mathematicians call these functions [composite functions](https://en.wikipedia.org/wiki/Function_composition).  And lucky for us, they have developed a technique for calculating derivatives of composite functions.  That technique is the chain rule.

But that composite function $J(w,b)$ is a little too difficult to be our starting point.  

Instead let's learn about the chain rule by finding the derivative of with a simpler composite function.  Then later, we'll use our new knowledge about the chain rule to find the derivative of our cost function $J(w, b)$.

This is our simpler composite function:

$$f(x) = (3x + 1)^2$$

#### 1. Break it down

* *In Math*

Do you see how this is a composite function?  We start with the function:

$$f(x) = (3x + 1)^2$$

And then we break this into two functions, $h(x)$ and $g(y)$ where:

$$h(x) = 3x + 1$$
$$ g(y) = y^2$$

So now think about what $g(h(x))$ equals.  

> Well $g(h(x)) = g(3x + 1) = (3x + 1)^2$

So given our functions $h(x)$ and $g(y)$ above, we can rewrite our function, $f(x)$ as:

$$f(x) = g(h(x)) $$ 

So we broke our function $f(x)$ down above, by defining two functions $h(x)$ and $g(y)$, and then passing the output of $h(x)$ into $g(y)$.

* *In Code* 

This idea of breaking down a composite function into it's constituent parts is a critical part to understanding the chain rule.  So let's also see how we can do this through using some Python.  

First, we have write function $f(x) = (3x + 1)^2$ not broken down:

In [3]:
def f(x): 
    return 3(x + 1)**2

And then we can break this function into two components $h(x) = 3x + 1$ and $g(y) = y^2$.

In [4]:
def h(x):
    output = (3*x + 1)
    print(f"h({x}) = ", output)
    return (3*x + 1)

def g(y): 
    output = y**2
    print(f"g({y}) = ", output)
    return output

And now we rewrite $f(x)$ as $f(x) = g(h(x))$.  

In [5]:
# def f(x): return (3*x + 1)**2 is equivelent to:
def f(x): 
    return g(h(x))

Ok, now let's try it out! 

In [6]:
f(2)

h(2) =  7
g(7) =  49


49

So as we can see, we can break down our a complicated function into it's components, and then rewrite the function.  Once we do that, finding the derivative of this composite function becomes more manageable.

#### 2. Finding the derivative

Now, so far, we have rewritten our function 
$$f(x) = (3x + 1)^2$$ as:

$$f(x) = g(h(x)) $$  where:

$$h(x) = 3x + 1$$ 
$$ g(h) = h(x)^2$$

Now to find the derivative $f(x)$ with respect to $x$ we apply the chain rule.


> **The chain rule:** take the derivative of the outer function $g'(h(x))$ and multiply it by the derivative of the inner function $h'(x)$.

$f'(x) = g'(h(x)) * h'(x)$

Confused?  

Good, me too.

Let's see this in practice by finding the derivative of our composite function:

$$f(x) = (3x + 1)^2$$

We know this can be rewritten as:

$$h(x) = 3x + 1$$
$$ g(h) = h(x)^2$$

$$f(x) = g(h(x)) $$

Now we find the derivative of the outer function, $g(h(x))$ and the derivative of the inner function  $h(x)$.

* Outer function: 
    * $ g(h(x)) = h(x)^2$ so 
* $g'(h(x)) =  2h(x)$


* Inner function: 
    * $h(x) = 3x + 1$ and 
* $h'(x) =  3$

Now that we found the derivative of the outer function is $g'(h(x)) = 2h(x)$ and the derivative of the inner function is $h'(x) = 3$.  Now it's time to apply:

> **The chain rule:** Multiply the derivative of the outer function by the derivative of the inner function.


Or $f'(x) = g'(h(x)) * h'(x)$.

So substituting we get $f'(x) = 2h(x)*h'(x) = 2h(x)*3 = 6*h(x)$

And because $h(x) = (3x + 1)$, substituting further we get: 

$f'(x) = 6h(x) = 6(3x + 1) = 18x + 6 $

### Wrapping Up

So what we just did is pretty cool.  We were able to calculate how nudging the value of $x$ impacts a composite function $$f(x) = (3x + 1)^2$$.

And found that: 

$f'(x) = 18x + 6 $

Now let's make sure we can interpret what we found.  We calculated how the output of $f(x)$ changes as we nudge our value of $x$ at different points. 

So for example, when $x = 3$, the rate of change of our function with respect to $x$ is $f'(3) = 18x + 6  = 18*3 + 4 = 60$. 

And we can check this result with our code.

In [8]:
f(3)

h(3): 10
g(10): 100


100

In [9]:
f(3.1)

h(3.1): 10.3
g(10.3): 106.09000000000002


106.09000000000002

In [10]:
#df/dx =  (f_2 - f_1)/(x2 - x1)
(106.09 - 100)/.1

60.900000000000034

Being able to find the derivative of composite function by applying the chain rule is critical because our cost function is for a neuron (and indeed a neural network), is one big composite function:

$$J(w,b) = (y - z(w, b))^2$$

And we will need to calculate the instantaneous rate of change in our cost function as we change our parameters  -- that is our weights and our bias.  As we know we'll use this instantaneous rate of change to determine how to update our parameters so that we can find the parameters that minimize our cost function $J$.

<center>
<a href="https://www.jigsawlabs.io/free" style="position: center"><img src="https://storage.cloud.google.com/curriculum-assets/curriculum-assets.nosync/mom-files/jigsaw-labs.png" width="15%" style="text-align: center"></a>
</center>