Copyright 2021 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

**Make a copy of this notebook!**

In [None]:
#@title Python imports
import collections
import datetime
from functools import partial
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy import stats
import seaborn as sns

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [None]:
# Make colab plots larger
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100

# Section 1

In this section we'll cover:

* Using Jax
* Computing derivatives with Jax
* Gradient Descent (single variable)
* Newton's Method (single variable)


## Jax

[Autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)

Jax is a python library that provides the ability to differentiate many python
functions.


In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap, api, random, jacfwd, jacrev
from jax.experimental import optimizers, stax

## Basic autodiff example

Let's take the derivative of $y = f(x) = x^2$, which we know to be $f'(x) = 2x$.



In [None]:
def square(x):
  return x * x

# Compute the derivative with grad from Jax
dsquare = grad(square)

# Plot the function and the derivative
domain = np.arange(0, 2.5, 0.1)
plt.plot(domain, square(domain), label="$y=x^2$")
plt.plot(domain, list(map(dsquare, domain)), label="$y'=\\frac{dy}{dx} = 2x$")
# Note can also use JAX's vmap: vmap(dsquare)(domain) instead of list(map(dsquare, domain))
plt.legend()
plt.show()

### Exercise

Write a function to compute the sigmoid

$$f(x) = \frac{1}{1 + e^{-x}}$$

using `jnp.exp` to compute $e^x$ -- need to use the Jax version of
numpy for autodiff to work.

Then verify that

$$ f'(x) = f(x) (1 - f(x))$$

For example, compare some explicit values and plot the differences between the derivative and $f(x) (1 - f(x))$.

In [None]:
## Your code here

# Compute the sigmoid
def sigmoid(x):
  pass

# Compute the derivative (using grad)

# Compare derivative values to f(x) * (1 - f(x))


In [None]:
# @title Solution (double-click to show)

# Compute the sigmoid
def sigmoid(x):
  return 1. / (1. + jnp.exp(-x))

# Compute the derivative (using grad)
deriv = grad(sigmoid)

# Compare derivative values to f(x) * (1 - f(x))
xs = np.arange(-3, 3.1, 0.1)
ys = []
for x in xs:
  ys.append(deriv(x) - sigmoid(x) * (1. - sigmoid(x)))
plt.scatter(xs, ys)
plt.title("Differences between Jax derivative and $f(x)(1-f(x))$\nNote the scale modifier in the upper left")
plt.show()

# Single variable gradient descent examples

In this example we solve $x^2=x$, which we know has two solutions. Different
initial starting points yield convergence to different solutions, or
non-convergence to either solution at $x_0 = 1/2$.

We need to turn the problem into one of finding local extrema. So we consider
the function $f(x) = \left( x^2 - x\right)^2$, which is differentiable and
has local minimum at $x=0$ and $x=1$. We square so that the points where
$x^2=x$ are minima instead of zeros of a function like $g(x) = x^2 - x$.

Notice in the plot below that the function also has a local maximum at $x=1/2$,
which is centered between the two solutions. Intuitively, gradient descent
starting at $x_0=1/2$ will not move because there's no reason to favor either
local minumum.

Let's plot the function first:

In [None]:
def f(x):
  return (x**2 - x)**2

xs = np.arange(-1, 2.05, 0.05)
ys = f(xs)
plt.plot(xs, ys)
plt.show()

Let's define a function to compute iterations of gradient descent.

$$\begin{eqnarray}
x_{n+1} &=& x_n - \alpha f'(x_n) \\
\end{eqnarray}$$


In [None]:
def gradient_descent(dfunc, x0, iterations=100, alpha=0.1):
  """dfunc is the derivative of the function on which we
  perform descent."""
  xs = [x0]
  for i in range(iterations):
    x = xs[-1]
    x = x - alpha * dfunc(float(x))
    xs.append(x)
  return xs

In [None]:
# Let's try it on our function f now.

# Compute the derivative
df = grad(f)

# Try some different starting points.
for x0 in [0.25, 0.5, 0.50001, 0.85]:
  xs = gradient_descent(df, x0, iterations=30, alpha=1.)
  plt.plot(range(len(xs)), xs, label=x0)
plt.xlabel("iterations")
plt.legend()
plt.show()

### Check your understanding: Explain what happened with each of the four curves in the plot.

### Exercise: What happens if we decrease the learning rate $\alpha$?
Recreate the plot above using $\alpha=0.1$ instead.

In [None]:
## Your code here


## Example 2

In this example we use gradient descent to approximate $\sqrt{3}$. We use
the function $f(x) = \left(x^2 - 3\right)^2$ and construct a sequence converging
to the positive solution. In this case notice the impact of the learning rate
both on the time to convergence and whether the convergence is monotonic or
oscillitory. Larger learning rates such as $\alpha=1$ can cause divergence.

In [None]:
def f2(x):
  return (x*x - 3)**2

df = grad(f2)

x0 = 2
for alpha in [0.08, 0.01]:
  xs = gradient_descent(df, x0, iterations=40, alpha=alpha)
  plt.plot(range(len(xs)), xs, label="$\\alpha = {}$".format(alpha))
plt.xlabel("iterations")

# Plot the correct value
sqrt3 = math.pow(3, 0.5)
n = len(xs)
plt.plot(range(n), [sqrt3]*n, label="$\\sqrt{3}$", linestyle="--")

plt.legend()
plt.show()
print("Sqrt(3) =", sqrt3)

## Exercise

Solve the equation $x = e^{-x}$, which does not have an easily obtainable solution algebraically.

The solution is approximately $x = 0.567$. Again note the impact of
the learning rate.

Use `jnp.exp` for the exponential function.


In [None]:
def f3(x):
  """Define the function f(x) = (x - e^(-x))^2."""
  ## Your code here
  pass

# Compute the gradient

# Initial guess
x0 = 0.4

## Add code here for gradient descent using the functions above

## Plot the gradient descent values


In [None]:
#@title Solution (double click to expand)

def f3(x):
  """Define the function f(x) = (x - e^(-x))^2."""
  return (x - jnp.exp(-x))**2

# Compute the gradient
df = grad(f3)

# Initial guess
x0 = 0.4
for alpha in [0.01, 0.1]:
  xs = gradient_descent(df, x0, iterations=50, alpha=alpha)
  plt.plot(range(len(xs)), xs, label="$\\alpha = {}$".format(alpha))
plt.xlabel("iterations")

plt.legend()
plt.show()

print("Final iteration:", xs[-1])

# Newton's method, single variable

We can use Newton's
method to find zeros of functions. Since local extrema occur at zeros of
the derivative, we can apply Newton's method to the first derivative to obtain
a second-order alternative to gradient descent.

$$\begin{eqnarray}
x_{n+1} &=& x_n - \alpha \frac{f'(x_n)}{f''(x_n)} \\
\end{eqnarray}$$ 


In [None]:
def newtons_method(func, x0, iterations=100, alpha=1.):
  dfunc = grad(func)
  xs = [x0]
  for i in range(iterations):
    x = xs[-1]
    x = x - alpha * func(x) / dfunc(float(x))
    xs.append(x)
  return xs

Let's repeat the example of finding the value of $\sqrt{3}$ with Newton's method and compare to gradient descent.

For small $\alpha$, gradient descent seems to perform better:

In [None]:
def f(x):
  return (x**2 - 3)**2

# Let's make a function we can reuse
def compare_gradient_newton(func, x0, alpha=0.01, iterations=50):
  # Compute the first and second derivatives
  df = grad(func)

  # Compute Newton's method iterations
  xs = newtons_method(df, x0, alpha=alpha, iterations=iterations)

  # Compute gradient descent with same alpha
  xs2 = gradient_descent(df, x0, alpha=alpha, iterations=iterations)

  # Plot it all
  plt.plot(range(len(xs2)), xs2, label="Gradient Descent")
  plt.plot(range(len(xs)), xs, label="Newton's method")
  plt.xlabel("iterations")
  plt.legend()
  plt.title("$\\alpha = {}$".format(alpha))

  # Plot the solution
  sqrt3 = math.pow(3, 0.5)
  n = len(xs)
  plt.plot(range(n), [sqrt3]*n, label="$\\sqrt{3}$", linestyle="--")

  plt.show()

compare_gradient_newton(f, 2., alpha=0.01, iterations=50)

But for larger $\alpha$, Newton's method is better behaved and gradient descent fails to converge.

In [None]:
compare_gradient_newton(f, 2., alpha=0.1, iterations=50)

In this case we can also apply Newton's method with just the first derivative to find a zero of $x^2 - 3$, i.e. we don't have to look for a minimum of 
$(x^2 - 3)^2$ since Newton's method can also find zeros of functions.

In [None]:
def f(x):
  return x**2 - 3

xs = newtons_method(f, 2., alpha=0.5, iterations=10)
plt.plot(range(len(xs)), xs, label="$\\alpha = {}$".format(alpha))
plt.xlabel("iterations")

# Plot the solution
sqrt3 = math.pow(3, 0.5)
n = len(xs)
plt.plot(range(n), [sqrt3]*n, label="$\\sqrt{3}$", linestyle="--")

print(xs[-1])

# Section 2

Now we'll look at multivariate derivatives and gradient descent,
again using Jax.

## Multivariate derivatives

$$f(x, y) = x y^2$$

$$ \nabla f = [y^2, 2 x y]$$

In [None]:
def f(x, y):
  return x * y * y

# Compute the partial derivatives with grad from Jax
# Use float as arguments, else Jax will complain
print("f(3, 1)=", f(3., 1.))

# argnums allows us to specify which variable to take the derivative of, positionally
print("Partial x derivative at (3, 1):", grad(f, argnums=0)(3., 1.))
print("Partial y derivative at (3, 1):", grad(f, argnums=1)(3., 1.))

# We can get both partials at the same time
print("Gradient vector at (3, 1):", grad(f, (0, 1))(3., 1.))
g = [float(z) for z in grad(f, (0, 1))(3., 1.)]
print("Gradient vector at (3, 1):", g)

We can plot some of the vectors of a gradient. Let's consider
$$f(x, y) = x^2 + y^2$$
The partial derivatives are
$$\frac{\partial f}{\partial x} = 2x$$

$$\frac{\partial f}{\partial y} = 2y$$

So the gradient is $$\nabla f = [2x, 2y]^T$$

In [None]:
def f(x, y):
  return x * x + y * y

partial_x = grad(f, argnums=0)
partial_y = grad(f, argnums=1)

xs = np.arange(-1, 1.25, 0.25)
ys = np.arange(-1, 1.25, 0.25)

plt.clf()

# Compute and plot the gradient vectors
for x in xs:
  for y in ys:
    u = partial_x(x, y)
    v = partial_y(x, y)
    plt.arrow(x, y, u, v, length_includes_head=True,
              head_width=0.1)
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.show()

## Jacobians and Hessians with Jax

Let's verify some of the example from the slides.

Let $f(x) = \sum_i{x_i} = x_1 + \cdots + x_n$. Then the gradient is $\nabla f (x) = [1, \ldots, 1]^T.$

In [None]:
n = 4
def f(x):
  return sum(x)

test_point = [1., 2., 3., 4.]

print("x = ", test_point)
print("Gradient(x):", [float(x) for x in grad(f)(test_point)])

## Try other test points, even random ones:
test_point = np.random.rand(4)
print()
print("x = ", test_point)
print("Gradient(x):", [float(x) for x in grad(f)(test_point)])


### Exercise

Compute the gradient of the function that sums the squares of the elements of a vector.

In [None]:
# @title Solution (double click to show)
def sum_squares(x):
  return jnp.dot(x, x)

test_point = np.array([1., 2., 3., 4.])

print("x = ", test_point)
print("Gradient(x):", [float(x) for x in grad(sum_squares)(test_point)])

Let's try a Jacobian now. In the slides we saw that Jacobian of $f(\mathbf{x}) = Ax$ is $A$. Let's verify with Jax.





In [None]:
A = np.array([[1., 2.], [3., 4.]])

def f(x):
  return jnp.dot(A, x)

x = np.array([1., 1.])
jacfwd(f)(x)

In [None]:
# Try some other matrices A and a 3x3 matrix
# Note that Jax handles larger matrices with the same code
# But you'll need a length 3 vector for x


We can compute the Hessian by taking the Jacobian twice. In this case, $f(\mathbf{x}) = Ax$ is a linear function, so the second derivatives should all be zero.

In [None]:
A = np.array([[1., 2.], [3., 4.]])

def f(x):
  return jnp.dot(A, x)

def hessian(f):
  return jacfwd(jacrev(f))

x = np.array([1., 1.])
hessian(f)(x)

In [None]:
# Try some other matrices A and a 3x3 matrix




### Exercise

Now try to take derivatives of the function $f(\mathbf{x}) = x \cdot A x$. Hints:
* Is it scalar or vector-valued?
* Does it matter if $A$ is symmetric $(A = A^T)$ or anti-symmetric $(A = -A^T)$?

In [None]:
# Your code here

def f(x):
  """Compute x . A x"""
  pass

# Compute the first and second derivatives at a test point x




In [None]:
#@title Solution (double-click to show)

A = np.array([[1., 0.], [0., 3.]])

def f(x):
  return jnp.dot(x, jnp.dot(A, x))

x = np.array([2., -1.])
print(grad(f)(x))
print(hessian(f)(x))

### Exercise: Entropy

Compute the Gradient and Hessian of the Shannon entropy:
$$S = -\sum_{i}{x_i \log x_i}$$

Note that for a test point you'll need to have all the elements positive and summing to 1, so a good choice is $$[1 / n, \ldots, 1 / n]$$


In [None]:
## Your code here

def entropy(x):
  pass

In [None]:
#@title Solution (double-click to show)

def entropy(x):
  return - sum(a * jnp.log(a) for a in x)  

x = np.array([1./2, 1./2])
print(entropy(x))
print(grad(entropy)(x))
print(hessian(entropy)(x))

## Example: Linear Regression with Jax

Given some data of the form $(x_i, y_i)$, let's find a best fit line $$ y = m x + b $$ by minimizing the sum of squared errors.

$$ S = \sum_{i}{\left(y_i - (m x_i + b) \right)^2}$$


In [None]:
## Adapted from JAX docs: https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html

# Generate some data using sklearn
X, y = make_regression(n_features=1, noise=10)
X, X_test, y, y_test = train_test_split(X, y)

# Plot the data
plt.scatter([x[0] for x in X], y)
plt.title("Randomly generated dataset")
plt.show()


Read through the following code, which minimizes the sum of squared errors for a linear model.

In [None]:
# In JAX, we can specify our parameters as various kinds of Python objects,
# including dictionaries.

# Initial model parameters
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}

# The model function itself, a linear function.
def forward(params, X):
  """y = w x + b"""
  return jnp.dot(X, params['w']) + params['b']

# The loss function we want to minimize, the sum of squared errors
# of the model prediction versus the true values
def sse(params, X, y):
  """Sum of squared errors (mean)"""
  err = forward(params, X) - y
  return jnp.mean(jnp.square(err))

# Function to update our parameters in each step of gradient descent
def update(params, grads, alpha=0.1):
    return jax.tree_multimap(lambda p, g: p - alpha * g, params, grads)

# We'll define a gradient descent function similarly to as before,
# and we'll track the loss function values for plotting
# Note also that we compute our gradients on the training data X and y
# but our loss function on the test data X_test and y_test
def gradient_descent(f, params, X, X_test, y, y_test, alpha=0.1, iterations=30):
  """
  Apply gradient descent to the function f with starting point x_0
  and learning rate \alpha.
  
  x_{n+1} = x_n - \alpha d_f(x_n)
  
  """
  grad_fn = grad(f)
  params_ = []
  losses = []

  for _ in range(iterations):
    grads = grad_fn(params, X, y)
    params = update(params, grads, alpha)
    params_.append(params)

    loss = f(params, X_test, y_test)
    losses.append(loss)

  return params_, losses

# Function to plot our residuals to see how the loss function progresses
def plot_residuals(params, model_fn, X, y, color='blue'):
  res = y - model_fn(params, X)
  plt.hist(res, bins=10, color=color, alpha=0.5)

# Find the best fit line
fit_params, losses = gradient_descent(sse, params, X, X_test, y, y_test)

# Plot the decrease in the loss function over iterations.
plt.plot(range(len(losses)), losses)
plt.ylabel("SSE")
plt.xlabel("Iteration")
plt.title("Loss evolution\nMinimizing sum of squared errors")
plt.show()

# Compare the errors of our initial guess model with the final model
# Note the differences in scales
plot_residuals(params, forward, X, y)
plt.title("Histogram of initial residuals (errors for each point)")
plt.show()

plot_residuals(fit_params[-1], forward, X, y, color='green')
plt.title("Histogram of final residuals (errors for each point)")
plt.show()


In [None]:
# Let's plot the best fit line

# Plot the data
xs = [x[0] for x in X]
plt.scatter(xs, y)
plt.title("Best fit line")

# Plot the best fit line
params = fit_params[-1]
xs = np.arange(min(xs), max(xs), 0.1)
m = float(params['w'])
b = float(params['b'])
ys = [m * x + b for x in xs] 

plt.plot(xs, ys, color='black')
plt.show()


We can easily use another loss function, like the mean absolute error, where use the absolute value of residuals instead of the square, which reduces the
impact of outliers.

$$ S = \sum_{i}{\left|y_i - (m x_i + b) \right|}$$

This will give us a different best fit line for some data sets.

In [None]:
## Minimize MAE instead of SSE

# MAE is the p=1 case of this function.
def lp_norm(p=2):
  def norm(params, X, y):
    err = forward(params, X) - y
    return jnp.linalg.norm(err, ord=p)
  return norm

# Generate some noisier data
X, y = make_regression(n_features=1, noise=100, bias=5)
X, X_test, y, y_test = train_test_split(X, y)

fit_params, losses = gradient_descent(lp_norm(p=1.), params, X, X_test, y, y_test)

# Plot the data
xs = [x[0] for x in X]
plt.scatter(xs, y)
plt.title("Best fit line")

# Plot the best fit line
params = fit_params[-1]
xs = np.arange(min(xs), max(xs), 0.1)
m = float(params['w'])
b = float(params['b'])
ys = [m*x+b for x in xs] 
plt.plot(xs, ys, color='black', label="MAE")

# Compare to SEE best fit line

# Find the best fit line
fit_params, losses = gradient_descent(sse, params, X, X_test, y, y_test)

# Plot the best fit line
params = fit_params[-1]
xs = np.arange(min(xs), max(xs), 0.1)
m = float(params['w'])
b = float(params['b'])
ys = [m*x+b for x in xs] 
plt.plot(xs, ys, color='green', label="SSE")
plt.legend()
plt.show()
