In [3]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import time
from functools import partial

hidden_size = 20
learning_rate = 0.01
batch_size = 500
epochs = 20

# 1. Define the model
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)  # No sigmoid here
        )

    def __call__(self, x):
        return self.net(x)

# 2. Generate synthetic training data
def generate_data(n_samples=10000):
    np.random.seed(42)
    X = np.random.randn(n_samples, 2).astype(np.float32)
    y = ((X[:, 0] * X[:, 1]) > 0).astype(np.float32).reshape(-1, 1)  # class 1 if signs match
    return X, y

X_train, y_train = generate_data(8000)
X_test, y_test = generate_data(2000)

# 3. Initialize model and optimizer
model = MyModel()
optimizer = optim.Adam(model, learning_rate)

# 4. Define binary cross entropy loss (logits)
def loss_fn(model, X, y):
    logits = model(X)
    return mx.mean(
        mx.maximum(logits, 0) - logits * y + mx.log1p(mx.exp(-mx.abs(logits)))
    )

# 5. Gradient & optimization
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
    loss, grads = loss_and_grad_fn(model, X, y)
    optimizer.update(model, grads)
    return loss

@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
    preds = mx.sigmoid(model(X)) > 0.5
    return mx.mean(preds == y)

def batch_iterate(batch_size, X, y):
    perm = np.random.permutation(len(X))
    for i in range(0, len(X), batch_size):
        idx = perm[i : i + batch_size]
        yield mx.array(X[idx]), mx.array(y[idx])

# 6. Training loop
tic = time.perf_counter()
for e in range(epochs):
    for Xb, yb in batch_iterate(batch_size, X_train, y_train):
        step(Xb, yb)
        mx.eval(model.state)

    acc = eval_fn(mx.array(X_test), mx.array(y_test)).item()
    loss = loss_fn(model, mx.array(X_test), mx.array(y_test)).item()
    print(f"Epoch {e+1:02d}: Accuracy = {acc:.3f}, Loss = {loss:.4f}, Time = {time.perf_counter() - tic:.2f}s")


ValueError: [addmm] Got 0 dimension input. Inputs must have at least one dimension.

In [2]:
model.parameters()

{'net': {'layers': [{'weight': array([[-0.301643, -0.09589],
           [-0.679963, -0.502911],
           [-0.596912, 0.162901],
           ...,
           [-0.208396, -0.292126],
           [0.0380046, 0.585519],
           [0.124601, 0.663408]], dtype=float32),
    'bias': array([-0.466846, 0.0375038, 0.493884, ..., -0.423702, -0.0151615, 0.435534], dtype=float32)},
   {},
   {'weight': array([[-0.0809622, -0.208339, 0.141565, ..., 0.0374379, -0.221195, -0.151838]], dtype=float32),
    'bias': array([0.107936], dtype=float32)}]}}