Skip to content

Commit

Permalink
llama v2 with sharded weights
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Dec 12, 2023
1 parent 9a02dce commit f0c57c1
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 124 deletions.
25 changes: 15 additions & 10 deletions llama/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# LLaMA
# Llama

An example of generating text with LLaMA using MLX.
An example of generating text with Llama (1 or 2) using MLX.

LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters.
Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters.

### Setup

Expand All @@ -14,27 +15,31 @@ pip install -r requirements.txt

Next, download and convert the model. If you do not have access to the model
weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
from Meta.


Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step.
Alternatively, you can also download a select converted checkpoints from the
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
Face and skip the conversion step.

Convert the weights with:

```
python convert.py <path_to_torch_weights> <path_to_mlx_llama_weights.npz>
python convert.py --model_path <path_to_torch_model>
```

The conversion script will save the converted weights in the same location.

### Run

Once you've converted the weights to MLX format, you can interact with the
LLaMA model:
LlaMA model:

```
python llama.py <path_to_mlx_llama_weights.npz> <path_to_tokenizer.model> "hello"
python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
```

Run `python llama.py --help` for more details.

[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
78 changes: 42 additions & 36 deletions llama/convert.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,59 @@
# Copyright © 2023 Apple Inc.

import argparse
from itertools import starmap
import collections
import glob
from pathlib import Path

import numpy as np
import torch

SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)

def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"

elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
def shard_key(k):
keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]

elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")

elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")

elif "output" in key:
key = key.replace("output", "out_proj")

elif "rope" in key:
return None, None

return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
def unshard(k, v):
wn = shard_key(k)
if wn not in SHARD_WEIGHTS:
return v
elif wn in SHARD_FIRST:
axis = 0
elif wn in SHARD_SECOND:
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
parser.add_argument(
"--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args()

state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)
model_path = Path(args.model_path)
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list)
for wf in torch_files:
state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = v.to(torch.float16).numpy()
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
weights[k] = v

out_file = str(model_path / "weights.npz")
for k, v in weights.items():
weights[k] = unshard(k, v)
np.savez(out_file, **weights)
Loading

0 comments on commit f0c57c1

Please sign in to comment.