In [2]:
from dataclasses import dataclass, asdict

@dataclass
class TrainerConfig:
    bsz: int = 16
    lr: float = 1e-4
    n_steps: int = 1000
    pad_token_id: int = 65535  # Max value of uint16

cfg_t = TrainerConfig()
cfg_m = LLaMAConfig()

In [3]:
from dataset import config_dataloader
load_data = config_dataloader(cfg_t.bsz, cfg_m.seq_len, cfg_t.pad_token_id)

In [4]:
import mlx.core as mx

mx.set_default_device(mx.gpu)

data_iter = iter(load_data())
inputs_, targets_ = next(data_iter)
while mx.all(inputs_ != cfg_t.pad_token_id):
    targets_ = next(data_iter)

In [8]:
from llama import LLaMAConfig, LLaMA
model = LLaMA(**asdict(cfg_m))

In [13]:
from mlx import nn

def forward(model, inputs, targets):
    pad_mask = (inputs != cfg_t.pad_token_id)
    logits = model(inputs * pad_mask)

    logprobs = nn.losses.cross_entropy(logits, targets)
    logprobs_m = logprobs * pad_mask
    loss = logprobs_m.sum() / pad_mask.sum()

    return loss

In [14]:
loss_and_grad = nn.value_and_grad(model, forward)
loss, grad = loss_and_grad(model, inputs_, targets_)
loss

array(10.5277, dtype=float32)

In [8]:
from functools import partial
import mlx.optimizers as optim

optimizer = optim.AdamW(learning_rate=cfg_t.lr)
state = [model.state, optimizer.state]

@partial(mx.compile, inputs=state, outputs=state)
def train_step(inputs, targets):
    loss_and_grad = nn.value_and_grad(model, forward)
    loss, grads = loss_and_grad(model, inputs, targets)
    optimizer.update(model, grads)
    return loss

loss = train_step(inputs_, targets_)
loss

array(61.2031, dtype=float32)

In [9]:
train_step(inputs_, targets_)

array(55.4943, dtype=float32)

In [1]:
from dataclasses import dataclass, asdict
from mlx import nn
import mlx.core as mx
import mlx.optimizers as optim
from dataset import config_dataloader
from llama import LLaMAConfig, LLaMA

@dataclass
class TrainerConfig:
    bsz: int = 16
    lr: float = 1e-4
    n_steps: int = 1000
    pad_token_id: int = 65535  # Max value of uint16

cfg_t = TrainerConfig()
cfg_m = LLaMAConfig()

model = LLaMA(**asdict(cfg_m))
optimizer = optim.AdamW(learning_rate=cfg_t.lr)
load_data = config_dataloader(cfg_t.bsz, cfg_m.seq_len, cfg_t.pad_token_id)

inputs_BT, targets_BT = next(load_data(cfg_t.n_steps))

# One train forward pass: inputs_BT, cfg_t.pad_token_id, model, targets_BT
# pad_mask_BT = (inputs_BT != cfg_t.pad_token_id)
# logits_BTV = model(inputs_BT * pad_mask_BT)
# logprobs_BT = nn.losses.cross_entropy(logits_BTV, targets_BT)
# loss = (logprobs_BT * pad_mask_BT).sum() / pad_mask_BT.sum()

In [6]:
class Dummy:
    def __init__(self, val):
        self.val = val
    def __add__(self, rhs):
        return Dummy(self.val + rhs.val)
    def __repr__(self):
        return f'Dummy({self.val})'

def test_scope():
    def inner(x):
        return x + y
    y = Dummy(39)
    x_ = Dummy(85)
    print(inner(x_))
    y.val = -85
    print(inner(x_))
test_scope()

Dummy(124)
Dummy(0)


In [2]:
def train_forward_pass(model, inputs_BT, targets_BT):
    pad_mask_BT = (inputs_BT != cfg_t.pad_token_id)
    logits_BTV = model(inputs_BT * pad_mask_BT)
    logprobs_BT = nn.losses.cross_entropy(logits_BTV, targets_BT)
    loss = (logprobs_BT * pad_mask_BT).sum() / pad_mask_BT.sum()
    return loss

In [3]:
@mx.compile
def train_step(model, optimizer, inputs_BT, targets_BT):
    loss_and_grad = nn.value_and_grad(model, train_forward_pass)
    loss, grads = loss_and_grad(model, inputs_BT, targets_BT)
    optimizer.update(model, grads)
    return model.state, optimizer.state
model_state, opt_state = train_step(model, optimizer, inputs_BT, targets_BT)
model.update(model_state)
optimizer.update(opt_state)

ValueError: [compile] Function arguments must be trees of arrays or constants (floats, ints, or strings), but received type mlx.optimizers.optimizers.AdamW.

In [9]:
optimizer.state

{'step': array(0, dtype=uint64), 'learning_rate': array(0.0001, dtype=float32)}