# Function differentiation

If you have a Python function `f(x)` that evaluates the mathematical function $f$ at $x$, then we would like to find a function $\nabla f$ so that we can evaluate $\nabla f(x)$.

Let's look at naive, symbolic and automatic differentiation:

- Finite difference
- SymPy
- JAX

💡 There's a TODO at the end of the notebook.

## Examples

The examples are from [the Jax docs](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).

First, we look at $\tanh x$, whose derivative is famously convenient to calculate: $1 - \tanh^2 x$. 

Second, we look at a polynomial function. Then derivatives of $f(x) = x^3 + 2x^2 - 3x + 1$ can be represented as:

$$
\begin{array}{l}
f'(x) = 3x^2 + 4x -3\\
f''(x) = 6x + 4\\
f'''(x) = 6\\
f^{iv}(x) = 0
\end{array}
$$

## Analytic derivative

In [1]:
import numpy as np

x = 2.0

f = lambda x: np.tanh(x)
f(x)

0.9640275800758169

We might know or [look up](https://en.wikipedia.org/wiki/Hyperbolic_functions#Derivatives) that the derivative is given by $1 - \tanh^2 x$

In [2]:
dfdx = 1 - f(x)**2
dfdx

0.07065082485316443

## Finite difference

### One-sided

We can approximate the analytic result with a finite difference, correct to a few decimal places:

$$
f'(x) \approx \frac{f(x + \epsilon) - f(x)}{\epsilon}
$$

In [3]:
ϵ = 1e-6
dfdx = (f(x + ϵ) - f(x)) / ϵ
dfdx

0.07065075680046107

### Two-sided

We can improve slightly with a symmetric function, but it's still only an estimate, and we cannot get around the floating point imprecision.

$$
f'(x) \approx \frac{f(x + \epsilon) - f(x - \epsilon)}{2 \epsilon}
$$

In [4]:
ϵ = 1e-6
dfdx = (f(x + ϵ) - f(x - ϵ)) / (2 * ϵ)
dfdx

0.07065082485713248

## Symbolic differentiation with SymPy

Examples inspired by [this Coursera tutorial](https://github.com/greyhatguy007/Mathematics-for-Machine-Learning-and-Data-Science-Specialization-Coursera/blob/main/C2/w1/C2_W1_Lab_1_differentiation_in_python.ipynb) by Luis Serrano.

In [5]:
from sympy import tanh, symbols, expand, diff, evalf
from sympy.utilities.lambdify import lambdify

x = symbols('x')
f = tanh(x)
f

tanh(x)

We can evaluate for $x = 2$

In [6]:
f.evalf(subs={x: 2})

0.964027580075817

And find the derivative as a symbolic expression:

In [7]:
dfdx = diff(f, x)
dfdx

1 - tanh(x)**2

And evaluate this:

In [8]:
dfdx.evalf(subs={x: 2})

0.0706508248531645

The second derivative simply involves passing the derivative function to `diff()`:

In [9]:
d2fdx = diff(f, x, 2)  # Second derivative.
d2fdx

2*(tanh(x)**2 - 1)*tanh(x)

#### Polynomial example

In [10]:
f = x**3 + 2*x**2 - 3*x + 1
f

x**3 + 2*x**2 - 3*x + 1

In [11]:
f.evalf(subs={x: 2})

11.0000000000000

In [12]:
dfdx = diff(f, x)
dfdx

3*x**2 + 4*x - 3

In [13]:
dfdx.evalf(subs={x: 2})

17.0000000000000

## Automatic differentiation with dual numbers

I learned about [dual numbers](https://en.wikipedia.org/wiki/Dual_number) and autodiff from Håvard Berland, who described the method at a company conference, in a version of [part of his PhD defense](https://www.pvv.ntnu.no/~berland/resources/autodiff-triallecture.pdf).

Dual numbers are expressions of the form $a + b\varepsilon$, where $a$ and $b$ are real numbers, and $\varepsilon$ is a symbol taken to satisfy $\varepsilon^2 = 0$ with $\varepsilon\neq 0$. Evaluating a function with dual numbers produces the derivative automatically:

$$ P(a + b\varepsilon) = P(a) + bP'(a)\varepsilon $$

So we choose $b = 1$.

Dual numbers are implemented in some Python libraries, eg [`num-dual`](https://pypi.org/project/num-dual/), but there are others. However, it's not too hard to implement them ourselves.

In [14]:
class Dual:
    def __init__(self, real, dual):
        self.real = real
        self.dual = dual
    def __add__(self, other):
        if isinstance(other, Dual):
            return Dual(self.real + other.real, self.dual + other.dual)
        else:
            return Dual(self.real + other, self.dual)
    __radd__ = __add__
    def __mul__(self, other):
        if isinstance(other, Dual):
            return Dual(self.real * other.real,
                        self.real * other.dual + other.real * self.dual)
        else:
            return Dual(self.real * other, self.dual * other)
    __rmul__ = __mul__
    def __neg__(self):
        return self.__mul__(-1)
    def __sub__(self, other):
        return self + -other
    def __rsub__(self, other):
        return other + -self
    def __repr__(self):
        return f'Dual({self.real}, {self.dual})'

Now we define our function, using multiplication instead of exponentiation (since we have not defined `__pow__()` in our class).

In [15]:
f = lambda x: x*x*x + 2*x*x - 3*x + 1

# Evaluate with x = 2:
f(2)

11

Now for the derivative at $x = 2$. Instead of evaluating on a real number, we evaluate on the dual number, $x = 2 + 1\epsilon$

In [16]:
x = Dual(2, 1)

f(x)

Dual(11, 17)

The dual part, 17, is the derivative.

## Automatic differentiation with `autograd`

[The original `autograd`](https://github.com/HIPS/autograd) from HIPS, now LIPS (at Princeton).

In [17]:
import autograd.numpy as np  # Thinly-wrapped numpy
from autograd import grad    # The only autograd function you may ever need

def tanh(x):
    return (1.0 - np.exp((-2 * x))) / (1.0 + np.exp(-(2 * x)))

grad_tanh = grad(tanh)       # Obtain its gradient function
grad_tanh(1.0)               # Evaluate the gradient at x = 1.0

0.419974341614026

In [18]:
(tanh(1.0001) - tanh(0.9999)) / 0.0002  # Compare to finite differences

0.41997434264973155

The polynomial example above, recall that the derivative at f(2) was 17...

In [20]:
df = grad(f)

df(2.0)

17.0

## Automatic differentiation with Jax

[The Jax docs](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) are really good. All the examples in this notebook were derived from there.

In [21]:
import jax
import jax.numpy as jnp
from jax import grad

grad_tanh = grad(jnp.tanh)

x = 2.0
grad_tanh(x)

Array(0.07065082, dtype=float32, weak_type=True)

We can get the second-order derivative:

In [22]:
grad(grad(jnp.tanh))(x)

Array(-0.13621868, dtype=float32, weak_type=True)

For the polynomial:

In [23]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx  = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

We expect 4.0 here:

In [24]:
dfdx(1.0)

Array(4., dtype=float32, weak_type=True)

## TODO

- Compute on arrays
- Add plots
- Add forward and reverse mode automatic differentiation, eg [see this](https://kenndanielso.github.io/mlrefined/blog_posts/3_Automatic_differentiation/3_4_AD_forward_mode.html)
- Also see [Automatic differentiation from scratch](https://e-dorigatti.github.io/math/deep%20learning/2020/04/07/autodiff.html) for a very nice implementation of the computational graph
- Add [`tangent`](https://github.com/google/tangent) or whatever replaced it
- Add Torch
- Add Tensorflow
- Have a look at [`yaae`](https://hackmd.io/@machine-learning/blog-post-yaae) (Yet Another Autodiff Engine)

---

&copy; Matt Hall 2024 and various original authors linked in text, original content licensed CC BY