Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export llama without llama #85

Merged
merged 1 commit into from
Jul 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 88 additions & 65 deletions export_meta_llama_bin.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,114 @@
"""
This script exports the Llama 2 weights in llama2c.bin format.
"""
import sys
import struct
from pathlib import Path
import json

Place it into the root directory of:
https://github.com/facebookresearch/llama
import torch

And then run it similar to their other examples, via torchrun sadly:
torchrun --nproc_per_node 1 export_meta_llama_bin.py
"""
from model import precompute_freqs_cis

from llama import Llama

# -----------------------------------------------------------------------------
def export(self, filepath='model.bin'):
def export(p, state_dict, filepath='model.bin'):
"""export the model weights in fp32 into .bin file to be read from C"""

f = open(filepath, 'wb')
import struct
import numpy as np

def serialize(t):
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
b = struct.pack(f'{len(d)}f', *d)
f.write(b)
def serialize(key):
print(f"writing {key}...")
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
f.write(memoryview(t))
del state_dict[key]

# first write out the header
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
p = self.params
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, -p.vocab_size, p.max_seq_len)
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
p['vocab_size'] = 32000
p['max_seq_len'] = 2048

n_kv_heads = p.get('n_kv_heads') or p['n_heads']
header = struct.pack(
'iiiiiii',
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len']
)
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
f.write(header)

# next write out the embedding weights
print("writing tok_embeddings...")
serialize(self.tok_embeddings.weight)
serialize('tok_embeddings.weight')

# now all the layers
# attention weights
for i, layer in enumerate(self.layers):
print(f"writing attention_norm layer {i}...")
serialize(layer.attention_norm.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wq layer {i}...")
serialize(layer.attention.wq.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wk layer {i}...")
serialize(layer.attention.wk.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wv layer {i}...")
serialize(layer.attention.wv.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wo layer {i}...")
serialize(layer.attention.wo.weight)
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
# ffn weights
for i, layer in enumerate(self.layers):
print(f"writing ffn_norm layer {i}...")
serialize(layer.ffn_norm.weight)
for i, layer in enumerate(self.layers):
print(f"writing feed_forward.w1 layer {i}...")
serialize(layer.feed_forward.w1.weight)
for i, layer in enumerate(self.layers):
print(f"writing feed_forward.w2 layer {i}...")
serialize(layer.feed_forward.w2.weight)
for i, layer in enumerate(self.layers):
print(f"writing feed_forward.w3 layer {i}...")
serialize(layer.feed_forward.w3.weight)
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')

# final rmsnorm
print("writing final rmsnorm, classifier and freq_cis...")
serialize(self.norm.weight)
serialize('norm.weight')
# freqs_cis
serialize(self.freqs_cis.real[:p.max_seq_len])
serialize(self.freqs_cis.imag[:p.max_seq_len])
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
serialize('freqs_cis.real')
serialize('freqs_cis.imag')

# finally write the output weights
serialize(self.output.weight)
serialize('output.weight')

# write to binary file
f.close()
print(f"wrote {filepath}")
# -----------------------------------------------------------------------------

# init Llama as normal
generator = Llama.build(
ckpt_dir="llama-2-7b",
tokenizer_path="tokenizer.model",
max_seq_len=4096,
max_batch_size=1,
)
export(generator.model, "llama2_7b.bin")


def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict


def load_and_export(model_path, output_path):
with open(model_path + 'params.json') as f:
params = json.load(f)
print(params)

model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = []
for i in model_paths:
print(f'Loading {i}')
models.append(torch.load(i, map_location='cpu'))

state_dict = concat_weights(models)
del models
export(params, state_dict, output_path)


if __name__ == '__main__':
if len(sys.argv) == 1:
print('[Llama model folder path] [output path]')
exit()

model_path = sys.argv[1]
output_path = sys.argv[2]
load_and_export(model_path, output_path)