# Mini-Batch Gradient Descent (MBGD)

The goldilocks zone between vanilla GD and vanilla SGD. It uses the average gradient from a random 'batch' of points.

In [1]:
import numpy as np
def numpyMBGD(X, y, learning_rate=0.01, batch_size=16, epochs=100):
    m, n = X.shape
    theta = np.random.randn(n, 1)  # random initialization

    for epoch in range(epochs):
        shuffled_indices = np.random.permutation(m)
        X_shuffled = X[shuffled_indices]
        y_shuffled = y[shuffled_indices]

        for i in range(0, m, batch_size):
            xi = X_shuffled[i:i + batch_size]
            yi = y_shuffled[i:i + batch_size]

            gradients = 2 / batch_size * xi.T.dot(xi.dot(theta) - yi)
            theta = theta - learning_rate * gradients

    return theta

In [2]:
from sklearn.metrics import mean_absolute_error

# Apply function to some data, in this case, 100 triples from an i.i.d from 0 to 1
X = np.random.rand(100, 3)
y = 5 * X[:, 0] - 3 * X[:, 1] + 2 * X[:, 2] + np.random.randn(100, 1)  # sample linear regression problem
theta = numpyMBGD(X, y)

# Predict and calculate MAE
predictions = X.dot(theta)
mae = mean_absolute_error(y, predictions)
print(f"MAE: {mae}")  # MAE: 1.0887166179544072


MAE: 1.0469046993669393


In [25]:
import jax
import jax.numpy as jnp

@jax.jit
def jaxMBGD(X, y, learning_rate=0.01, batch_size=16, epochs=100, seed=42):
    X = jnp.asarray(X)
    y = jnp.asarray(y)
    m, n = X.shape

    key = jax.random.PRNGKey(seed)
    key, subkey = jax.random.split(key)
    theta = jax.random.normal(subkey, (n, 1))


    for epoch in range(epochs):
        key, subkey = jax.random.split(key)

        shuffled_indices = jax.random.split(subkey, m)

        X_shuffled = X[shuffled_indices]
        y_shuffled = y[shuffled_indices]

        for i in range(0, m, batch_size):
            xi = X_shuffled[i:i + batch_size]
            yi = y_shuffled[i:i + batch_size]

            gradients = 2 / xi.shape[0] * xi.T.dot(xi.dot(theta) - yi)
            theta = theta - learning_rate * gradients

    return theta

In [26]:
key = jax.random.PRNGKey(42)
key, kX, kN = jax.random.split(key, 3)

X = jax.random.normal(kX, (100, 3))
noise = jax.random.normal(kN, (100, 1))

# ✅ keep (100,1) by using [:, [0]] etc
y = (5 * X[:, [0]] - 3 * X[:, [1]] + 2 * X[:, [2]]) + noise

theta = jaxMBGD(X, y)
print(theta.shape)  # (3,1)

TypeError: dot_general requires contracting dimensions to have the same shape, got (16,) and (2,).