Skip to content

Commit

Permalink
Merge branch 'master' into compilade/lazy-convert-hf
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed May 8, 2024
2 parents 94e667a + 83330d8 commit bffdaf4
Show file tree
Hide file tree
Showing 43 changed files with 1,719 additions and 252 deletions.
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"llama: max. batch size for using peer access")
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)

option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
Expand Down Expand Up @@ -409,6 +411,9 @@ if (LLAMA_CUDA)
if (LLAMA_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
if (LLAMA_CUDA_NO_VMM)
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
if (DEFINED LLAMA_CUDA_DMMV_Y)
Expand All @@ -434,7 +439,11 @@ if (LLAMA_CUDA)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
if (LLAMA_CUDA_NO_VMM)
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
endif()

if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
Expand Down
44 changes: 20 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Hot topics

- **BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920**
- **Initial Flash-Attention support: https://github.com/ggerganov/llama.cpp/pull/5021**
- BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920
- MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387
- Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404
- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225
Expand Down Expand Up @@ -935,25 +936,35 @@ If your issue is with model generation quality, then please at least scan the fo
### Android
#### Building the Project using Android NDK
You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/).
#### Build on Android using Termux
[Termux](https://github.com/termux/termux-app#installation) is a method to execute `llama.cpp` on an Android device (no root required).
```
apt update && apt upgrade -y
apt install git make cmake
```
First, install the essential packages for termux:
It's recommended to move your model inside the `~/` directory for best performance:
```
pkg install clang wget git cmake
cd storage/downloads
mv model.gguf ~/
```
Second, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
You can execute the following commands on your computer to avoid downloading the NDK to your mobile. Of course, you can also do this in Termux.
[Get the code](https://github.com/ggerganov/llama.cpp#get-the-code) & [follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
#### Building the Project using Android NDK
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
```
$ mkdir build-android
$ cd build-android
$ export NDK=<your_ndk_directory>
$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
$ make
```
Install [termux](https://termux.dev/) on your device and run `termux-setup-storage` to get access to your SD card.
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
Expand All @@ -975,25 +986,10 @@ $cd /data/data/com.termux/files/home/bin
$./main -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml
```
Here is a demo of an interactive session running on Pixel 5 phone:
Here's a demo of an interactive session running on Pixel 5 phone:
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
#### Build on Android using Termux
[Termux](https://github.com/termux/termux-app#installation) is an alternative to execute `llama.cpp` on an Android device (no root required).
```
apt update && apt upgrade -y
apt install git
```
It's recommended to move your model inside the `~/` directory for best performance:
```
cd storage/downloads
mv model.gguf ~/
```
[Follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
### Docker
#### Prerequisites
Expand Down
11 changes: 6 additions & 5 deletions ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ function gg_run_test_scripts_debug {

set -e

# TODO: too slow, run on dedicated node
#(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
#(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log

set +e
}
Expand Down Expand Up @@ -695,8 +694,10 @@ test $ret -eq 0 && gg_run ctest_release
if [ -z ${GG_BUILD_LOW_PERF} ]; then
test $ret -eq 0 && gg_run embd_bge_small

test $ret -eq 0 && gg_run test_scripts_debug
test $ret -eq 0 && gg_run test_scripts_release
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
test $ret -eq 0 && gg_run test_scripts_debug
test $ret -eq 0 && gg_run test_scripts_release
fi

if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then
if [ -z ${GG_BUILD_CUDA} ]; then
Expand Down
5 changes: 5 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.instruct = true;
return true;
}
if (arg == "-cnv" || arg == "--conversation") {
params.conversation = true;
return true;
}
if (arg == "-cml" || arg == "--chatml") {
params.chatml = true;
return true;
Expand Down Expand Up @@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --version show version and build info\n");
printf(" -i, --interactive run in interactive mode\n");
printf(" --interactive-first run in interactive mode and wait for input right away\n");
printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n");
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct gpt_params {
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
Expand Down
5 changes: 5 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

result->prev.resize(params.n_prev);

result->n_considered = 0;

llama_sampling_set_rng_seed(result, params.seed);

return result;
Expand Down Expand Up @@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_considered = 0;
}

void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
Expand Down Expand Up @@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
}
}

ctx_sampling->n_considered = cur_p.size;

return id;
}

Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
size_t n_considered;

std::mt19937 rng;
};
Expand Down
5 changes: 5 additions & 0 deletions convert-hf-to-gguf-update.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
]

# make directory "models/tokenizers" if it doesn't exist
Expand Down Expand Up @@ -150,6 +153,8 @@ def download_file_with_auth(url, token, save_path):
# print the "pre_tokenizer" content from the tokenizer.json
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
normalizer = cfg["normalizer"]
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
pre_tokenizer = cfg["pre_tokenizer"]
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))

Expand Down
14 changes: 12 additions & 2 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,15 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
res = "command-r"
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
res = "qwen2"
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
res = "olmo"
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
# ref: https://huggingface.co/databricks/dbrx-instruct
res = "dbrx"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -2248,8 +2257,9 @@ class OlmoModel(Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_layer_norm_eps(1e-5)
if "clip_qkv" in self.hparams is not None:
self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"])
clip_qkv = self.hparams.get("clip_qkv")
if clip_qkv is not None:
self.gguf_writer.add_clamp_kqv(clip_qkv)

# Same as super class, but permuting q_proj, k_proj
# Copied from: LlamaModel
Expand Down
51 changes: 32 additions & 19 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,25 +1512,27 @@ def main(args_in: list[str] | None = None) -> None:
if args.big_endian:
endianess = gguf.GGUFEndian.BIG

params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
msg = """\
The model doesn't have a context size, and you didn't specify one with --ctx
Please specify one with --ctx:
- LLaMA v1: --ctx 2048
- LLaMA v2: --ctx 4096"""
parser.error(textwrap.dedent(msg))
params.n_ctx = args.ctx

if args.outtype:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype]

logger.info(f"params = {params}")
params = None
if args.pad_vocab or not args.vocab_only:
params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
msg = """\
The model doesn't have a context size, and you didn't specify one with --ctx
Please specify one with --ctx:
- LLaMA v1: --ctx 2048
- LLaMA v2: --ctx 4096"""
parser.error(textwrap.dedent(msg))
params.n_ctx = args.ctx

if args.outtype:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype]

logger.info(f"params = {params}")

model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
Expand All @@ -1543,6 +1545,17 @@ def main(args_in: list[str] | None = None) -> None:
if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile
if params is None:
params = Params(
n_vocab = vocab.vocab_size,
n_embd = 1,
n_layer = 1,
n_ctx = 1,
n_ff = 1,
n_head = 1,
n_head_kv = 1,
f_norm_eps = 1e-5,
)
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
endianess=endianess, pad_vocab=args.pad_vocab)
logger.info(f"Wrote {outfile}")
Expand Down
2 changes: 1 addition & 1 deletion docs/BLIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Install BLIS:
sudo make install
```

We recommend using openmp since it's easier to modify the cores been used.
We recommend using openmp since it's easier to modify the cores being used.

### llama.cpp compilation

Expand Down
4 changes: 2 additions & 2 deletions docs/HOWTO-add-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc

This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.

Have a look to existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.

When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support of missing backend operations can be added in another PR.
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.

Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback).

Expand Down
2 changes: 1 addition & 1 deletion examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);

auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
} else if (a->type == GGML_TYPE_F32) {
return ggml_add(ctx, a, b);
Expand Down
Loading

0 comments on commit bffdaf4

Please sign in to comment.