Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions llms/gguf_llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None):
config = get_config(metadata)
model = Model(ModelArgs(**config))
if quantization is not None:
# quantized the LM head?
qm = model if "lm_head.scales" in weights else model.model
nn.QuantizedLinear.quantize_module(
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
qm,
**quantization,
class_predicate=class_predicate,
)

def dequantize(k):
weight = weights.pop(f"{k}.weight")
scales = weights.pop(f"{k}.scales")
biases = weights.pop(f"{k}.biases")
weights[f"{k}.weight"] = mx.dequantize(
weight, scales=scales, biases=biases, **quantization
)

# Dequantize embeddings
dequantize("model.embed_tokens")

tokenizer = GGUFTokenizer(metadata)
model.load_weights(list(weights.items()))
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/llama/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items())))

# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
nn.quantize(model, args.q_group_size, args.q_bits)

# Update the config:
quantized_config["quantization"] = {
Expand Down
2 changes: 1 addition & 1 deletion llms/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def load_model(model_path):
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/llama/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.8.0
mlx>=0.11.0
sentencepiece
torch
numpy
2 changes: 1 addition & 1 deletion llms/mistral/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items())))

# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
nn.quantize(model, args.q_group_size, args.q_bits)

# Update the config:
quantized_config["quantization"] = {
Expand Down
2 changes: 1 addition & 1 deletion llms/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def load_model(folder: str):
weights = tree_unflatten(list(weights.items()))
model = Mistral(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/mistral/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.8.0
mlx>=0.11.0
sentencepiece
torch
numpy
5 changes: 1 addition & 4 deletions llms/mixtral/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,10 @@ def quantize(weights, config, args):
model.update(all_weights)

# Quantize the model:
nn.QuantizedLinear.quantize_module(
nn.quantize(
model,
args.q_group_size,
args.q_bits,
# TODO: Quantize gate matrices when < 32 tiles supported
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)

# Extract the subset of quantized weights:
Expand Down
6 changes: 1 addition & 5 deletions llms/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,7 @@ def load_model(folder: str):
weights = tree_unflatten(list(weights.items()))
model = Mixtral(model_args)
if quantization is not None:
# TODO: Quantize gate matrices when < 32 tiles supported
quantization["linear_class_predicate"] = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
)
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)

model.update(weights)
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/mixtral/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.8.0
mlx>=0.11.0
sentencepiece
torch
numpy
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __call__(
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out, cache

Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __call__(
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = self.model.embed_tokens.as_linear(out)
return out, cache

@property
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __call__(
h = self.norm(h)

if self.weight_tying:
return h @ self.wte.weight.T, cache
return self.wte.as_linear(h), cache

return self.ff_out(h), cache

Expand Down
13 changes: 9 additions & 4 deletions llms/mlx_lm/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,24 @@ def __init__(self, args: ModelArgs):
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache

def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
Expand Down
14 changes: 7 additions & 7 deletions llms/mlx_lm/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,20 @@ def __init__(self, args: ModelArgs):
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
sself.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache

@property
def layers(self):
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.10
mlx>=0.11
numpy
transformers>=4.39.3
protobuf
Expand Down
1 change: 1 addition & 0 deletions llms/mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):

def __init__(self, tokenizer):
self._tokenizer = tokenizer
self._tokenizer.decode([0])
self.reset()

def reset(self):
Expand Down
5 changes: 5 additions & 0 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def default_loss(model, inputs, targets, lengths):
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)

# Make the batches:
batch_idx = [
Expand Down
80 changes: 33 additions & 47 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer

# Local imports
from .sample_utils import top_p_sampling
Expand All @@ -31,12 +31,6 @@

MAX_FILE_SIZE_GB = 5

linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0]
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)


def _get_classes(config: dict):
"""
Expand Down Expand Up @@ -188,14 +182,14 @@ def _step(y):
repetition_context = repetition_context[-repetition_context_size:]
return y, prob

y, prob = _step(y)
y, p = _step(y)

mx.async_eval(y)
while True:
sync = mx.async_eval(y)
next_out = _step(y)
sync.wait()
yield y.item(), prob
y, prob = next_out
next_y, next_p = _step(y)
mx.async_eval(next_y)
yield y.item(), p
y, p = next_y, next_p


def generate(
Expand Down Expand Up @@ -283,6 +277,16 @@ def generate(
return detokenizer.text


def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
return config


def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
"""
Load and initialize the model from a given path.
Expand All @@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise

config = load_config(model_path)

weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
Expand All @@ -325,26 +324,17 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)

if quantization is not None:
# for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys():
vocab_size = config["vocab_size"]
extended_linear_class_predicate = (
lambda layer: linear_class_predicate(layer)
and layer.weight.shape[0] != vocab_size
)
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=extended_linear_class_predicate,
)
# for models that have lm_head quant
else:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
model,
**quantization,
class_predicate=class_predicate,
)

model.load_weights(list(weights.items()))

Expand Down Expand Up @@ -395,10 +385,9 @@ def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path)

return model, config.to_dict(), tokenizer
return model, config, tokenizer


def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
Expand Down Expand Up @@ -543,10 +532,7 @@ def quantize_model(
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)

nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))

Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.9.0"
__version__ = "0.10.0"
Loading