In [1]:
!pip show mlx

Name: mlx
Version: 0.0.4.dev20231211+b0cd092
Summary: A framework for machine learning on Apple Silicon.
Home-page: 
Author: MLX Contributors
Author-email: mlx@group.apple.com
License: 
Location: /Users/id4thomas/miniforge3/envs/mlx/lib/python3.8/site-packages
Requires: 
Required-by: 


In [2]:
import mlx
import mlx.nn as nn
import mlx.core as mx

In [3]:
## Set Random Seed
mlx.core.random.seed(42)

In [4]:
## Device
device = mx.default_device()
print(device)
print(mlx.core.Device.__dict__)


Device(gpu, 0)
{'__init__': <instancemethod __init__ at 0x1082c7b50>, '__doc__': None, '__module__': 'mlx.core', 'type': <property object at 0x1082d3a40>, '__repr__': <instancemethod __repr__ at 0x1082c7c10>, '__eq__': <instancemethod __eq__ at 0x1082c7c70>, '__hash__': None}


In [5]:
## dtypes
for dtype_name in ["int8", "int16", "int32", "int64", "float16", "uint8", "uint16", "uint32"]:
	dtype = mlx.core.__getattribute__(dtype_name)
	print(type(dtype), dtype)

print("Getting dtypes")
mx.array([1.0, 0.1, 0.2]).dtype

<class 'mlx.core.Dtype'> mlx.core.int8
<class 'mlx.core.Dtype'> mlx.core.int16
<class 'mlx.core.Dtype'> mlx.core.int32
<class 'mlx.core.Dtype'> mlx.core.int64
<class 'mlx.core.Dtype'> mlx.core.float16
<class 'mlx.core.Dtype'> mlx.core.uint8
<class 'mlx.core.Dtype'> mlx.core.uint16
<class 'mlx.core.Dtype'> mlx.core.uint32
Getting dtypes


mlx.core.float32

In [6]:
## Sample Array
arr = mx.expand_dims(mx.arange(0, 32), axis = 0)
print(arr.shape, arr.size)

## Broadcast
arr2 = mx.broadcast_to(arr, (2, arr.shape[-1]))
arr2 = arr2*2
print(arr2.shape, arr2.size)
print(arr2)

[1, 32] 32
[2, 32] 64
array([[0, 2, 4, ..., 58, 60, 62],
       [0, 2, 4, ..., 58, 60, 62]], dtype=int32)


In [7]:
## Value and Grad
# https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.value_and_grad.html

def sample_fn(inputs, targets):
	a = mx.broadcast_to(inputs, (2, inputs.shape[-1]))
	outputs = (targets - a).square().mean()
	return outputs

## argnums -> indicies of variables to calculate gradient with respect to
grad_fn = mx.grad(sample_fn, argnums = [0, 1])

## Calculating as int
print("CALCULATING AS INT")
grad1, grad2 = grad_fn(arr, arr2)
print(grad1)
print(grad2)
print("="*30)

## Calculating as float
print("CALCULATING AS FLOAT")
grad1, grad2 = grad_fn(arr.astype(mlx.core.float32), arr2.astype(mlx.core.float32))
print(grad1)
print(grad2)
print("="*30)

CALCULATING AS INT
array([[0, 0, 0, ..., 0, 0, 0]], dtype=int32)
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)
CALCULATING AS FLOAT
array([[0, -0.0625, -0.125, ..., -1.8125, -1.875, -1.9375]], dtype=float32)
array([[0, 0.03125, 0.0625, ..., 0.90625, 0.9375, 0.96875],
       [0, 0.03125, 0.0625, ..., 0.90625, 0.9375, 0.96875]], dtype=float32)


In [9]:
## loss functions
# defined in https://github.com/ml-explore/mlx/blob/main/python/mlx/nn/losses.py
print(nn.losses.cross_entropy)
print(nn.losses.l1_loss)
print(nn.losses.l1_loss)
print(nn.losses.nll_loss)
print(nn.losses.kl_div_loss)

<function cross_entropy at 0x10815e8b0>
<function l1_loss at 0x1082ae9d0>
<function mse_loss at 0x1082aeca0>
<function nll_loss at 0x1082aedc0>
<function kl_div_loss at 0x1082aee50>


In [16]:
# https://zhang-yang.medium.com/how-is-pytorchs-binary-cross-entropy-with-logits-function-related-to-sigmoid-and-d3bd8fb080e7
def bce_with_logits(x, y):
	x = mx.sigmoid(x)
	return -(x.log()*y + (1-y)*(1-x).log()).mean()

x = mx.array([[ 2.3611, -0.8813, -0.5006, -0.2178],
        [ 0.0419,  0.0763, -1.0457, -1.6692],
        [-1.0494,  0.8111,  1.5723,  1.2315],
        [ 1.3081,  0.6641,  1.1802, -0.2547],
        [ 0.5292,  0.7636,  0.3692, -0.8318],
        [ 0.5100,  0.9849, -1.2905,  0.2821],
        [ 1.4662,  0.4550,  0.9875,  0.3143],
        [-1.2121,  0.1262,  0.0598, -1.6363],
        [ 0.3214, -0.8689,  0.0689, -2.5094],
        [ 1.1320, -0.6824,  0.1657, -0.0687]])
y = mx.array([[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])

print(bce_with_logits(x,y))
# tensor(0.7739)

array(0.773853, dtype=float32)
