In [47]:
from copy import deepcopy
from datasets import load_dataset
import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load, generate
import mlx.optimizers as optim
from mlx.utils import tree_map, tree_flatten

In [48]:
# 128k vocab size
teacher, tokenizer = load("mlx-community/Llama-3.2-1B-Instruct-4bit")

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [55]:
ctx_len = 256

def prepare(x):
    x = tokenizer.encode(x["text"][0])
    cutoff = (len(x) // ctx_len) * ctx_len
    x = x[:cutoff]
    return { "ids": [x[i:i+ctx_len] for i in range(0, len(x), ctx_len)] }

ds = load_dataset("flpelerin/tinystories-100k", split="train")
ds = ds.map(prepare, batched=True, batch_size=1, remove_columns=ds.column_names)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [56]:
print(teacher.args)

ModelArgs(model_type='llama', hidden_size=2048, num_hidden_layers=16, intermediate_size=8192, num_attention_heads=32, rms_norm_eps=1e-05, vocab_size=128256, head_dim=64, max_position_embeddings=131072, num_key_value_heads=8, attention_bias=False, mlp_bias=False, rope_theta=500000.0, rope_traditional=False, rope_scaling={'factor': 32.0, 'high_freq_factor': 4.0, 'low_freq_factor': 1.0, 'original_max_position_embeddings': 8192, 'rope_type': 'llama3'}, tie_word_embeddings=True)


In [72]:
args = deepcopy(teacher.args)
args.num_key_value_heads = 3
args.num_attention_heads = 9
args.head_dim = 64
args.hidden_size = 576
args.intermediate_size = 1536
args.num_hidden_layers = 2

model = type(teacher)(args)
mx.eval(model.parameters())
print("n_params", sum(v.size for _, v in tree_flatten(model.parameters())))

n_params 80956224


In [73]:
def loss(model, x, y):
    teacher_softmax = nn.softmax(y)
    teacher_log_softmax = nn.log_softmax(y)
    student_log_softmax = nn.log_softmax(model(x))
    return mx.mean(mx.sum(teacher_softmax * (teacher_log_softmax - student_log_softmax), axis=-1))

optimizer = optim.AdamW(learning_rate=0.0001)
step = nn.value_and_grad(model, loss)

state = [model.state, optimizer.state]

In [74]:
B=8
for i in range(0, 100):
    x = mx.array(ds[i*B:i*B+B]["ids"])
    y = teacher(x)

    loss, grads = step(model, x, y)
    optimizer.update(model, grads)
    mx.eval(state)

    print(f"iter: {i}, loss: {loss}, lr: {optimizer.learning_rate:.5f}")
    optimizer.learning_rate -= 0.000001


iter: 0, loss: 10.559500694274902, lr: 0.00010
iter: 1, loss: 9.183050155639648, lr: 0.00010
iter: 2, loss: 8.637887001037598, lr: 0.00010
iter: 3, loss: 8.360408782958984, lr: 0.00010
iter: 4, loss: 7.685101509094238, lr: 0.00010
iter: 5, loss: 7.479364395141602, lr: 0.00009
iter: 6, loss: 7.132399559020996, lr: 0.00009
iter: 7, loss: 6.698503494262695, lr: 0.00009
iter: 8, loss: 6.323685646057129, lr: 0.00009
iter: 9, loss: 6.087820053100586, lr: 0.00009
iter: 10, loss: 5.876641273498535, lr: 0.00009
iter: 11, loss: 5.553765773773193, lr: 0.00009
iter: 12, loss: 5.357820510864258, lr: 0.00009
iter: 13, loss: 5.042308330535889, lr: 0.00009
iter: 14, loss: 4.88326358795166, lr: 0.00009
iter: 15, loss: 4.839049816131592, lr: 0.00008
iter: 16, loss: 4.7905144691467285, lr: 0.00008
iter: 17, loss: 4.737854957580566, lr: 0.00008
iter: 18, loss: 4.692557334899902, lr: 0.00008
iter: 19, loss: 4.5652875900268555, lr: 0.00008
iter: 20, loss: 4.624118804931641, lr: 0.00008
iter: 21, loss: 4.399

In [78]:
def generate(model, str, n):
    x = mx.array(tokenizer.encode(str))[None, :]
    for _ in range(n):
        y = model(x)
        x = mx.concatenate([x, mx.random.categorical(y[:, -1])[None, :]], axis=1)
    return tokenizer.decode(x[0].tolist())

print(generate(model, tokenizer.decode(ds[0]["ids"][:20]), 80))


<|begin_of_text|><|begin_of_text|>Once upon a time, in a big, green park, there was a small, messy dog. The proud blocks! her'shelter. They branch. She do walked were teacher nods, they said. atικοί. She saw Ther. We idea at it is that trapped eager, " 
 four brings that other bed to the king on the tree are.

S thought that a dog,led away that a prizes drink and, "I brown and looked.

",、お had Career and happy
