Skip to content

Latest commit

 

History

History
 
 

LLM Runtime

LLM Runtime is designed to provide the efficient inference of large language models (LLMs) on Intel platforms through the state-of-the-art (SOTA) model compression techniques. The work is highly inspired from llama.cpp, which organizes almost all the core code (e.g., kernels) in a single big file with a large number of pre-defined macros, thus making it not easy for developers to support a new model. Our LLM Runtime has the following features:

  • Modular design to support new models
  • Highly optimized low precision kernels
  • Utilize AMX, VNNI, AVX512F and AVX2 instruction set
  • Support CPU (x86 platforms only) and initial (Intel) GPU
  • Support 4bits and 8bits quantization

LLM Runtime is under active development so APIs are subject to change.

Supported Hardware

Hardware Optimization
Intel Xeon Scalable Processors
Intel Xeon CPU Max Series
Intel Core Processors
Intel Arc GPU Series WIP
Intel Data Center GPU Max Series WIP
Intel Gaudi2 Not yet

Supported Models

LLM Runtime supports the following models:

Text Generation

model name INT8 INT4
LLaMA2-7B, LLaMA2-13B, LLaMA2-70B
LLaMA-7B, LLaMA-13B
GPT-J-6B
GPT-NeoX-20B
Dolly-v2-3B
MPT-7B, MPT-30B
Falcon-7B, Falcon-40B
BLOOM-7B
OPT-125m, OPT-350m, OPT-1.3B, OPT-13B
ChatGLM-6B, ChatGLM2-6B
Baichuan-13B-Chat, Baichuan2-13B-Chat
Mistral-7B

Code Generation

model name INT8 INT4
Code-LLaMA-7B, Code-LLaMA-13B
StarCoder-1B, StarCoder-3B, StarCoder-15.5B

How to Use

There are two methods for utilizing the LLM runtime:

How to use: Transformer-based API

1. Install

Install from binary

pip install intel-extension-for-transformers

2. Run LLM with Transformer-based API

You can use Python API to run Hugging Face model simply. Here is the sample code:

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
model_name = "Intel/neural-chat-7b-v1-1"     # Hugging Face model_id or local model
prompt = "Once upon a time, there existed a little girl,"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
streamer = TextStreamer(tokenizer)

model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True)
outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300)

To enable StreamingLLM for infinite inference, here is the sample code:

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
model_name = "Intel/neural-chat-7b-v1-1"     # Hugging Face model_id or local model
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
prompt = "Once upon a time, there existed a little girl,"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
streamer = TextStreamer(tokenizer)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config)

# Paper: https://arxiv.org/pdf/2309.17453.pdf
# Recommend n_keep=4 to do attention sinks (four initial tokens) and n_discard=-1 to drop half rencetly tokens when meet length threshold
outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300, ctx_size=100, n_keep=4, n_discard=-1)
streamingLLM_v2.mp4

Argument description of WeightOnlyQuantConfig:

Argument Type Description
compute_dtype String Data type of Gemm computation: int8/bf16/fp32 (default: int8)
weight_dtype String Data type of quantized weight: int4/int8 (default int4)
alg String Quantization algorithm: sym/asym (default sym)
group_size Int Group size: Int (default: 32)
scale_dtype String Data type of scales: fp32/bf16 (dafault fp32)
use_ggml Bool Enable ggml for quantization and inference (default: False)
not_quant Bool Determine whether or not the model will be quantized. (default: False)
use_cache Bool Use local quantized model if file exists (default: False)

Argument description of generate function:

Argument Type Description
inputs Lists[Int] Input ids after tokenizer
interactive Bool Interactive mode, use history commands when True (default: False)
n_keep Int Number of tokens to keep from the initial prompt (default: 0, -1 = all)
n_discard Int Number of tokens will be discarded (default: -1, -1 = half of tokens will be discarded)
shift_roped_k Bool Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False)
ignore_prompt Bool Generate outputs w/o prompt (default: False)
batch_size Int Batch size for prompt processing (default: 512)
ctx_size Int Size of the prompt context (default: 512)
seed Int NG seed (default: -1, use random seed for < 0)
threads Int Number of threads to use during computation (default: 8)
repetition_penalty Float Please refer to Transformer's generate
num_beams Int Please refer to Transformer's generate
do_sample Int Please refer to Transformer's generate
top_k Int Please refer to Transformer's generate
top_p Int Please refer to Transformer's generate
temperature Float Please refer to Transformer's generate
min_new_tokens Int Please refer to Transformer's generate
length_penalty Float Please refer to Transformer's generate
early_stopping Bool Please refer to Transformer's generate
max_new_tokens Int Please refer to Transformer's generate
streamer Class Please refer to Transformer's generate
stopping_criteria Class Please refer to Transformer's generate

3. Multi-Round Chat

Chat with LLaMA2:

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

# Please change to local path to model, llama2 does not support online conversion, currently.
model_name = "meta-llama/Llama-2-7b-chat-hf"
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
    prompt = input("> ").strip()
    if prompt == "quit":
        break
    b_prompt = "[INST]{}[/INST]".format(prompt)  # prompt template for llama2
    inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
    outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True,
                num_beams=1, max_new_tokens=-1, ctx_size = 1024, do_sample=True, threads=28, repetition_penalty=1.1)

Chat with ChatGLM2:

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "THUDM/chatglm2-6b"  # or local path to model
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
    prompt = input("> ").strip()
    if prompt == "quit":
        break
    prompt = tokenizer.build_prompt(prompt)  # prompt template for chatglm2
    inputs = tokenizer([prompt], return_tensors="pt").input_ids
    outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True,
                num_beams=1, max_new_tokens=-1, ctx_size = 1024, do_sample=True, threads=28, repetition_penalty=1.1, n_keep=2)

How to use: Python script

Install from binary

pip install intel-extension-for-transformers

Build from source

⚠️ If you want to use from_pretrain API: please follow Transformer-based API

# Linux
# make sure your path is in intel-extension-for-transformers/intel_extension_for_transformers/llm/runtime/graph folder
git submodule update --init --recursive
mkdir build
cd build
cmake .. -G Ninja
ninja
# Windows
# Install VisualStudio 2022 and open 'Developer PowerShell for VS 2022'
# make sure your path is in intel-extension-for-transformers/intel_extension_for_transformers/llm/runtime/graph folder
mkdir build
cd build
cmake ..
cmake --build . -j

Note: add compile args -DNE_AVX512=OFF -DNE_AVX512_VBMI=OFF -DNE_AVX512_VNNI=OFF to cmake when compiling it on a CPU without AVX512

1. Run LLM with Python Script

You can run LLM with one-click python script including conversion, quantization and inference.

python scripts/run.py model-path --weight_dtype int4 -p "She opened the door and see"

Argument description of run.py:

Argument Description
model Directory containing model file or model id: String
--weight_dtype Data type of quantized weight: int4/int8 (default int4)
--alg Quantization algorithm: sym/asym (default sym)
--group_size Group size: Int (default: 32)
--scale_dtype Data type of scales: fp32/bf16 (dafault fp32)
--compute_dtype Data type of Gemm computation: int8/bf16/fp32 (default: int8)
--use_ggml Enable ggml for quantization and inference
-p / --prompt Prompt to start generation with: String (default: empty)
-n / --n_predict Number of tokens to predict: Int (default: -1, -1 = infinity)
-t / --threads Number of threads to use during computation: Int (default: 56)
-b / --batch_size_truncate Batch size for prompt processing: Int (default: 512)
-c / --ctx_size Size of the prompt context: Int (default: 512, can not be larger than specific model's context window length)
-s / --seed NG seed: Int (default: -1, use random seed for < 0)
--repeat_penalty Penalize repeat sequence of tokens: Float (default: 1.1, 1.0 = disabled)
--color Colorise output to distinguish prompt and user input from generations
--keep Number of tokens to keep from the initial prompt: Int (default: 0, -1 = all)
--shift-roped-k Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False)

Advanced Usage

Besides the one-click script, LLM Runtime also offers the detailed script: 1) convert and quantize, and 2) inference.

1. Convert and Quantize LLM

LLM Runtime assumes the compatible model format as llama.cpp and ggml. You can also convert the model by following the below steps:

# convert the model directly use model id in Hugging Face. (recommended)
python scripts/convert.py --outtype f32 --outfile ne-f32.bin EleutherAI/gpt-j-6b

# or you can download fp32 model (e.g., LLAMA2) from Hugging Face at first, then convert the pytorch model to ggml format.
git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
python scripts/convert.py --outtype f32 --outfile ne-f32.bin model_path

# To convert model with PEFT(Parameter-Efficient Fine-Tuning) adapter, you need to merge the PEFT adapter into the model first, use below command to merge the PEFT adapter and save the merged model, afterwards you can use 'scripts/convert.py' just like above mentioned.
python scripts/load_peft_and_merge.py --model_name_or_path meta-llama/Llama-2-7b-hf --peft_name_or_path dfurman/llama-2-7b-instruct-peft --save_path ./Llama-2-7b-hf-instruct-peft

# quantize weights of fp32 ggml bin
# model_name: llama, llama2, mpt, falcon, gptj, starcoder, dolly
# optimized INT4 model with group size 128 (recommended)
python scripts/quantize.py --model_name llama2 --model_file ne-f32.bin --out_file ne-q4_j.bin --weight_dtype int4 --group_size 128 --compute_dtype int8

# Alternativly you could run ggml q4_0 format like following
python scripts/quantize.py --model_name llama2 --model_file ne-f32.bin --out_file ne-q4_0.bin --weight_dtype int4 --use_ggml
# optimized INT4 model with group size 32
python scripts/quantize.py --model_name llama2 --model_file ne-f32.bin --out_file ne-q4_j.bin --weight_dtype int4 --group_size 32 --compute_dtype int8

Argument description of quantize.py:

Argument Description
--model_file Path to the fp32 model: String
--out_file Path to the quantized model: String
--build_dir Path to the build file: String
--config Path to the configuration file: String (default: "")
--nthread Number of threads to use: Int (default: 1)
--weight_dtype Data type of quantized weight: int4/int8 (default: int4)
--alg Quantization algorithm to use: sym/asym (default: sym)
--group_size Group size: Int (default: 32)
--scale_dtype Data type of scales: bf16/fp32 (default: fp32)
--compute_dtype Data type of Gemm computation: int8/bf16/fp32 (default: int8)
--use_ggml Enable ggml for quantization and inference

2. Inference LLM

We provide LLM inference script to run the quantized model. Please reach us if you want to run using C++ API directly.

# recommed to use numactl to bind cores in Intel cpus for better performance
# if you use different core numbers, please also  change -t arg value
# please type prompt about codes when run `StarCoder`, for example, -p "def fibonnaci(".
OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python scripts/inference.py --model_name llama -m ne-q4_j.bin -c 512 -b 1024 -n 256 -t 56 --color -p "She opened the door and see"

# if you want to generate fixed outputs, please set --seed arg, for example:
OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python scripts/inference.py --model_name llama -m ne-q4_j.bin -c 512 -b 1024 -n 256 -t 56 --color -p "She opened the door and see" --seed 12

# if you want to reduce repeated generated texts, please set --repeat_penalty (value > 1.0, default = 1.0), for example:
OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python scripts/inference.py --model_name llama -m ne-q4_j.bin -c 512 -b 1024 -n 256 -t 56 --color -p "She opened the door and see" --repeat_penalty 1.2

Argument description of inference.py:

Argument Description
--model_name Model name: String
-m / --model Path to the executed model: String
--build_dir Path to the build file: String
-p / --prompt Prompt to start generation with: String (default: empty)
-n / --n_predict Number of tokens to predict: Int (default: -1, -1 = infinity)
-t / --threads Number of threads to use during computation: Int (default: 56)
-b / --batch_size Batch size for prompt processing: Int (default: 512)
-c / --ctx_size Size of the prompt context: Int (default: 512, can not be larger than specific model's context window length)
-s / --seed NG seed: Int (default: -1, use random seed for < 0)
--repeat_penalty Penalize repeat sequence of tokens: Float (default: 1.1, 1.0 = disabled)
--color Colorise output to distinguish prompt and user input from generations
--keep Number of tokens to keep from the initial prompt: Int (default: 0, -1 = all)
--shift-roped-k Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False)
--glm_tokenizer The path of the chatglm tokenizer: String (default: THUDM/chatglm-6b)
--memory-f32
--memory-f16
--memory-auto
Data type of kv memory (default to auto);
If set to auto, the runtime will try with jblas flash attn managed format (currently requires GCC11+ & AMX) and fall back to fp16 if failed

3. Tensor Parallelism cross nodes/sockets

We support tensor parallelism strategy for distributed inference/training on multi-node and multi-socket. You can refer to tensor_parallelism.md to enable this feature.

4. Contribution

You can consider adding your own models via graph developer document.

5. Custom Stopping Criteria

You can customize the stopping criteria according to your own needs by processing the input_ids to determine if text generation needs to be stopped. Here is a simple example, which requires a minimum generation length of 80 tokens. Once the min_length is met, encountering a terminator eos_token_id will end the generation.

import torch
from typing import List
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnTokens(StoppingCriteria):
    def __init__(self, min_length: int, start_length: int, stop_token_id: List[int]):
        self.min_length = min_length
        self.start_length = start_length
        self.stop_token_id = stop_token_id
 
    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        if input_ids.shape[-1] - self.start_length > self.min_length:
            for stop_id in self.stop_token_id:
                if input_ids[0][input_ids.shape[-1] - 1] == stop_id:
                    return True
        return False

stopping_criteria = StoppingCriteriaList(
    [
        StopOnTokens(
            min_length=80,
            start_length=inputs.shape[1],
            stop_token_id=[tokenizer.eos_token_id],
        )
    ]
)

outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_criteria)