forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_grok.py
467 lines (359 loc) · 14.9 KB
/
convert_grok.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
"""
Convert Grok-1 weights to GGUF format.
Example invocation:
python -m convert_grok -i path/to/grok-1/ckpt-0 --vocab_dir path/to/grok -o grok.bin -t q4_0 --experts 1,2
To run:
./build/bin/main -m grok.bin -p "The answer to life the universe and everything is" -s 1 -n 3 -ngl 1
"""
import argparse
import logging
import mmap
import os
import pathlib
import pickletools
import sys
import time
import ml_dtypes
import numpy as np
import torch
try:
from tabulate import tabulate
except ModuleNotFoundError:
pass
from convert import SentencePieceVocab
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(pathlib.Path(__file__).parent / "gguf-py"))
import gguf
QK8_0 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q8_0][0]
QK4_0 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q4_0][0]
QK4_1 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q4_1][0]
# Heuristic to avoid having to fully parse pickle files.
FP32_SHAPES = {805306368: (131072, 6144), 6144: (6144,), 49152: (6144, 8)}
BF16_SHAPES = {
262144: (8, 1, 32768),
393216: (8, 8, 6144),
1024: (1, 1024),
49152: (8, 6144),
6144: (1, 6144),
}
class AttributeDict(dict):
def __getattr__(self, key):
return self.__getitem__(key) if key in self else super().__getattr__(key)
__setattr__ = dict.__setitem__
def _genops(data):
view = memoryview(data)
code2op = {ord(d.code): d for d in pickletools.opcodes}
dataops = {
"BINBYTES": pickletools.read_uint4,
"BINBYTES8": pickletools.read_uint8,
}
while True:
pos = data.tell()
code = data.read_byte()
opcode = code2op[code]
arg = None
if opcode.arg is not None:
if opcode.name not in dataops:
arg = opcode.arg.reader(data)
else:
size = dataops[opcode.name](data)
p = data.tell()
arg = np.frombuffer(view[p : p + size], dtype=np.uint8)
data.seek(size, 1)
yield opcode, arg, pos
if code == ord(b"."):
break
def genops(fn):
"""Yield (opcode, arg, pos) from for a pickle file.
Uses mmap to avoid copies of binary data (e.g., np and JAX arrays)."""
with open(fn, "rb") as f:
yield from _genops(mmap.mmap(f.fileno(), length=0, flags=mmap.MAP_PRIVATE))
def get_weights(fn):
"""Returns tensor/array data in Grok pickle files, zero copy."""
arrays = []
for unused_opcode, arg, unused_pos in genops(fn):
if isinstance(arg, np.ndarray):
arrays.append(arg)
if len(arrays) == 1:
# Plain numpy array.
array = arrays[0].view(np.float32)
array = array.reshape(FP32_SHAPES[array.size])
return array, None
elif len(arrays) == 2:
weight, scales = arrays
scales = scales.view(ml_dtypes.bfloat16)
scales = scales.reshape(BF16_SHAPES[scales.size])
weight = weight.view(np.int8)
shape = list(scales.shape)
shape[-2] = -1
weight = weight.reshape(shape)
return weight, scales
assert len(arrays) in (1, 2)
def torch_roundf(t: torch.Tensor) -> torch.Tensor:
"""Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py."""
a = abs(t)
floored = torch.floor(a)
b = floored + torch.floor(2 * (a - floored))
return torch.sign(t) * b
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
# Equivalent to gguf.quantize_q8_0 but PyTorch instead of Numpy.
assert tensor.shape[1] % QK8_0 == 0
tensor = tensor.reshape(-1, QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0)
tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char()
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c (modulo rounding away from zero)
assert tensor.shape[1] % QK4_0 == 0
tensor = tensor.reshape(-1, QK4_0)
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
scale = max_values / -8
tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char()
# compress two int4 weights into a int8
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_1 in ggml.c (modulo rounding away from zero)
assert tensor.shape[1] % QK4_1 == 0
tensor = tensor.reshape(-1, QK4_1)
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
abs_min_indices = tensor.min(dim=-1, keepdim=True).indices
min_values = torch.take_along_dim(tensor, abs_min_indices, dim=-1)
scale = (max_values - min_values) / 15
tensor = ((tensor - min_values) / scale).round().clamp(min=0, max=15).char()
# compress two int4 weights into a int8
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
# add scale into each block
tensor = torch.cat(
(scale.half().view(torch.int8), min_values.half().view(torch.int8), tensor), dim=-1
)
return tensor
def maybe_quantize_tensor(tensor, ggml_type):
assert tensor.dtype == torch.float32
if ggml_type == gguf.GGMLQuantizationType.F32:
return tensor.float()
elif ggml_type == gguf.GGMLQuantizationType.F16:
return tensor.half()
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
return quantize_q8_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
return quantize_q4_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_1:
return quantize_q4_1(tensor)
else:
raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})")
def get_dtype_and_ggml_type(name, tensor, ggml_type):
if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name:
if tensor.shape[1] % QK8_0 == 0:
return np.int8, ggml_type
else:
return np.float16, gguf.GGMLQuantizationType.F16
else:
return np.float32, gguf.GGMLQuantizationType.F32
def dump_state_dict(f, ggml_type, input_dir, config):
weights = {}
# Load weights in file order (mmap'ed).
for idx, name in enumerate(get_weight_names(config.num_hidden_layers)):
weights[name] = get_weights(f"{input_dir}/tensor{idx:05}_000")
logging.debug("Loaded %i files", len(weights))
# But write in layer order.
weight_names = get_weight_names(config.num_hidden_layers, lexicographic=False)
# Operate on meta tensors to find shapes and dtypes for GGUF header.
for name in weight_names:
weight, scales = weights[name]
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info(
f"{name}.weight",
list(meta_tensor.shape),
dtype,
quantized_meta_tensor.nbytes,
tensor_ggml_type,
)
f.write_header_to_file()
f.write_kv_data_to_file()
f.write_ti_data_to_file()
# Now write actual tensor data.
tensor_info = []
for name in weight_names:
weight, scales = weights.pop(name)
tensor = convert_weight(name, weight, scales, config)
_, tensor_ggml_type = get_dtype_and_ggml_type(name, tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
logging.info(
f"dumping {name}:"
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes"
)
f.write_tensor_data(array)
tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name))
try:
print( # noqa: NP100
tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")
)
except NameError:
pass
if weights:
logging.warning("Not all tensors are converted")
def from_numpy(array):
"""Like torch.from_numpy, but handle ml_dtypes.bfloat16 too."""
if array.dtype == ml_dtypes.bfloat16:
return torch.from_numpy(array.view(np.uint8)).view(torch.bfloat16)
return torch.from_numpy(array)
def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None):
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
weight = from_numpy(weight).to(device=device, dtype=dtype)
if scales is not None:
scale = from_numpy(scales).to(device=device, dtype=dtype)
# row parallel layers have sharded scale
if len(scale.shape) >= 2 and scale.shape[-2] != 1:
scale = scale[..., None, :]
weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1])
weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1])
else:
weight = weight * scale
if name != "token_embd" and len(weight.shape) >= 2:
# Transpose linear matrix
weight = weight.transpose(-1, -2)
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
weight = weight[config.experts] # gather.
return weight
def extract_vocabulary_from_model(vocab):
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size
return tokens, scores, toktypes
def get_weight_names(num_hidden_layers=64, lexicographic=True):
"""Return Grok-1 weight names.
If `lexicographic` is set, the order is as in the tensor#####_000 files."""
weight_names = [
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM],
]
layer = (
gguf.MODEL_TENSOR.FFN_GATE_EXP,
gguf.MODEL_TENSOR.FFN_DOWN_EXP,
gguf.MODEL_TENSOR.FFN_UP_EXP,
gguf.MODEL_TENSOR.ATTN_K,
gguf.MODEL_TENSOR.ATTN_OUT,
gguf.MODEL_TENSOR.ATTN_Q,
gguf.MODEL_TENSOR.ATTN_V,
gguf.MODEL_TENSOR.ATTN_NORM,
gguf.MODEL_TENSOR.ATTN_OUT_NORM,
gguf.MODEL_TENSOR.FFN_NORM,
gguf.MODEL_TENSOR.LAYER_OUT_NORM,
gguf.MODEL_TENSOR.FFN_GATE_INP,
)
layers = [str(bid) for bid in range(64)]
if lexicographic:
# Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
layers.sort()
for bid in layers[:num_hidden_layers]:
for key in layer:
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
return weight_names
def convert_grok(args, vocab, ggml_type):
start = time.time()
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
return _ffn_size
config = {
"hidden_act": "gelu",
"pad_token_id": 0,
"eos_token_id": 2,
"max_position_embeddings": 8192,
"output_multiplier_scale": 0.5773502691896257,
"embedding_multiplier_scale": 78.38367176906169,
"hidden_size": 48 * 128,
"intermediate_size": -1,
"num_attention_heads": 48,
"num_key_value_heads": 8,
"num_hidden_layers": 64, # Change to 1 for quicker debugging.
"num_selected_experts": 2,
"rope_theta": 10000,
"attn_output_multiplier": 0.08838834764831845,
"rms_norm_eps": 1e-5,
}
config = AttributeDict(config)
config.intermediate_size = ffn_size(config.hidden_size, 8)
config.experts = list(range(8))
if args.experts != "":
config.experts = [int(x, 0) for x in args.experts.split(",")]
config.num_experts = len(config.experts)
assert config.num_experts >= 2, "need at least 2 experts"
logging.info("experts to export: %s", config.experts)
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
f.add_name("grok-1")
f.add_context_length(config.max_position_embeddings)
f.add_embedding_length(config.hidden_size)
f.add_block_count(config.num_hidden_layers)
f.add_feed_forward_length(config.intermediate_size)
f.add_rope_dimension_count(config.hidden_size // config.num_attention_heads)
f.add_head_count(config.num_attention_heads)
f.add_head_count_kv(config.num_key_value_heads)
f.add_expert_count(config.num_experts)
f.add_expert_used_count(config.num_selected_experts)
f.add_layer_norm_rms_eps(config.rms_norm_eps)
f.add_rope_freq_base(config.rope_theta)
f.add_tokenizer_model("llama")
# Extract model vocabulary for model conversion
tokens, scores, toktypes = extract_vocabulary_from_model(vocab)
f.add_token_list(tokens)
f.add_token_scores(scores)
f.add_token_types(toktypes)
f.add_quantization_version(ggml_type)
dump_state_dict(f, ggml_type, args.input_dir, config)
f.close()
delta = time.time() - start
logging.info(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec")
def load_vocab(path):
def load_spm(p):
logging.info(f"Loading vocab file {p}")
return SentencePieceVocab(p)
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
if path.is_dir():
path2 = path / "tokenizer.model"
# Use `.parent` instead of /.. to handle the symlink case better.
path3 = path.parent / "tokenizer.model"
if path2.exists():
return load_spm(path2)
elif path3.exists():
return load_spm(path3)
raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir"
)
def main():
parser = argparse.ArgumentParser("convert_grok")
parser.add_argument("-i", "--input_dir", type=str)
parser.add_argument("-o", "--save_path", type=pathlib.Path)
parser.add_argument(
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
)
parser.add_argument("--vocab_dir", type=str, default="")
parser.add_argument("--experts", type=str, default="")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
vocab = load_vocab(
pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir)
)
ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
convert_grok(args, vocab, ggml_type)
if __name__ == "__main__":
main()