In [1]:
from dataset import config_dataloader
from llama import LLaMAConfig, LLaMA

In [2]:
from dataclasses import dataclass, asdict

@dataclass
class TrainerConfig:
    bsz: int = 16
    lr: float = 1e-4
    n_steps: int = 1000
    pad_token_id: int = 65535  # Max value of uint16

cfg_t = TrainerConfig()
cfg_m = LLaMAConfig()

In [3]:
load_data = config_dataloader(cfg_t.bsz, cfg_m.seq_len, cfg_t.pad_token_id)

In [4]:
import mlx.core as mx

mx.set_default_device(mx.gpu)

data_iter = iter(load_data())
inputs_, targets_ = next(data_iter)
while mx.all(inputs_ != cfg_t.pad_token_id):
    targets_ = next(data_iter)

In [8]:
model = LLaMA(**asdict(cfg_m))

In [13]:
from mlx import nn

def forward(model, inputs, targets):
    pad_mask = (inputs != cfg_t.pad_token_id)
    logits = model(inputs * pad_mask)

    logprobs = nn.losses.cross_entropy(logits, targets)
    logprobs_m = logprobs * pad_mask
    loss = logprobs_m.sum() / pad_mask.sum()

    return loss

In [14]:
loss_and_grad = nn.value_and_grad(model, forward)
loss, grad = loss_and_grad(model, inputs_, targets_)
loss

array(10.5277, dtype=float32)

In [8]:
from functools import partial
import mlx.optimizers as optim

optimizer = optim.AdamW(learning_rate=cfg_t.lr)
state = [model.state, optimizer.state]

@partial(mx.compile, inputs=state, outputs=state)
def train_step(inputs, targets):
    loss_and_grad = nn.value_and_grad(model, forward)
    loss, grads = loss_and_grad(model, inputs, targets)
    optimizer.update(model, grads)
    return loss

loss = train_step(inputs_, targets_)
loss

array(61.2031, dtype=float32)

In [9]:
train_step(inputs_, targets_)

array(55.4943, dtype=float32)