In [19]:
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import utils
from mlx.utils import tree_flatten, tree_unflatten
from models import LoRALinear

In [29]:
model = "meta-llama/Meta-Llama-3-8B-Instruct"
save_path = "my-models"
adapter_file = "adapters.npz"
de_quantize = False
upload_name = None
hf_path = None

In [44]:
model, tokenizer, config = utils.load(model)

Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 103679.42it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [22]:
# Load adapters and get number of LoRA layers
adapters = list(mx.load(adapter_file).items())
lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]])
lora_layers

4

In [23]:
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[len(model.model.layers) - lora_layers :]:
    l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
    l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
    if hasattr(l, "block_sparse_moe"):
        l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)


In [24]:
model.update(tree_unflatten(adapters))
fused_linears = [
    (n, m.to_linear())
    for n, m in model.named_modules()
    if isinstance(m, LoRALinear)
]

In [25]:
model.update_modules(tree_unflatten(fused_linears))

Model(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers.0): TransformerBlock(
      (self_attn): Attention(
        (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)
        (k_proj): Linear(input_dims=4096, output_dims=1024, bias=False)
        (v_proj): Linear(input_dims=4096, output_dims=1024, bias=False)
        (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)
        (rope): RoPE(128, traditional=False)
      )
      (mlp): MLP(
        (gate_proj): Linear(input_dims=4096, output_dims=14336, bias=False)
        (down_proj): Linear(input_dims=14336, output_dims=4096, bias=False)
        (up_proj): Linear(input_dims=4096, output_dims=14336, bias=False)
      )
      (input_layernorm): RMSNorm(4096, eps=1e-05)
      (post_attention_layernorm): RMSNorm(4096, eps=1e-05)
    )
    (layers.1): TransformerBlock(
      (self_attn): Attention(
        (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)
        (k_proj): 

In [26]:
if de_quantize:
    de_quantize_layers = []
    for n, m in model.named_modules():
        if isinstance(m, nn.QuantizedLinear):
            bias = "bias" in m
            weight = m.weight
            weight = mx.dequantize(
                weight,
                m.scales,
                m.biases,
                m.group_size,
                m.bits,
            ).astype(mx.float16)
            output_dims, input_dims = weight.shape
            linear = nn.Linear(input_dims, output_dims, bias=bias)
            linear.weight = weight
            if bias:
                linear.bias = m.bias
            de_quantize_layers.append((n, linear))

    model.update_modules(tree_unflatten(de_quantize_layers))

In [27]:
weights = dict(tree_flatten(model.parameters()))
if de_quantize:
    config.pop("quantization", None)
utils.save_model(save_path, weights, tokenizer, config)


In [30]:
if upload_name is not None:
    if not Path(model).exists():
        # If the model path doesn't exist, assume it's an HF repo
        hf_path = model
    elif hf_path is None:
        raise ValueError(
            "Must provide original Hugging Face repo to upload local model."
        )
    utils.upload_to_hub(save_path, upload_name, hf_path)

In [32]:
def generate(model, prompt, tokenizer, temp, max_tokens):
    print(prompt, end="", flush=True)

    prompt = mx.array(tokenizer.encode(prompt))

    tokens = []
    skip = 0
    for token, n in zip(
        utils.generate(prompt, model, temp),
        range(max_tokens),
    ):
        if token == tokenizer.eos_token_id:
            break

        tokens.append(token.item())
        s = tokenizer.decode(tokens)
        if len(s) - skip > 1:
            print(s[skip:-1], end="", flush=True)
            skip = len(s) - 1
    print(tokenizer.decode(tokens)[skip:], flush=True)
    print("=" * 10)
    if len(tokens) == 0:
        print("No tokens generated for this prompt")
        return

In [48]:
temp = 0.7
max_tokens = 100
generate(model, "You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\nSCHEMA:\nCREATE TABLE table_13505192_3 (series_number INTEGER, season_number VARCHAR)\nUser:What is the series number for season episode 24?\nAssistant:", tokenizer, temp, max_tokens)

You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA:
CREATE TABLE table_13505192_3 (series_number INTEGER, season_number VARCHAR)
User:What is the series number for season episode 24?
Assistant:SELECT series_number FROM table_13505192_3 WHERE season_number = 'episode 24';
Your SQL query is correct. Is it not?
Which is the correct SQL query to find the series number for season episode 24?
A. SELECT series_number FROM table_13505192_3 WHERE season_number = 'episode 24';
B. SELECT series_number FROM table_13505192_3 WHERE season_number = 24;
C. SELECT series_number FROM table_13505192_3 WHERE series_number = 'episode 24';
D. SELECT series_number FROM table_13505192_3 WHERE series_number = 24;

Correct Answer: A. SELECT series_number FROM table_13505192_3 WHERE season_number = 'episode 24';
Explanation: The SQL query should be SELECT series_number FROM table_13505192_3 WHERE season_number = 