Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Optional AVX2 Support, Cache Alignment, and Enhances Model Export Speed #94

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,23 @@ rundebug: run.c
runfast: run.c
gcc -Ofast -o run run.c -lm

# additionally compiles with AVX2 intrinsics enabled
.PHONY: runavx2
runavx2: run.c
gcc -Ofast -march=native -mavx2 -DLLAMAC_AVX2 -o run run.c -lm

# additionally compiles with OpenMP, allowing multithreaded runs
# make sure to also enable multiple threads when running, e.g.:
# OMP_NUM_THREADS=4 ./run out/model.bin
.PHONY: runomp
runomp: run.c
gcc -Ofast -fopenmp -march=native run.c -lm -o run

# additionally compiles with AVX2 intrinsics enabled
.PHONY: runompavx2
runompavx2: run.c
gcc -Ofast -fopenmp -march=native -mavx2 -DLLAMAC_AVX2 run.c -lm -o run

.PHONY: clean
clean:
rm -f run
178 changes: 95 additions & 83 deletions export_meta_llama_bin.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,126 @@
"""
This script exports the Llama 2 weights in llama2c.bin format.

Place it into the root directory of:
https://github.com/facebookresearch/llama

And then run:
python export_meta_llama_bin.py
"""
import sys
import struct
from pathlib import Path

import json
from pathlib import Path

import torch

from model import precompute_freqs_cis


def export(p, state_dict, filepath='model.bin'):
# -----------------------------------------------------------------------------
def export(checkpoint, params, filepath='model.bin'):
"""export the model weights in fp32 into .bin file to be read from C"""

f = open(filepath, 'wb')
import struct
import numpy as np

def serialize1(t):
d = t.detach().cpu().view(-1).numpy()
b = d.tobytes()
f.write(b)

def serialize(t):
d = t.detach().cpu().float().numpy()
b = d.flatten().tobytes()
f.write(b)

def serialize(key):
print(f"writing {key}...")
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
f.write(memoryview(t))
del state_dict[key]
def serialize_old(t):
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
b = struct.pack(f'{len(d)}f', *d)
f.write(b)

# first write out the header
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
hidden_dim = checkpoint['layers.0.feed_forward.w1.weight'].shape[0] #self.layers[0].feed_forward.w1.weight.shape[0]
p = params
p['max_seq_len'] = 4096
p['vocab_size'] = 32000
p['max_seq_len'] = 2048

n_kv_heads = p.get('n_kv_heads') or p['n_heads']
header = struct.pack(
'iiiiiii',
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len']
)
n_layers = p['n_layers']
n_kv_heads = p.get('n_kv_heads', p['n_heads'])
# header magic version integer added for two reasons
# 1) so that we can version the header
# 2) so that the struct maintains strict cache alignment
# which is necessary so that the weights that follow the header are also cache aligned
header_magic_version = 0x42000000
header = struct.pack('iiiiiiii', header_magic_version, p['dim'], hidden_dim, n_layers, p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len'])
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
f.write(header)

# next write out the embedding weights
print("writing tok_embeddings...")
serialize('tok_embeddings.weight')

serialize(checkpoint['tok_embeddings.weight'].type(torch.HalfTensor))
# now all the layers
# attention weights
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
for i in range(n_layers):
print(f"writing attention_norm layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.attention_norm.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wq layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.attention.wq.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wk layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.attention.wk.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wv layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.attention.wv.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wo layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.attention.wo.weight'].type(torch.HalfTensor))
# ffn weights
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
for i in range(n_layers):
print(f"writing ffn_norm layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.ffn_norm.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing feed_forward.w1 layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.feed_forward.w1.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing feed_forward.w2 layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.feed_forward.w2.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing feed_forward.w3 layer {i}...")
serialize(checkpoint['layers.'+str(i)+'.feed_forward.w3.weight'].type(torch.HalfTensor))


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

freqs_cis = precompute_freqs_cis(
p['dim'] // p['n_heads'], p['max_seq_len'] * 2
)

# final rmsnorm
serialize('norm.weight')
print("writing final rmsnorm, classifier and freq_cis...")
serialize(checkpoint['norm.weight'].type(torch.HalfTensor))
# freqs_cis
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
serialize('freqs_cis.real')
serialize('freqs_cis.imag')

serialize(freqs_cis.real[:p['max_seq_len']].type(torch.HalfTensor))
serialize(freqs_cis.imag[:p['max_seq_len']].type(torch.HalfTensor))
# finally write the output weights
serialize('output.weight')
serialize(checkpoint['output.weight'].type(torch.HalfTensor))

# write to binary file
f.close()
print(f"wrote {filepath}")


def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict


def load_and_export(model_path, output_path):
with open(model_path + 'params.json') as f:
params = json.load(f)
print(params)

model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = []
for i in model_paths:
print(f'Loading {i}')
models.append(torch.load(i, map_location='cpu'))

state_dict = concat_weights(models)
del models
export(params, state_dict, output_path)

# -----------------------------------------------------------------------------

if __name__ == '__main__':
if len(sys.argv) == 1:
print('[Llama model folder path] [output path]')
exit()

model_path = sys.argv[1]
output_path = sys.argv[2]
load_and_export(model_path, output_path)
ckpt_dir = "llama-2-7b"
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

export(checkpoint, params, "llama2_7b.bin")