# Automatic differentiation tutorial


This hidden block defines some utility functions and imports packages. You should run it before you execute any of the cells below.

> Indented block



In [0]:
#@title
import math
import unittest

def forward_primitive(func):
  """Overload primitives.

  This makes sure that if a function is called on a regular
  value, the original primitive is called. However, if it
  is called on a dual number, the dual number primitive is
  called instead.

  """
  def wrap(dual_func):
    def overloaded_func(x):
      if isinstance(x, Dual):
        return dual_func(x)
      else:
        return func(x)
    return overloaded_func
  return wrap

def backward_primitive(func):
  """Overload primitives.

  This makes sure that if a function is called on a regular
  value, the original primitive is called. However, if it
  is called on a Float, the Float number primitive is
  called instead.

  """
  def wrap(float_func):
    def overloaded_func(x):
      if isinstance(x, Float):
        return float_func(x)
      else:
        return func(x)
    return overloaded_func
  return wrap


# Forward mode using operator overloading

As discussed in class, forward mode is commonly implemented by replacing numbers with dual numbers of the form $a + b\varepsilon$ where $\varepsilon^2 = 0$. The first component represents the intermediate value of the primal computation, and the second component represents the partial derivative at the intermediate value with respect to an input.

## Overloading arithmetic operators

We will begin by implementing a simple version of dual numbers in Python. Python supports operator overloading of arithmetic operators for arbitrary objects by defining special functions of the form `__add__`, `__mul__`, etc. Consider the following example.

In [0]:
class Dual:
  """A dual number of the form a + bε where ε² = 0."""
  def __init__(self, a, b):
    self.a = a
    self.b = b

  def __str__(self):
    return f"{self.a} + {self.b}ε"
    
  def __add__(self, other):
    """Addition of dual numbers.
    
    Note that (a + bε) + (c + dε) = (a + c) + (b + d)ε.
    
    """
    return Dual(self.a + other.a, self.b + other.b)

  def __mul__(self, other):
    """(a + bε)(c + dε)."""
    
    return Dual(self.a * other.a , self.a * other.b + self.b * other.a)
  

  def __truediv__(self, other):
    """(a + bε) / (c + dε)."""
    
    return Dual(self.a/other.a , (self.b / other.a) -(self.a * other.b)/(other.a**2))
    

  def __neg__(self):
    """-(a + bε)."""
    return Dual(-self.a, -self.b)

In [3]:
# We can add two scalars
x, y = 0, 2
print(x + y)

# And we can add two dual numbers the same way
x, y = Dual(0, 1), Dual(2, 2)
print(x + y)

2
2 + 3ε


Note that the arithmetic of dual numbers obey the rules of differentiation for the second component. For example, if $w = yz$ then the product rule tells us that $\frac{dw}{dx} = \frac{dy}{dx}z + y\frac{dz}{dx}$. Similarly, $(a + b\varepsilon)(c + d\varepsilon) = ac + (ad + bc)\varepsilon + bd\varepsilon^2= ac + (ad + bc)\varepsilon$.


#### Exercise

Complete the class accordingly by using the arithmetic of dual numbers or the rules of differentiation.

**Solution**:

In [0]:
#@title
class Dual:
  """A dual number of the form a + bε where ε² = 0."""
  def __init__(self, a, b):
    self.a = a
    self.b = b

  def __str__(self):
    return f"{self.a} + {self.b}ε"
    
  def __add__(self, other):
    """Addition of dual numbers.
    
    Note that (a + bε) + (c + dε) = (a + c) + (b + d)ε.
    
    """
    return Dual(self.a + other.a, self.b + other.b)

  def __mul__(self, other):
    """(a + bε)(c + dε)."""
    return Dual(self.a * other.a, self.a * other.b + self.b * other.a)

  def __truediv__(self, other):
    """(a + bε) / (c + dε)."""
    return Dual(self.a / other.a, (self.b * other.a - self.a * other.b) / other.a ** 2)

  def __neg__(self):
    """-(a + bε)."""
    return Dual(-self.a, -self.b)

**Unit tests**: run the following cell to check your implementation.

In [5]:
#@title
class TestDual(unittest.TestCase):
  def test_add(self):
    x, y = Dual(0, 1), Dual(2, 2)
    z = x + y
    self.assertEqual(z.a, 2)
    self.assertEqual(z.b, 3)

  def test_mul(self):
    x, y = Dual(0, 1), Dual(2, 2)
    z = x * y
    self.assertEqual(z.a, x.a * y.a)
    self.assertEqual(z.b, x.a * y.b + x.b * y.a)

  def test_truediv(self):
    x, y = Dual(0, 1), Dual(2, 2)
    z = x / y
    self.assertEqual(z.a, x.a / y.a)
    self.assertEqual(z.b, (x.b * y.a - x.a * y.b) / y.a ** 2)

  def test_neg(self):
    y = Dual(2, 2)
    z = -y
    self.assertEqual(z.a, -y.a)
    self.assertEqual(z.b, -y.b)


suite = unittest.TestLoader().loadTestsFromTestCase(TestDual)
_ = unittest.TextTestRunner(verbosity=2).run(suite)

test_add (__main__.TestDual) ... ok
test_mul (__main__.TestDual) ... ok
test_neg (__main__.TestDual) ... ok
test_truediv (__main__.TestDual) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.016s

OK


## Overloading primitives

Similarly, we can define mathematical functions which operate on dual numbers. Let's write a logarithm function which works on dual numbers.

In [0]:
@forward_primitive(math.log)
def log(x):
  return Dual(math.log(x.a), x.b / x.a)

In [7]:
x = Dual(2, 1)
print(log(x))

0.6931471805599453 + 0.5ε


#### Exercise

Implement the overloaded exponential operator similarly to how the logarithm operator was implemented.

In [0]:
@forward_primitive(math.exp)
def exp(x):
  
  return Dual(math.exp(x.a),math.exp(a)*x.b)

**Solution**:

In [0]:
#@title
@forward_primitive(math.exp)
def exp(x):
  exp_a = math.exp(x.a)
  return Dual(exp_a, x.b * exp_a)

**Unit tests**: run the following cell to check your implementation.

In [10]:
#@title
class TestPrimitives(unittest.TestCase):
  def test_log(self):
    x = Dual(2, 1)
    y = log(x)
    self.assertAlmostEqual(y.a, math.log(2))
    self.assertAlmostEqual(y.b, 0.5)

  def test_exp(self):
    x = Dual(2, 1)
    y = exp(x)
    self.assertEqual(y.a, math.exp(2))
    self.assertEqual(y.b, math.exp(2))


suite = unittest.TestLoader().loadTestsFromTestCase(TestPrimitives)
_ = unittest.TextTestRunner(verbosity=2).run(suite)

test_exp (__main__.TestPrimitives) ... ok
test_log (__main__.TestPrimitives) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.008s

OK


## Testing it out



We now have all the pieces to calculate the derivative of a weighted logistic function, $f(x) = 1 / (1 + e^{-tx})$. To calculate the derivative with respect to $x$, we have to set the initial dual values of $x$ and $t$ appropriately.


In [0]:
def logistic(x, t, one):
  return one / (one + exp(-(t * x)))

In [12]:
# The regular output
y = logistic(1, 2, 1)
print(f"y = {y:.4}")

y = 0.8808


#### Exercise

Set the correct `b` values of `Dual` instances for $x$ and $t$.

In [0]:
# The derivative
one = Dual(1, 0)
x = Dual(1, 1)
t = Dual(2, 0)

**Solution**:

In [0]:
#@title
one = Dual(1, 0)
x = Dual(1, 1)
t = Dual(2, 0)

**Unit tests**: run the following cell to check your implementation.

In [15]:
#@title
class TestLogistic(unittest.TestCase):
  def setUp(self):
    self.one = one
    self.x = x
    self.t = t

  def test_derivative(self):
    y = logistic(self.x, self.t, self.one)
    self.assertAlmostEqual(y.b, y.a * (1 - y.a) * self.t.a)

suite = unittest.TestLoader().loadTestsFromTestCase(TestLogistic)
_ = unittest.TextTestRunner(verbosity=2).run(suite)

test_derivative (__main__.TestLogistic) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.002s

OK


A passing test indicates that our implementation is correct:

In [16]:
y = logistic(x, t, one)
dydx = y.b
print(f"dy/dx = {dydx:.4}")

dy/dx = 0.21


However, the test was written by deriving the correct expression for the gradient, which for more complex functions may be very tedious. Another way to verify our implementation is by comparing it with simple finite differences where we approximate $f'(x) \approx \frac{f(x + \Delta) - f(x - \Delta)}{2\Delta}$.

#### Exercise

Confirm that your implementation is correct by checking if the finite differences approximation is close to the derivative calculated using automatic differentiation.

In [17]:
Delta = 0.01
dydx = (logistic(x.a + Delta, t.a, one.a)-logistic(x.a - Delta, t.a, one.a))/(2*Delta) # Replace this with the finite difference approximation
print(f"dy/dx = {dydx:.4}")

dy/dx = 0.21


# Reverse mode using operator overloading

As discussed in class, forward mode scales with the number of inputs. If we want the derivative of the weighted logistic function with respect to both $x$ and $t$, we would have to execute the function twice, once with $x = 1 + \varepsilon$ and once with $t = 2 + \varepsilon$.

We will avoid this overhead by implementing reverse mode automatic differentiation. We will start by using something similar to dual numbers, replacing each variable with an object that holds both the intermediate value and its partial derivative.

However, unlike in forward mode, we won't calculate the partial derivative straight away. Instead, we will log the computation on a tape. The partial derivative will then be calculated during the backward pass, when the tape is walked in reverse.

In [0]:
from typing import NamedTuple, Callable, Sequence

# The tape is simply a list
TAPE = []

class TapeEntry(NamedTuple):
  """On the tape we log the called function, its inputs, and its outputs."""
  function: Callable[..., float]
  inputs: Sequence
  output: float  # For simplicity, we assume a single output

In [0]:


import operator

class Float:
  """Scalar values that also hold their partial derivative."""
  def __init__(self, value):
      self.value = value
      self.grad = 0.0

  def __add__(self, other):
    # Add two numbers and create a new Float
    out = Float(self.value + other.value)
    
    # Write to the tape that this addition was performed
    TAPE.append(TapeEntry(operator.add, (self, other), out))

    return out

  def __repr__(self):
    return repr(self.value)

In [20]:
# We can now generate numbers that have an associated gradient
x = Float(2.0)
print(x + x)
x.grad += 1
print(x.grad)

# And when we add two numbers, this was logged on the tape
print(TAPE)

4.0
1.0
[TapeEntry(function=<built-in function add>, inputs=(2.0, 2.0), output=4.0)]


We now need to define the gradient for each operation, which will be performed during the backward pass. Consider the following example for addition:

In [0]:
# For each primitive, we have to say what its derivative is
GRADS = {}

def add_grad(x, y, z):
  # Note that dL/dx = dL/dz * dz/dx = dL/dz
  # Since z = x + y, dz/dx = 1, so dL/dx = dL/dz
  x.grad += z.grad
  # Similarly for dL/dy
  y.grad += z.grad

GRADS[operator.add] = add_grad

Each operation is now logged to a tape, and for each operation we know what gradient computations to perform during the backward pass. The only thing left to do is then to implement the backward pass by walking the tape in reverse.

#### Exercise

Walk the tape in reverse. For each operation on the tape, call its corresponding gradient with the inputs and output.

In [0]:
def grad(f):
  def df(*args, **kwargs):
    # Begin with an empty tape
    TAPE.clear()

    # Now call the original function, which will write to the tape
    out = f(*args, **kwargs)

    # The initial gradient of the output is 1
    out.grad = 1

    for entry in reversed(TAPE):
      grad_function= GRADS[entry.function]
      #pass
      # TODO: Call the grad_function with *inputs and output.
      #pass
      grad_function(*entry.inputs,entry.output)

    # We return the gradient with respect to each of the input arguments
    return tuple(arg.grad if isinstance(arg, Float) else None for arg in args)
  return df

**Solution**:

In [0]:
#@title
def grad(f):
  def df(*args, **kwargs):
    # Begin with an empty tape
    TAPE.clear()

    # Now call the original function, which will write to the tape
    out = f(*args, **kwargs)

    # The initial gradient of the output is 1
    out.grad = 1

    for entry in reversed(TAPE):
      grad_function = GRADS[entry.function]
      grad_function(*entry.inputs, entry.output)

    # We return the gradient with respect to each of the input arguments
    return tuple(arg.grad if isinstance(arg, Float) else None for arg in args)
  return df

**Unit tests**: run the following cell to check your implementation.

In [24]:
#@title
class TestGrad(unittest.TestCase):
  def test_derivative(self):
    def _f(x, y):
      return x + y + y
    dx, dy = grad(_f)(Float(2), Float(3))
    self.assertAlmostEqual(dx, 1.0)
    self.assertAlmostEqual(dy, 2.0)

suite = unittest.TestLoader().loadTestsFromTestCase(TestGrad)
_ = unittest.TextTestRunner(verbosity=2).run(suite)

test_derivative (__main__.TestGrad) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.002s

OK


Let's try our validated implementation on a simple example. For the following example, we want to calculate the gradient with respect to both $x$ and $y$.

In [25]:
def f(x, y):
  return x + y + y

dx, dy = grad(f)(Float(2), Float(3))
print(f"dx = {dx:.4}, dy = {dy:.4}")

dx = 1.0, dy = 2.0


Below there is a version of the `Float` class which traces division, multiplication, and negation as well. It also introduces an exponential function which operates on our special `Float` values. The only thing that is missing is their actual gradients.

In [0]:
def trace(primitive):
  """Overload a function.

  This function takes a primitive, and returns a new
  overloaded function which applies the primitive but
  also logs the application of the primitive to the tape.

  This is just a programmatic way of writing the __add__
  function from the previous Float class.

  """
  def overloaded_primitive(*args):
    out = Float(primitive(*(arg.value for arg in args)))
    TAPE.append(TapeEntry(primitive, args, out))
    return out
  return overloaded_primitive

class Float:
  def __init__(self, value):
      self.value = value
      self.grad = 0.0

  # Overloaded operators
  __add__ = trace(operator.add)
  __truediv__ = trace(operator.truediv)
  __mul__ = trace(operator.mul)
  __neg__ = trace(operator.neg)

  def __repr__(self):
    return repr(self.value)

@backward_primitive(math.exp)
def exp(x):
  out = Float(math.exp(x.value))
  TAPE.append(TapeEntry(math.exp, (x,), out))
  return out

#### Exercise

Implement the backwards gradients for multiplication, negation, division, and the exponential function so that you can calculate the gradient of the weighted logistic function with respect both $x$ and $t$.

In [0]:
def mul_grad(x, y, z):
  x.grad += z.grad * y.value# TODO: replace with the correct expression.
  y.grad += z.grad * x.value # TODO: replace with the correct expression.

GRADS[operator.mul] = mul_grad

def truediv_grad(x, y, z):
  x.grad += z.grad * (1/y.value)# TODO: replace with the correct expression.
  y.grad += -z.grad * x.value* (1/y.value**2) # TODO: replace with the correct expression.

GRADS[operator.truediv] = truediv_grad

def neg_grad(x, y):
  x.grad += -y.grad # TODO: replace with the correct expression.

GRADS[operator.neg] = neg_grad

def exp_grad(x, y):
  x.grad += y.grad * math.exp(x.value) # TODO: replace with the correct expression.

GRADS[math.exp] = exp_grad

**Solution**:

In [0]:
#@title
def mul_grad(x, y, z):
  x.grad += y.value * z.grad
  y.grad += x.value * z.grad

GRADS[operator.mul] = mul_grad

def truediv_grad(x, y, z):
  x.grad += z.grad / y.value
  y.grad += -z.grad * x.value / (y.value * y.value)

GRADS[operator.truediv] = truediv_grad

def neg_grad(x, y):
  x.grad += -y.grad

GRADS[operator.neg] = neg_grad

def exp_grad(x, y):
  x.grad += y.grad * y.value

GRADS[math.exp] = exp_grad

**Unit tests**: run the following cell to check your implementation.

In [29]:
#@title
class TestOperatorGrads(unittest.TestCase):
  def test_mul_grad(self):
    def _f(x, y):
      return x * y
    dx, dy = grad(_f)(Float(2), Float(3))
    self.assertAlmostEqual(dx, 3.0)
    self.assertAlmostEqual(dy, 2.0)

  def test_truediv_grad(self):
    def _f(x, y):
      return x / y
    dx, dy = grad(_f)(Float(2), Float(3))
    self.assertAlmostEqual(dx, 1.0 / 3.0)
    self.assertAlmostEqual(dy, -2.0 / 9.0)

  def test_neg_grad(self):
    def _f(x):
      return -x
    dx, = grad(_f)(Float(2))
    self.assertAlmostEqual(dx, -1.0)

  def test_exp_grad(self):
    def _f(x):
      return exp(x)
    dx, = grad(_f)(Float(2))
    self.assertAlmostEqual(dx, math.exp(2.0))

suite = unittest.TestLoader().loadTestsFromTestCase(TestOperatorGrads)
_ = unittest.TextTestRunner(verbosity=2).run(suite)

test_exp_grad (__main__.TestOperatorGrads) ... ok
test_mul_grad (__main__.TestOperatorGrads) ... ok
test_neg_grad (__main__.TestOperatorGrads) ... ok
test_truediv_grad (__main__.TestOperatorGrads) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.010s

OK


After you implemented these gradients, we can calculate the derivative of both $x$ and $t$ in one pass. Let's try our validated implementation on a simple example.

In [30]:
# We can calculate the gradients of all the arguments.
dx, dt, _ = grad(logistic)(Float(1), Float(2), Float(1))
print(f"dx = {dx:.4}, dt = {dt:.4}")

dx = 0.21, dt = 0.105


#### Exercise

Check to see if your code runs and returns correct gradients using finite differences.

In [38]:
# We can compare the gradient dx against the forward mode (0.21)
# Let's also check whether dt is correct
Delta = 0.01
dt = (logistic(x.a , t.a + Delta, one.a)-logistic(x.a , t.a - Delta, one.a))/(2*Delta) # Replace this with the finite difference approximation

print(f"correct dt = {dt:.4}")

correct dt = 0.105


# Bonus: Higher-order derivatives

A common problem with AD implementations is that higher-order derivatives are not supported. Operator overloading is actually one of the easiest cases to implement higher-order derivatives for, so let's see what we need to change to make AD closed under its own operation.

The main problem is that higher-order derivatives require us to write the gradient calculations to the tape as well. Right now, the gradients are just numbers (e.g. `self.grad = 0` instead of `self.grad = Float(0.0)`. Setting `self.grad = Float(0.0)` creates an infinite loop though, so we need to be a bit more clever.)

Also, in all the places where we assumed that the `grad` attribute was a number, we need to replace it with an instance of `Float`.

In [0]:
class Float:
  def __init__(self, value):
      self.value = value
      self._grad = None

  @property
  def grad(self):
    if self._grad is None:
      self._grad = Float(0.0)
    return self._grad

  @grad.setter
  def grad(self, value):
    self._grad = value

  # Overloaded operators
  __add__ = trace(operator.add)
  __truediv__ = trace(operator.truediv)
  __mul__ = trace(operator.mul)
  __neg__ = trace(operator.neg)

  def __repr__(self):
    return repr(self.value)

In [0]:
def mul_grad(x, y, z):
  x.grad += y * z.grad
  y.grad += x * z.grad

GRADS[operator.mul] = mul_grad

def truediv_grad(x, y, z):
  x.grad += z.grad / y
  y.grad += -z.grad * x / (y * y)

GRADS[operator.truediv] = truediv_grad

def neg_grad(x, y):
  x.grad += -y.grad

GRADS[operator.neg] = neg_grad

def exp_grad(x, y):
  x.grad += y.grad * y

GRADS[math.exp] = exp_grad

In [0]:
def grad(f):
  def df(*args, **kwargs):
    # Now call the original function, which will write to the tape
    out = f(*args, **kwargs)

    # The initial gradient of the output is 1
    out._grad = Float(1)

    for entry in reversed(TAPE):
      grad_function = GRADS[entry.function]
      grad_function(*entry.inputs, entry.output)

    # We return the gradient with respect to each of the input arguments
    return tuple(arg.grad if isinstance(arg, Float) else None for arg in args)
  return df
  

In [42]:
def grad_logistic(*args):
  dx, dt, _ = grad(logistic)(*args)
  return dx

# We can calculate the gradients of all the arguments.
TAPE.clear()
dxdx, dxdt, _ = grad(grad_logistic)(Float(1), Float(2), Float(1))
print(f"dxdx = {dxdx.value:.4}, dxdt = {dxdt.value:.4}")

dxdx = 0.9401, dxdt = 0.575
