Comment:
Using the AdamW with bias correction (not the one in mlx) combine with clip_grad_norm allows to be similar to pytorch performances.
One possible explanation is that pytorch also has a trick to avoid gradient explosion.
Another one explanation, more logical, is that the initialisation has a big impact on gradient stability.
The fact that the LR changes at each iteration instead of each epoch can be the last one.
Suggestion feature : Having LR scheduler with a option to do it per batch or per epoch would be a great feature to evaluate.
Here a view of 10 time runs using the same initialization (loading identical saved weights for each run) & same data split (mv = validation, mt = test)

.
we can see the real deviation of RMSE over the runs.
Desktop (please complete the following information):
- OS Version: [e.g. MacOS 15.1.1]
- Version [e.g. 0.21.0]