Skip to content

Commit

Permalink
llm : add MPT support (ggerganov#3417)
Browse files Browse the repository at this point in the history
* CUDA: added support for ggml_clamp (see also: ggerganov/ggml#545)

* mpt : added an implementation based (mostly) on falcon integration, modified with deltas from ggml/examples/mpt

* mpt : protect against "clip_qkv": null in mpt-7b

* mpt : quick fix to avoid "Strange model" warning when quantizing MPT models

* mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping" (more compliant with the current GGUF spec?)

* mpt : standardized all tensor names to follow GGUF spec

* mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GET_KEY macro instead of duplicate code

* mpt : fixed comment s/gptneox/mpt/

* mpt : remove tabs, trailing whitespace

* mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt

* mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to convert-gptneox-hf-to-gguf.py in pr:3252

* comment out n_past instead of marking it unused

* mpt : removed hardcoded +178 from convert script in favor of utilizing hparams["vocab_size"]

* mpt : remove unused tokenizer_json in convert script

* ggml : remove obsolete n_past assert in ggml_alibi

* llama : print clam_kqv and max_alibi_bias hparams

---------

Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
3 people committed Oct 10, 2023
1 parent 11ea5c7 commit f5f9121
Show file tree
Hide file tree
Showing 5 changed files with 685 additions and 9 deletions.
216 changes: 216 additions & 0 deletions convert-mpt-hf-to-gguf.py
@@ -0,0 +1,216 @@
#!/usr/bin/env python3
# HF mpt--> gguf conversion

from __future__ import annotations

import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]

if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf


def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1

if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
help="output format - use 0 for float32, 1 for float16",
)
return parser.parse_args()

args = parse_args()

dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16

# map from ftype to string
ftype_str = ["f32", "f16"]

if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+dir_model.name)

with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "MPTForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit()

# get number of model parts
num_parts = count_model_parts(dir_model)

ARCH=gguf.MODEL_ARCH.MPT
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])

print("gguf: get model metadata")

block_count = hparams["n_layers"]

gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_seq_len"])
gguf_writer.add_embedding_length(hparams["d_model"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
gguf_writer.add_head_count(hparams["n_heads"])
gguf_writer.add_layer_norm_eps(1e-05)
if hparams["attn_config"]["clip_qkv"] is not None:
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])

# TOKENIZATION

print("gguf: get tokenizer metadata")

tokens: list[bytearray] = []
scores: list[float] = []
toktypes: list[int] = []

# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

print("gguf: get gpt2 tokenizer vocab")

# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
# accomodate some "reserved" tokens; this is causing problems down the line in
# llama.cpp, so we pad the vocab with dummy tokens:

vocab_size = hparams["vocab_size"]

# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}

for i in range(vocab_size):
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
scores.append(0.0) # dummy
toktypes.append(gguf.TokenType.NORMAL)

gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

tensor_map = gguf.get_tensor_name_map(ARCH,block_count)

# tensor info
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")

for name in model_part.keys():
data = model_part[name]

old_dtype = data.dtype

# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)

data = data.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Cannot map tensor '" + name + "'")
continue # for the sake of compatibility with some old published models, don't quit
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(new_name, data)

# note: MPT output is tied to (same as) wte in original model;
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
if new_name == "token_embd.weight":
gguf_writer.add_tensor("output.weight", data)

print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print(f"gguf: model successfully exported to '{fname_out}'")
print("")
47 changes: 45 additions & 2 deletions ggml-cuda.cu
Expand Up @@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
Expand Down Expand Up @@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}

static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}

template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
Expand Down Expand Up @@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}

static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}

template<typename T>
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
Expand Down Expand Up @@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi(
const int64_t ne02 = src0->ne[2];
const int64_t nrows = ggml_nrows(src0);

const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

GGML_ASSERT(ne01 + n_past == ne00);
//GGML_ASSERT(ne01 + n_past == ne00);
GGML_ASSERT(n_head == ne02);

const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
Expand Down Expand Up @@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale(
(void) src1_dd;
}

inline void ggml_cuda_op_clamp(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const float min = ((float *) dst->op_params)[0];
const float max = ((float *) dst->op_params)[1];

clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());

(void) src1;
(void) dst;
(void) src1_dd;
}

static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
const int64_t nrows0 = ggml_nrows(src0);

Expand Down Expand Up @@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}

static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
}

static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
Expand Down Expand Up @@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_SCALE:
func = ggml_cuda_scale;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;
}
func = ggml_cuda_clamp;
break;
case GGML_OP_CPY:
func = ggml_cuda_cpy;
break;
Expand Down
2 changes: 1 addition & 1 deletion ggml-metal.m
Expand Up @@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(

const int nth = MIN(1024, ne00);

const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
Expand Down
4 changes: 1 addition & 3 deletions ggml.c
Expand Up @@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32(
return;
}

const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

assert(n_past >= 0);

const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
Expand Down

0 comments on commit f5f9121

Please sign in to comment.