From 4d1fa2f2c6581eb74c6ad01a64b961b41bc510c5 Mon Sep 17 00:00:00 2001 From: python273 Date: Wed, 26 Jul 2023 01:32:00 +0400 Subject: [PATCH] Export llama without llama --- export_meta_llama_bin.py | 153 ++++++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 65 deletions(-) diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py index e8d05d76..801077b5 100644 --- a/export_meta_llama_bin.py +++ b/export_meta_llama_bin.py @@ -1,91 +1,114 @@ """ This script exports the Llama 2 weights in llama2c.bin format. +""" +import sys +import struct +from pathlib import Path +import json -Place it into the root directory of: -https://github.com/facebookresearch/llama +import torch -And then run it similar to their other examples, via torchrun sadly: -torchrun --nproc_per_node 1 export_meta_llama_bin.py -""" +from model import precompute_freqs_cis -from llama import Llama -# ----------------------------------------------------------------------------- -def export(self, filepath='model.bin'): +def export(p, state_dict, filepath='model.bin'): """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') - import struct - import numpy as np - def serialize(t): - d = t.detach().cpu().view(-1).numpy().astype(np.float32) - b = struct.pack(f'{len(d)}f', *d) - f.write(b) + def serialize(key): + print(f"writing {key}...") + t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy() + f.write(memoryview(t)) + del state_dict[key] # first write out the header - hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] - p = self.params - n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, -p.vocab_size, p.max_seq_len) + hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0] + p['vocab_size'] = 32000 + p['max_seq_len'] = 2048 + + n_kv_heads = p.get('n_kv_heads') or p['n_heads'] + header = struct.pack( + 'iiiiiii', + p['dim'], hidden_dim, p['n_layers'], p['n_heads'], + n_kv_heads, -p['vocab_size'], p['max_seq_len'] + ) # NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present # in the checkpoint and should be loaded. f.write(header) # next write out the embedding weights print("writing tok_embeddings...") - serialize(self.tok_embeddings.weight) - + serialize('tok_embeddings.weight') + # now all the layers # attention weights - for i, layer in enumerate(self.layers): - print(f"writing attention_norm layer {i}...") - serialize(layer.attention_norm.weight) - for i, layer in enumerate(self.layers): - print(f"writing attention.wq layer {i}...") - serialize(layer.attention.wq.weight) - for i, layer in enumerate(self.layers): - print(f"writing attention.wk layer {i}...") - serialize(layer.attention.wk.weight) - for i, layer in enumerate(self.layers): - print(f"writing attention.wv layer {i}...") - serialize(layer.attention.wv.weight) - for i, layer in enumerate(self.layers): - print(f"writing attention.wo layer {i}...") - serialize(layer.attention.wo.weight) + for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight') # ffn weights - for i, layer in enumerate(self.layers): - print(f"writing ffn_norm layer {i}...") - serialize(layer.ffn_norm.weight) - for i, layer in enumerate(self.layers): - print(f"writing feed_forward.w1 layer {i}...") - serialize(layer.feed_forward.w1.weight) - for i, layer in enumerate(self.layers): - print(f"writing feed_forward.w2 layer {i}...") - serialize(layer.feed_forward.w2.weight) - for i, layer in enumerate(self.layers): - print(f"writing feed_forward.w3 layer {i}...") - serialize(layer.feed_forward.w3.weight) + for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight') + for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight') + # final rmsnorm - print("writing final rmsnorm, classifier and freq_cis...") - serialize(self.norm.weight) + serialize('norm.weight') # freqs_cis - serialize(self.freqs_cis.real[:p.max_seq_len]) - serialize(self.freqs_cis.imag[:p.max_seq_len]) + freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) + state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']] + state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']] + serialize('freqs_cis.real') + serialize('freqs_cis.imag') + # finally write the output weights - serialize(self.output.weight) + serialize('output.weight') - # write to binary file f.close() print(f"wrote {filepath}") -# ----------------------------------------------------------------------------- - -# init Llama as normal -generator = Llama.build( - ckpt_dir="llama-2-7b", - tokenizer_path="tokenizer.model", - max_seq_len=4096, - max_batch_size=1, -) -export(generator.model, "llama2_7b.bin") + + +def concat_weights(models): + state_dict = {} + for name in list(models[0]): + tensors = [model[name] for model in models] + if len(tensors) == 1 or len(tensors[0].shape) == 1: + state_dict[name] = tensors[0] + continue + is_axis_1 = ( + name.startswith('tok_embeddings.') + or name.endswith('.attention.wo.weight') + or name.endswith('.feed_forward.w2.weight') + ) + axis = 1 if is_axis_1 else 0 + state_dict[name] = torch.cat(tensors, dim=axis) + for model in models: + del model[name] + return state_dict + + +def load_and_export(model_path, output_path): + with open(model_path + 'params.json') as f: + params = json.load(f) + print(params) + + model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) + models = [] + for i in model_paths: + print(f'Loading {i}') + models.append(torch.load(i, map_location='cpu')) + + state_dict = concat_weights(models) + del models + export(params, state_dict, output_path) + + +if __name__ == '__main__': + if len(sys.argv) == 1: + print('[Llama model folder path] [output path]') + exit() + + model_path = sys.argv[1] + output_path = sys.argv[2] + load_and_export(model_path, output_path)