In [1]:
from minigrad import Scalar, draw_graph

In [2]:
from torch import tensor
from jax import grad

In [3]:
def is_close(x, y, z=None, eps=1e-10):
    if z is None: return abs(x-y) < eps
    return abs(x-y) < eps and abs(y-z) < eps and abs(x-z) < eps

# Verify gradients against PyTorch and JAX

In [4]:
def f1(x, y):
    '''From "Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition", page 770.'''
    n4 = x*x
    n5 = n4 * y
    n6 = y + 2
    n7 = n5 + n6
    return n7

# Compute gradients with MiniGrad
x = Scalar(3.0)
y = Scalar(4.0)
out = f1(x, y)
out.backward()

# Compute gradients with PyTorch
x_p = tensor(3.0, requires_grad=True)
y_p = tensor(4.0, requires_grad=True)
out_p = f1(x_p, y_p)
out_p.backward()

# Compute gradients with JAX
x_j_grad, y_j_grad = grad(f1, (0, 1))(3.0, 4.0)

# Check that gradients are equivalent
assert is_close(x.grad, x_p.grad, float(x_j_grad))
assert is_close(y.grad, y_p.grad, float(y_j_grad))
# draw_graph(out)



In [5]:
def f2(a, b):
    '''From https://github.com/karpathy/micrograd.'''
    c = a + b
    d = a * b + b**3
    c += c + 1
    c += 1 + c + (-a)
    d += d * 2 + (b + a).relu()
    d += 3 * d + (b - a).relu()
    e = c - d
    f = e**2
    g = f / 2.0
    g += 10.0 / f
    return g

# Compute gradients with MiniGrad
a = Scalar(-4.0)
b = Scalar(2.0)
out = f2(a, b)
out.backward()

# Compute gradients with PyTorch
a_p = tensor(-4.0, requires_grad=True)
b_p = tensor(2.0, requires_grad=True)
out_p = f2(a_p, b_p)
out_p.backward()

# JAX doesn't work here since its scalars don't support .relu(), and I'm lazy
# a_j_grad, b_j_grad = grad(f2, (0, 1))(-4.0, 2.0)

assert is_close(a.grad, a_p.grad)
assert is_close(b.grad, b_p.grad)
# draw_graph(out)