Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPT support in llama.cpp #3417

Merged
merged 17 commits into from Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b49792b
CUDA: added support for ggml_clamp (see also: https://github.com/gger…
jploski Sep 30, 2023
15236e8
mpt : added an implementation based (mostly) on falcon integration, m…
jploski Sep 30, 2023
84e30e8
mpt : protect against "clip_qkv": null in mpt-7b
jploski Sep 30, 2023
00e8c5c
mpt : quick fix to avoid "Strange model" warning when quantizing MPT …
jploski Sep 30, 2023
1be89c4
mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out f…
jploski Sep 30, 2023
26c253e
mpt : standardized all tensor names to follow GGUF spec
jploski Sep 30, 2023
df072d2
mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GE…
jploski Sep 30, 2023
90e7d6d
mpt : fixed comment s/gptneox/mpt/
jploski Oct 2, 2023
4708012
mpt : remove tabs, trailing whitespace
jploski Oct 2, 2023
1364bcd
mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) a…
jploski Oct 3, 2023
7d6a24a
mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to co…
jploski Oct 6, 2023
292363e
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 9, 2023
ad3c2f3
comment out n_past instead of marking it unused
cebtenzzre Oct 9, 2023
1a454eb
mpt : removed hardcoded +178 from convert script in favor of utilizin…
jploski Oct 9, 2023
32172f1
mpt : remove unused tokenizer_json in convert script
cebtenzzre Oct 9, 2023
96cf3f5
ggml : remove obsolete n_past assert in ggml_alibi
ggerganov Oct 10, 2023
9b66378
llama : print clam_kqv and max_alibi_bias hparams
ggerganov Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
216 changes: 216 additions & 0 deletions convert-mpt-hf-to-gguf.py
@@ -0,0 +1,216 @@
#!/usr/bin/env python3
# HF mpt--> gguf conversion

from __future__ import annotations

import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]

if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf


def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1

if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
help="output format - use 0 for float32, 1 for float16",
)
return parser.parse_args()

args = parse_args()

dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16

# map from ftype to string
ftype_str = ["f32", "f16"]

if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+dir_model.name)

with open(dir_model / "config.json", "r", encoding="utf-8") as f:
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
hparams = json.load(f)

if hparams["architectures"][0] != "MPTForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit()

# get number of model parts
num_parts = count_model_parts(dir_model)

ARCH=gguf.MODEL_ARCH.MPT
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])

print("gguf: get model metadata")

block_count = hparams["n_layers"]

gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_seq_len"])
gguf_writer.add_embedding_length(hparams["d_model"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
gguf_writer.add_head_count(hparams["n_heads"])
gguf_writer.add_layer_norm_eps(1e-05)
if hparams["attn_config"]["clip_qkv"] is not None:
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])

# TOKENIZATION

print("gguf: get tokenizer metadata")

tokens: list[bytearray] = []
scores: list[float] = []
toktypes: list[int] = []

# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

print("gguf: get gpt2 tokenizer vocab")

# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
# accomodate some "reserved" tokens; this is causing problems down the line in
# llama.cpp, so we pad the vocab with dummy tokens:

vocab_size = hparams["vocab_size"]

# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}

for i in range(vocab_size):
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
scores.append(0.0) # dummy
toktypes.append(gguf.TokenType.NORMAL)

gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

tensor_map = gguf.get_tensor_name_map(ARCH,block_count)

# tensor info
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")

for name in model_part.keys():
data = model_part[name]

old_dtype = data.dtype

# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)

data = data.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Cannot map tensor '" + name + "'")
continue # for the sake of compatibility with some old published models, don't quit
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(new_name, data)

# note: MPT output is tied to (same as) wte in original model;
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
if new_name == "token_embd.weight":
gguf_writer.add_tensor("output.weight", data)

print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print(f"gguf: model successfully exported to '{fname_out}'")
print("")
47 changes: 45 additions & 2 deletions ggml-cuda.cu
Expand Up @@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
Expand Down Expand Up @@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}

static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}

template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
Expand Down Expand Up @@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}

static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}

template<typename T>
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
Expand Down Expand Up @@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi(
const int64_t ne02 = src0->ne[2];
const int64_t nrows = ggml_nrows(src0);

const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

GGML_ASSERT(ne01 + n_past == ne00);
//GGML_ASSERT(ne01 + n_past == ne00);
GGML_ASSERT(n_head == ne02);

const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
Expand Down Expand Up @@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale(
(void) src1_dd;
}

inline void ggml_cuda_op_clamp(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const float min = ((float *) dst->op_params)[0];
const float max = ((float *) dst->op_params)[1];

clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());

(void) src1;
(void) dst;
(void) src1_dd;
}

static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
const int64_t nrows0 = ggml_nrows(src0);

Expand Down Expand Up @@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}

static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
}

static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
Expand Down Expand Up @@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_SCALE:
func = ggml_cuda_scale;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;
}
func = ggml_cuda_clamp;
break;
case GGML_OP_CPY:
func = ggml_cuda_cpy;
break;
Expand Down
2 changes: 1 addition & 1 deletion ggml-metal.m
Expand Up @@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(

const int nth = MIN(1024, ne00);

const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
Expand Down
4 changes: 1 addition & 3 deletions ggml.c
Expand Up @@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32(
return;
}

const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

assert(n_past >= 0);

const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
Expand Down