Skip to content

Commit

Permalink
[Engine]: Enbale gpt neox and dolly (#939)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel committed Jun 19, 2023
1 parent 2580f3c commit 402bb90
Show file tree
Hide file tree
Showing 39 changed files with 3,115 additions and 112 deletions.
7 changes: 3 additions & 4 deletions examples/.config/engine_deploy.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -553,16 +553,15 @@
"model": "/tf_dataset2/models/nlp_toolkit/llama-7b-hf",
"dtype": "fp32/bf16/int8",
"output_model": "ir",
"pt_file": "pt",
"model_type": "llama_7b"
"pt_file": "pt"
}
},
"benchmark": {
"cmd": "python run_llm.py",
"params": {
"max-new-tokens": "32",
"model_path": "ir",
"model_type": "llama_7b"
"model": "decapoda-research/llama-7b-hf",
"model_path": "ir"
}
},
"launcher": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ python optimize_llm.py --model=EleutherAI/gpt-j-6B --dtype=(fp32|bf16) --output_

# int8
wget https://huggingface.co/Intel/gpt-j-6B-pytorch-int8-static/resolve/main/pytorch_model.bin -O <path to int8_model.pt>
python gen_ir.py --model=EleutherAI/gpt-j-6B --dtype=int8 --output_model=<path to ir> --pt_file=<path to int8_model.pt>
python optimize_llm.py --model=EleutherAI/gpt-j-6B --dtype=int8 --output_model=<path to ir> --pt_file=<path to int8_model.pt>
```
- When the input dtype is fp32 or bf16, the model will be downloaded if it does not exist.
- When the input dtype is int8, the int8 trace model should exist.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
# Load past kv caches from files
import numpy as np
import pickle
from optimum.utils import NormalizedConfigManager


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -551,7 +553,7 @@ def _prepare_model_inputs(
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
if not model_kwargs["llama"]:
if model_kwargs['model_type'] != 'llama':
# 3. models with `input_ids` can also make use of `inputs_embeds`
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
Expand Down Expand Up @@ -2854,6 +2856,7 @@ def beam_search(
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only

while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
Expand All @@ -2866,7 +2869,8 @@ def beam_search(
break

model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
if not model_kwargs["llama"]:

if model_kwargs['model_type'] != "llama":
if model_inputs["past_key_values"] is None:
first_token = model_inputs["input_ids"].size()[1] != 1
if first_token:
Expand All @@ -2876,10 +2880,8 @@ def beam_search(
model_inputs["input_ids"] = model_inputs["input_ids"][:1,:]
input_ids_1 = model_inputs['input_ids'].cpu().numpy().astype(np.int32)
attention_mask_1 = model_inputs['attention_mask'].cpu().numpy().astype(np.int32)

past_k_v = np.ones([1,0,16,256]).astype(np.float32)
predictions = engine_model.inference([input_ids_1] + [past_k_v for _ in range(2 * model_kwargs["past_kv_nums"])] + [attention_mask_1])

past_k_v = np.zeros([1,0,model_kwargs['num_attention_heads'],model_kwargs['d_k']]).astype(np.float32)
predictions = engine_model.inference([input_ids_1] + [past_k_v for _ in range(2 * model_kwargs['past_kv_nums'])] + [attention_mask_1])
for key in predictions:
predictions[key] = torch.from_numpy(predictions[key])

Expand All @@ -2901,28 +2903,22 @@ def beam_search(
value = value.view(value.size(1) * value.size(0), value.size(2), value.size(3))
past_key_values.append(tuple([key, value]))
outputs.past_key_values = tuple(past_key_values)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]

else:
example_inputs = []
for k, v in model_inputs.items():
if v is not None and not isinstance(v, bool):
example_inputs.append(v)
example_inputs = tuple(example_inputs)

input_ids_1 = example_inputs[0].cpu().numpy().astype(np.int32)
attention_mask_1 = example_inputs[-1].cpu().numpy().astype(np.int32)
past_key_values = [example_inputs[1][i][j] for i in range(model_kwargs["past_kv_nums"]) for j in range(2)]
input_ids_1 = model_inputs['input_ids'].cpu().numpy().astype(np.int32)
attention_mask_1 = model_inputs['attention_mask'].cpu().numpy().astype(np.int32)
past_key_values = [model_inputs['past_key_values'][i][j] for i in range(model_kwargs["past_kv_nums"]) for j in range(2)]
predictions = engine_model.inference([input_ids_1] + past_key_values + [attention_mask_1])

# ts=time.time()
for key in predictions:
predictions[key] = torch.from_numpy(predictions[key])
outputs = CausalLMOutputWithPast()
outputs.logits = list(predictions.values())[0].reshape(-1,1,50400)
outputs.logits = list(predictions.values())[0].reshape(-1,1,model_kwargs['vocab_size'])
outputs.past_key_values = [(list(predictions.values())[2*i+1], list(predictions.values())[2*i+2]) for i in range(model_kwargs["past_kv_nums"])]

# print(2,time.time()-ts)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"pattern_switch": {"MatMulWithTranspose": false, "RemoveLastView": false, "NeoxReorderChange": true, "NeoxRoraryPosEmb": true, 'MultiHeadAttention': true}}
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,38 @@
import argparse
import os
import sys
from optimum.utils import NormalizedConfigManager

class Net(torch.nn.Module):
def __init__(self, ori_model):
super(Net, self).__init__()
self.model = ori_model
def forward(self, input_ids, pastkv, mask):
return self.model(input_ids=input_ids, attention_mask=mask, past_key_values=pastkv, return_dict=False)

parser = argparse.ArgumentParser('GPT-J Generation ir', add_help=False)
parser.add_argument("--model",
type=str,
help="path to bfloat16 or int8 IR files",
help="path to original config and weight files",
default="EleutherAI/gpt-j-6B",
)
parser.add_argument('--dtype', default=None, type=str)
parser.add_argument('--output_model', default="./ir", type=str)
parser.add_argument('--model_type', default="gpt-j", type=str)
parser.add_argument('--pt_file', type=str)
args = parser.parse_args()
print(args)

model_id = args.model
model_type = args.model_type
model = AutoModelForCausalLM.from_pretrained(model_id, return_dict=False)
model.eval()

normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config)
num_layers = normalized_config.num_layers
num_attention_heads = normalized_config.num_attention_heads
hidden_size = normalized_config.hidden_size
d_k = hidden_size // num_attention_heads
model_type = model.config.model_type

if 'llama' in model_type:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_id)
Expand All @@ -28,54 +44,59 @@
prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \
" She wanted to go to places and meet new people, and have fun."
init_input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]
input_ids = init_input_ids.clone()
attention_mask = torch.ones(len(input_ids)+1)
attention_mask[0] = 0
past_key_value_torch = tuple([(torch.zeros([1,16,32,256]), torch.zeros([1,16,32,256])) for i in range(28)])
input_ids = input_ids[0:1].unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
input_ids = init_input_ids.clone().unsqueeze(0)
attention_mask = torch.ones(len(input_ids)).unsqueeze(0)
past_key_value = tuple([(torch.zeros([1,num_attention_heads,0,d_k]),
torch.zeros([1,num_attention_heads,0,d_k])) for i in range(num_layers)])

if 'llama' in model_type:
input_ids = init_input_ids.clone()
attention_mask = torch.ones(len(input_ids)+1)
attention_mask[0] = 0
input_ids = input_ids[0:1].unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
past_key_value = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)])
if 'llama_13b' in model_type:
past_key_value = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)])

traced_model = None
if 'llama' in model_type:
past_key_value_torch = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)])
if 'llama_13b' in model_type:
past_key_value_torch = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)])

if args.pt_file and os.path.exists(args.pt_file):
print('PT model exists, compile will be executed.')
del model
traced_model = torch.jit.load(args.pt_file)
else:
model = AutoModelForCausalLM.from_pretrained(model_id, return_dict=False)
model.eval()
if args.dtype in ['fp32', 'bf16']:
if 'llama' in model_type:
traced_model = torch.jit.trace(model, (input_ids, attention_mask, past_key_value_torch))
print("Traced model is saved as {}".format(args.pt_file))
else:
traced_model = torch.jit.trace(model, (input_ids, past_key_value_torch, attention_mask))
print("Traced model is saved as {}".format(args.pt_file))
assert args.dtype in ['fp32', 'bf16'], "Model with {} can't be traced, please provide one.".format(args.dtype)
if 'llama' in model_type:
net = model
traced_model = torch.jit.trace(net, (input_ids, attention_mask, past_key_value))
else:
print("Model with {} can't be traced, please provide one.".format(args.dtype))
sys.exit(1)
net = Net(model)
traced_model = torch.jit.trace(net, (input_ids, past_key_value, attention_mask))

from intel_extension_for_transformers.backends.neural_engine.compile import compile, autocast
if 'llama' not in model_type:
if 'llama' in model_type:
if args.dtype == "bf16":
with autocast("bf16"):
graph = compile(traced_model)
graph = compile(traced_model, './llama_pattern.conf')
elif args.dtype == "int8":
graph = compile(traced_model, './int8_pattern.conf')
graph = compile(traced_model, './llama_int8_pattern.conf')
else:
graph = compile(traced_model)
graph = compile(traced_model, './llama_pattern.conf')
elif 'gpt_neox' in model_type:
if args.dtype == "bf16":
with autocast("bf16"):
graph = compile(traced_model, './gpt_neox_pattern.conf')
else:
graph = compile(traced_model, './gpt_neox_pattern.conf')
else:
if args.dtype == "bf16":
with autocast("bf16"):
graph = compile(traced_model, './llama_pattern.conf')
graph = compile(traced_model)
elif args.dtype == "int8":
graph = compile(traced_model, './llama_int8_pattern.conf')
graph = compile(traced_model, './int8_pattern.conf')
else:
graph = compile(traced_model, './llama_pattern.conf')
graph = compile(traced_model)

graph.save(args.output_model)
print('Neural Engine ir is saved as {}'.format(args.output_model))
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
transformers==4.27.4
torch==2.0
accelerate
sentencepiece
sentencepiece
optimum
40 changes: 22 additions & 18 deletions examples/huggingface/pytorch/text-generation/deployment/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.profiler import profile, record_function, ProfilerActivity
from accelerate import init_empty_weights
import generation_utils as itrex_generation_utils
from optimum.utils import NormalizedConfigManager

# args
parser = argparse.ArgumentParser('GPT-J generation script', add_help=False)
Expand All @@ -17,39 +18,42 @@
help="path to bfloat16 or int8 IR files",
default="bfloat16",
)
parser.add_argument("--model",
type=str,
help="path to original config and weight files",
default="EleutherAI/gpt-j-6B",
)
parser.add_argument('--max-new-tokens', default=32, type=int, help="output max new tokens")
parser.add_argument('--input-tokens', default='32', type=str)
parser.add_argument('--prompt', default=None, type=str)
parser.add_argument('--batch-size', default=1, type=int)
parser.add_argument('--weight_type', default=None, type=str)
parser.add_argument('--model_type', default='gpt-j', type=str)
args = parser.parse_args()
print(args)

generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
if args.model_type == 'llama_7b':
generate_kwargs["past_kv_nums"] = 32
generate_kwargs["llama"] = True
model_id = "decapoda-research/llama-7b-hf"
from transformers import LlamaForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_id)
prompt_json = '/llamaprompt.json'
elif args.model_type == 'llama_13b':
generate_kwargs["past_kv_nums"] = 40
generate_kwargs["llama"] = True
model_id = "decapoda-research/llama-13b-hf"
from transformers import LlamaForCausalLM, LlamaTokenizer

model_id = args.model
config = AutoConfig.from_pretrained(model_id)
model_type = config.model_type
normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
num_attention_heads = normalized_config.num_attention_heads
hidden_size = normalized_config.hidden_size
generate_kwargs["past_kv_nums"] = normalized_config.num_layers
generate_kwargs["model_type"] = model_type
generate_kwargs["num_attention_heads"] = num_attention_heads
generate_kwargs["d_k"] = hidden_size // num_attention_heads
generate_kwargs["vocab_size"] = normalized_config.vocab_size

if 'llama' in model_type:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_id)
prompt_json = '/llamaprompt.json'
elif args.model_type == 'gpt-j':
generate_kwargs["past_kv_nums"] = 28
generate_kwargs["llama"] = False
model_id = "EleutherAI/gpt-j-6B"
else:
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt_json = '/prompt.json'

# load model
config = AutoConfig.from_pretrained(model_id)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
setattr(model, "generate", types.MethodType(itrex_generation_utils.GenerationMixin.generate, model))
Expand Down

0 comments on commit 402bb90

Please sign in to comment.