# Automatic differentiation in MLX

Markus Enzweiler, markus.enzweiler@hs-esslingen.de

This is a demo used in a Computer Vision & Machine Learning lecture. Feel free to use and contribute.

**Note: This requires a machine with an Apple SoC, e.g. M1/M2/M3 etc.**

See: https://github.com/ml-explore/mlx


## Setup

Adapt `packagePath` to point to the directory containing this notebeook.

In [45]:
# Imports
import sys
import os

In [46]:
# Package Path
package_path = "./" # local
print(f"Package path: {package_path}")

Package path: ./


In [47]:
# Install requirements in the current Jupyter kernel
req_file = os.path.join(package_path, "requirements.txt")
if os.path.exists(req_file):
    !{sys.executable} -m pip install -r {req_file}
else:
    print(f"Requirements file not found: {req_file}")



In [48]:
# Now we should be able to import the additional packages
import mlx 
import mlx.core as mx

## MLX

MLX provides composable function transformations, supporting automatic differentiation, automatic vectorization, and optimization of computation graphs. Computation graphs within MLX are dynamically constructed. A key feature of MLX is the use of (and optimization for) unified memory present in the Apple SoCs. 

See:
- https://github.com/ml-explore/mlx
- https://ml-explore.github.io/mlx/build/html/quick_start.html

### Automatic differentiation with scalar functions in MLX

In [49]:
def f(x):
    return x**2 + 3*x + 2


x_tensor = mx.arange(-5, 5, 1, dtype=mx.float32)
print(mx.grad(f)(x_tensor[0]))


array(-7, dtype=float32)


In [50]:
# MLX has mlx.core.grad to automatically compute gradients of a function. 
# See: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.grad.html#mlx.core.grad

# Let's try it out with simple functions first.

# Define the function, x^2+3x+2
def f(x):
    return x**2 + 3*x + 2

# Manual gradient w.r.t x
def f_grad(x):
    return 2*x + 3

def mlxgrad(func, x):
    # Initialize an empty list for gradients
    gradients = []

    # Compute the gradient for each element in the tensor
    for xi in x:
        # Compute the function on the i-th element
        y = func(xi)

        # Compute the gradient for the i-th element
        gradients.append(mlx.core.grad(func)(xi))
            
        # Similar to PyTorch autograd, the mlx.core.grad function is designed to compute gradients of scalar 
        # outputs with respect to inputs.
        
        # In our case, the function f(x) applied to x_tensor results in a vector (a tensor with multiple elements),
        # not a single scalar. Hence, mlx.core.grad cannot directly compute the gradient for each element
        # of this vector. To resolve this, we loop over each element of x_tensor, treating each function evaluation
        # f(x[i]) as a scalar output, and compute its gradient individually. This way, we are effectively computing
        # the gradient of multiple scalar functions, each dependent on a single element of x_tensor.

    return gradients



# Compute some function values and gradients
# make sure to set requires_grad=True to enable gradient tracking on the computational graph
x_tensor = mx.arange(-5, 5, 1, dtype=mx.float32)

f_value    = f(x_tensor)
f_grad     = f_grad(x_tensor)
f_autograd = mlxgrad(f, x_tensor)

for i in range(len(x_tensor)):
    print(f"x = {x_tensor[i].item():5.2f}: "
          f"f(x) = {f_value[i].item():5.2f}, "
          f"f_grad(x) = {f_grad[i].item():5.2f}, "
          f"autograd(x) = {f_autograd[i].item():5.2f}")

x = -5.00: f(x) = 12.00, f_grad(x) = -7.00, autograd(x) = -7.00
x = -4.00: f(x) =  6.00, f_grad(x) = -5.00, autograd(x) = -5.00
x = -3.00: f(x) =  2.00, f_grad(x) = -3.00, autograd(x) = -3.00
x = -2.00: f(x) =  0.00, f_grad(x) = -1.00, autograd(x) = -1.00
x = -1.00: f(x) =  0.00, f_grad(x) =  1.00, autograd(x) =  1.00
x =  0.00: f(x) =  2.00, f_grad(x) =  3.00, autograd(x) =  3.00
x =  1.00: f(x) =  6.00, f_grad(x) =  5.00, autograd(x) =  5.00
x =  2.00: f(x) = 12.00, f_grad(x) =  7.00, autograd(x) =  7.00
x =  3.00: f(x) = 20.00, f_grad(x) =  9.00, autograd(x) =  9.00
x =  4.00: f(x) = 30.00, f_grad(x) = 11.00, autograd(x) = 11.00


### Automatic differentiation with tensors in MLX

In [51]:
# Define two tensors and track computations
t1 = mx.array([[1, 2, 3],
               [4, 5, 6]], dtype=mx.float32)

t2 = mx.array([[7, 8, 9],
              [10, 11, 12]], dtype=mx.float32)

In [52]:
# Perform element-wise multiplication of t1 and t2
t1_mul_t2 = t1 * t2

def mul_and_reduce_sum(t1, t2):
    return mx.sum(t1*t2)

After `backward()`, `t1.grad` and `t2.grad` are populated.

The gradient of each element of `t1` is equal to the corresponding element in `t2`, and vice versa. This is because the derivative of `t1[i] * t2[i]` w.r.t. `t1[i]` is `t2[i]`, and w.r.t. `t2[i]` is `t1[i]`.

In [53]:
# Compute gradients of the sum of all elements in t1_mul_t2 with respect to t1 and t2

lvalue, t1_grad = mx.value_and_grad(mul_and_reduce_sum)(t1,t2)
lvalue, t2_grad = mx.value_and_grad(mul_and_reduce_sum)(t2,t1)


# The gradient at each element in t1 and t2 indicates the rate of change of the sum with respect to that element.
# For element-wise multiplication, the gradient at each element of t1 is equal to the corresponding element
# in t2 and vice versa. This is because the derivative of t1[i] * t2[i] w.r.t. t1[i] is t2[i],
# and w.r.t. t2[i] is t1[i].

print(f"t1_grad = {t1_grad}")
print(f"t2_grad = {t2_grad}")

t1_grad = array([[7, 8, 9],
       [10, 11, 12]], dtype=float32)
t2_grad = array([[1, 2, 3],
       [4, 5, 6]], dtype=float32)


In [54]:
# Analyzing the gradient at t2[0,1]. If t2_grad[0,1] is 2, it means that a unit change in t2[0,1] results in a
# change of 2 in the sum. Therefore, increasing t2[0,1] by 3 should increase the sum by 3 * t2_grad[0,1], under
# linear approximation.

# Create a new tensor and add 3 to t2[0,1]
t2_modified = t2
t2_modified[0,1] = t2[0,1] + 3

# Perform the computation again with the modified t2
t1_mul_t2_updated = t1 * t2_modified
updated_sum = t1_mul_t2_updated.sum()

# Compare the change in sum
change_in_sum = updated_sum - t1_mul_t2.sum()
print(f"Change in sum: {change_in_sum}")

Change in sum: array(6, dtype=float32)
