Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/intel_extension_for_transformers/llm/runtime/graph/*
!/intel_extension_for_transformers/llm/runtime/graph/*.*
!/intel_extension_for_transformers/llm/runtime/graph/*/
### ignore binary files in llm-runtime ###

*.pyc
.vscode
.idea
Expand All @@ -11,6 +16,7 @@
*.log
*.swp
*.onnx
*.bin
tags
build/
_build
Expand All @@ -32,6 +38,8 @@ CMakeUserPresets.json

/intel_extension_for_transformers/llm/runtime/.vs
/intel_extension_for_transformers/llm/runtime/out
/intel_extension_for_transformers/llm/runtime/graph/out
/intel_extension_for_transformers/llm/runtime/graph/runtime_outs
/examples/**/*.npy
/examples/**/*.bin
/examples/**/*.yaml
Expand Down
5 changes: 4 additions & 1 deletion intel_extension_for_transformers/llm/runtime/graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class StopOnTokens(StoppingCriteria):
self.min_length = min_length
self.start_length = start_length
self.stop_token_id = stop_token_id

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
Expand All @@ -369,3 +369,6 @@ stopping_criteria = StoppingCriteriaList(

outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_criteria)
```

### 6. Perplexity (measuring model quality)
You can use the [scripts/perplexity.py](./scripts/perplexity.py) script to over a given (subset of) dataset. Run `python scripts/perplexity.py --help` for detailed usage. For more infomation of the perplexity metric, see https://huggingface.co/docs/transformers/perplexity.
112 changes: 70 additions & 42 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers import AutoConfig, AutoTokenizer
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model

import torch
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model
from transformers import AutoConfig, AutoTokenizer

model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}


class Model:
def __init__(self):
self.module = None
Expand All @@ -28,55 +31,68 @@ def __init__(self):
self.bin_file = None
self.generate_round = 0

def __import_package(self, model_name):
def __import_package(self, model_type):
if self.module:
return
if model_name == "gptj":
if model_type == "gptj":
import intel_extension_for_transformers.llm.runtime.graph.gptj_cpp as cpp_model
elif model_name == "falcon":
elif model_type == "falcon":
import intel_extension_for_transformers.llm.runtime.graph.falcon_cpp as cpp_model
elif model_name == "gptneox":
elif model_type == "gptneox":
import intel_extension_for_transformers.llm.runtime.graph.gptneox_cpp as cpp_model
elif model_name == "dolly":
elif model_type == "dolly":
import intel_extension_for_transformers.llm.runtime.graph.dolly_cpp as cpp_model
elif model_name == "llama" or model_name == "llama2":
elif model_type == "llama" or model_type == "llama2":
import intel_extension_for_transformers.llm.runtime.graph.llama_cpp as cpp_model
elif model_name == "mpt":
elif model_type == "mpt":
import intel_extension_for_transformers.llm.runtime.graph.mpt_cpp as cpp_model
elif model_name == "gpt_bigcode" or model_name == "starcoder":
elif model_type == "gpt_bigcode" or model_type == "starcoder":
import intel_extension_for_transformers.llm.runtime.graph.starcoder_cpp as cpp_model
elif model_name == "opt":
elif model_type == "opt":
import intel_extension_for_transformers.llm.runtime.graph.opt_cpp as cpp_model
elif model_name == "bloom":
elif model_type == "bloom":
import intel_extension_for_transformers.llm.runtime.graph.bloom_cpp as cpp_model
elif model_name == "chatglm":
elif model_type == "chatglm":
import intel_extension_for_transformers.llm.runtime.graph.chatglm_cpp as cpp_model
elif model_name == "chatglm2":
elif model_type == "chatglm2":
import intel_extension_for_transformers.llm.runtime.graph.chatglm2_cpp as cpp_model
elif model_name == "baichuan":
elif model_type == "baichuan":
import intel_extension_for_transformers.llm.runtime.graph.baichuan_cpp as cpp_model
elif model_name == "polyglot":
elif model_type == "polyglot":
import intel_extension_for_transformers.llm.runtime.graph.polyglot_cpp as cpp_model
elif model_name == "mistral":
elif model_type == "mistral":
import intel_extension_for_transformers.llm.runtime.graph.mistral_cpp as cpp_model
else:
raise TypeError("Unspported model type {}!".format(model_name))
raise TypeError("Unspported model type {}!".format(model_type))
self.module = cpp_model

@staticmethod
def get_model_type(model_config):
model_type = model_maps.get(model_config.model_type, model_config.model_type)
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
model_type = "chatglm2"
return model_type

def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_type = model_maps.get(self.config.model_type, self.config.model_type)
if model_type == "chatglm" and "chatglm2" in self.config._name_or_path:
model_type = "chatglm2"
model_type = Model.get_model_type(self.config)
self.__import_package(model_type)

# check cache and quantization
output_path = "runtime_outs"
if not os.path.exists(output_path):
os.makedirs(output_path)
os.makedirs(output_path, exist_ok=True)
fp32_bin = "{}/ne_{}_f32.bin".format(output_path, model_type)
quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type)
quant_desc = quant_kwargs['weight_dtype']
if quant_kwargs['use_ggml']:
quant_desc += "_ggml"
else:
quant_desc += "_jblas_c" + quant_kwargs['compute_dtype']
if quant_kwargs['group_size'] == -1:
quant_desc += "_pc"
else:
quant_desc += "_g{}".format(quant_kwargs['group_size'])
quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc)

if not_quant:
self.bin_file = fp32_bin
Expand All @@ -85,20 +101,22 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
if use_cache and os.path.exists(self.bin_file):
return

convert_model(model_name, fp32_bin, "f32")
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
if not use_cache or not os.path.exists(fp32_bin):
convert_model(model_name, fp32_bin, "f32")
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"

if not_quant:
print("FP32 model will be used.")
return
self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **quant_kwargs)
self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin, **quant_kwargs)
assert os.path.exists(quant_bin), "Fail to quantize model"

# clean
os.remove(fp32_bin)
if not use_cache:
os.remove(fp32_bin)

def init_from_bin(self, model_name, model_path, **generate_kwargs):
self.__import_package(model_name)
def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.__import_package(model_type)
self.model = self.module.Model()
if "threads" not in generate_kwargs:
threads = os.getenv("OMP_NUM_THREADS")
Expand All @@ -108,11 +126,9 @@ def init_from_bin(self, model_name, model_path, **generate_kwargs):
generate_kwargs["threads"] = int(threads)
self.model.init_model(model_path, **generate_kwargs)

def quant_model(self, model_name, model_path, out_path, **quant_kwargs):
self.__import_package(model_name)
self.module.Model.quant_model(model_path = model_path,
out_path = out_path, **quant_kwargs)

def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
self.__import_package(model_type)
self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs)

def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs):
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
Expand All @@ -129,8 +145,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
ret = input_ids.tolist()

beam_search = False
if ("num_beams" in generate_kwargs and generate_kwargs["num_beams"] > 1) and not \
generate_kwargs.get("do_sample", False):
if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False):
beam_search = True
if not beam_search:
# TODO support multi batch
Expand All @@ -142,30 +157,43 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
Make sure that `num_beams` is set to 1."
if self.generate_round == 0 and not ignore_prompt:
streamer.put(input_ids)

if interactive:
self.model.reset_token_end()
out_count = 0
input_list = input_ids.tolist()
while True:
response = self.model.generate(input_ids = input_ids.tolist())
response = self.model.generate(input_ids=input_list)
input_list = [] # next-token stage will use previous output
if len(response) == 0:
break
if streamer:
streamer.put(torch.tensor([response[0]]))
for i in range(len(response)):
ret[i].extend(response[i])
if beam_search:
break
if stopping_criteria is not None:
if stopping_criteria(torch.tensor(ret), None):
break
elif ret[0][-1] == self.tokenizer.eos_token_id or \
(max_new_tokens != -1 and out_count > max_new_tokens):
(max_new_tokens != -1 and out_count > max_new_tokens):
break
out_count += 1
if streamer:
streamer.end()

self.generate_round += 1
return ret

def is_token_end(self):
return self.model.is_token_end()

def __call__(self, input_ids, reinit=False, **kwargs):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
self.generate_round = 0
elif reinit:
self.model.reinit()
self.generate_round = 0
return self.model.evaluate(input_ids.tolist())
Loading