Skip to content

Commit

Permalink
[Engine] Inner Product FP8 weight compression format dispatch for LLM (
Browse files Browse the repository at this point in the history
  • Loading branch information
zhentaoyu committed May 8, 2023
1 parent e30ed31 commit 0065db0
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 6 deletions.
10 changes: 10 additions & 0 deletions examples/huggingface/pytorch/text-generation/deployment/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ python gen_ir.py --model=EleutherAI/gpt-j-6B --dtype=int8 --output_model=<path t
# support single socket and multiple sockets
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_gptj.py --max-new-tokens 32 --input-tokens 32 --batch-size 1 --ir_path <path to ir>
```

Neural Engine supports FP8 weight compression **only when runing bf16 graph**. If you want to try, please add arg `--fp8_weight`, like:
```bash
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_gptj.py --max-new-tokens 32 --input-tokens 32 --batch-size 1 --ir_path <path to bf16 ir> --fp8_weight
```

For now, there has three weight storage format types in Neural Engine: `fp8_4e3m`, `fp8_5e2m` and `int8` (default type if you enable `fp8_weight`). You can select weight storage format type with `--fp8_weight_type` arg. For example:
```bash
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_gptj.py --max-new-tokens 32 --input-tokens 32 --batch-size 1 --ir_path <path to bf16 ir> --fp8_weight --fp8_weight_type=fp8_4e3m
```
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
parser.add_argument('--prompt', default=None, type=str)
parser.add_argument('--greedy', default=None, type=str)
parser.add_argument('--batch-size', default=1, type=int)
parser.add_argument('--fp8_weight', action="store_true")
parser.add_argument('--fp8_weight_type', default='int8', type=str)
args = parser.parse_args()
print(args)

Expand Down Expand Up @@ -64,9 +66,15 @@
num_warmup = 4


from intel_extension_for_transformers.backends.neural_engine.compile import compile
graph = compile(args.ir_path)
from intel_extension_for_transformers.backends.neural_engine.compile import compile, autocast
print("Using IR file {}".format(args.ir_path))
if args.fp8_weight:
with autocast('bf16', weight_dtype=args.fp8_weight_type):
print("Using FP8 weight which has storage type {} and make sure your IR is BF16 type".format(
args.fp8_weight_type))
graph = compile(args.ir_path)
else:
graph = compile(args.ir_path)
import numpy as np

prompt = [prompt] * args.batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
parser.add_argument('--input-tokens', default='32', type=str)
parser.add_argument('--prompt', default=None, type=str)
parser.add_argument('--batch-size', default=1, type=int)
parser.add_argument('--fp8_weight', action="store_true")
parser.add_argument('--fp8_weight_type', default='int8', type=str)
args = parser.parse_args()
print(args)

Expand Down Expand Up @@ -59,9 +61,14 @@
num_warmup = 4


from intel_extension_for_transformers.backends.neural_engine.compile import compile
graph = compile(args.ir_path)
print("Using IR file {}".format(args.ir_path))
from intel_extension_for_transformers.backends.neural_engine.compile import compile, autocast
if args.fp8_weight:
with autocast('bf16', weight_dtype=args.fp8_weight_type):
print("Using FP8 weight which has storage type {} and make sure your IR is BF16 type".format(
args.fp8_weight_type))
graph = compile(args.ir_path)
else:
graph = compile(args.ir_path)
import numpy as np

prompt = [prompt] * args.batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np
from . import graph_utils as util
from copy import deepcopy
from .optimizer import Optimizer

COMPILES = OrderedDict({
'loader': Loader,
Expand All @@ -36,14 +37,19 @@

class autocast:

def __init__(self, cast_type: str) -> None:
def __init__(self, cast_type: str, *args, **kwargs) -> None:
util.autocast_init()
self.prev_cast_type = util.get_autocast_info()['cast_type']
self.cast_type = cast_type
self.weight_dtype = None
if 'weight_dtype' in kwargs:
self.weight_dtype = kwargs['weight_dtype']

def __enter__(self) -> None:
self.prev_cast_type = util.get_autocast_info()['cast_type']
util.set_autocast("cast_type", self.cast_type)
if self.weight_dtype:
util.set_autocast("weight_dtype", self.weight_dtype)

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
util.set_autocast("cast_type", self.prev_cast_type)
Expand Down Expand Up @@ -106,6 +112,8 @@ def compile(model, config=None) -> Graph:
else:
config = _config_validation(config)
model = start_pipeline(model, config=config)
optimizer = Optimizer(model)
optimizer.optimize()
if util.get_autocast_info()['cast_type'] == "dynamic_int8":
model = _dynamic_quantization(model)
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,29 @@ def get_quant_info():
return _quant_info


def environ_info_init():
"""Initialize the environ info."""
global _environ_info
_environ_info = {}

def insert_environ_info(key, value):
"""Modify the environ info."""
_environ_info[key] = value

def remove_environ_info_item(key):
"""Remove an item in environ info."""
_environ_info.pop(key, None)

def remove_environ_info_items(keys):
"""Remove a list of items in environ info."""
for key in keys:
remove_environ_info_item(key)

def get_environ_info():
"""Get the environ info."""
return _environ_info


def search_straight_pattern(input_pattern, graph):
"""Search user specified patterns on internal grpah structure.
Expand Down Expand Up @@ -1151,3 +1174,23 @@ def _is_neural_engine(model):
assert fwk_name != 'NA', 'Framework is not detected correctly from model format.'

return fwk_name

def set_environ_var(key, val='1'):
"""Set an env var."""
assert type(val) == str, 'Environment variable must be string!'
os.environ[key] = val

def set_environ_vars(kvs):
"""Set a list of env vars."""
for key, val in kvs.items():
set_environ_var(key, val)

def del_environ_var(key):
"""Delete an env var."""
if key in os.environ:
del os.environ[key]

def del_environ_vars(keys):
"""Delete a list of env vars."""
for key in keys:
del_environ_var(key)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The neural engine optimizer module."""

from .graph import Graph
from . import graph_utils as util
from . import logger

OPTIMIZED_WEIGHT_FORMAT_TAG = {'FP8': ['ANY', 'INT8', 'FP8_4E3M', 'FP8_5E2M']}


class Optimizer:
"""The defintion of the neural engine optimizer."""

def __init__(self, graph, input_shape=None, *args, **kwargs):
"""The optimizer initialization.
Args:
graph: neural engine Graph class
input_shape: list of list, model input data shape list
"""
assert isinstance(graph, Graph), 'graph must be an instance of Graph class'
self.graph = graph
self.input_shape = input_shape
self.cast_dtype = util.get_autocast_info()['cast_type']
self.weight_dtype = util.get_autocast_info().get('weight_dtype', 'native')
try:
util.get_environ_info()
except:
util.environ_info_init()

def optimize(self):
"""Optimize the graph."""
self.weight_optimization()
# Set env vars before inference. These env vars could help accelerate inference speed.
util.set_environ_vars(util.get_environ_info())

def weight_optimization(self):
"""Optimize weight format."""
if self.cast_dtype == 'bf16' and self.weight_dtype.upper() in \
OPTIMIZED_WEIGHT_FORMAT_TAG['FP8']:
self._weight_fp8_dispatch(self.weight_dtype.upper())

def _weight_fp8_dispatch(self, w_tag):
"""Optimize BF16 graph by using FP8 weight format."""
tag2env = {'INT8': 'NE_WEIGHT_INT8', 'FP8_4E3M': 'NE_WEIGHT_FP8_4E3M',
'FP8_5E2M': 'NE_WEIGHT_FP8_5E2M'}
util.del_environ_vars(list(tag2env.values()))
util.remove_environ_info_items(list(tag2env.values()))
if w_tag == 'ANY':
# TODO: Consider to add best fp8 weight format search
best_tag = 'INT8'
logger.info('Using FP8 weight storage format {} for BF16 model inference'.format(
best_tag))
util.insert_environ_info(tag2env[best_tag], '1')
elif w_tag in tag2env:
env_key = tag2env[w_tag]
logger.info('Using FP8 weight storage format {} for BF16 model inference'.format(
w_tag))
util.insert_environ_info(env_key, '1')
else:
logger.warning('Unknown FP8 weight compression format, please use {}'.format(
OPTIMIZED_WEIGHT_FORMAT_TAG['FP8']))
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import shutil
from intel_extension_for_transformers.backends.neural_engine.compile.ops.op import OPERATORS
from intel_extension_for_transformers.backends.neural_engine.compile.ops.tensor import Tensor
from intel_extension_for_transformers.backends.neural_engine.compile.graph import Graph
from intel_extension_for_transformers.backends.neural_engine.compile import compile, autocast
import copy


def fp32_to_bf16(fp32_np):
if fp32_np.dtype == np.int16:
return fp32_np
tmp = copy.deepcopy(fp32_np)
tmp = tmp.view(dtype=np.int32)
tmp = tmp >> 16
tmp = tmp.astype(np.int16)
return tmp

class TestExecutionOptions(unittest.TestCase):
@classmethod
def setUpClass(self):

self.ir_path = 'optimizer_ir'
graph = Graph()
input_data_node = OPERATORS['Input']()
input_tensors = []
output_tensors = [Tensor(name="activation", shape=[-1, -1], dtype="bf16")]
input_data_node.construct('input_data', 'Input', input_tensors=input_tensors,
output_tensors=output_tensors)
ip_node = OPERATORS['InnerProduct']()
input_tensors = [Tensor(name="activation", shape=[-1, -1], dtype="bf16"),
Tensor(name="weight", shape=[256, 256], dtype="bf16",
data=fp32_to_bf16(np.random.randn(256, 256).astype(np.float32))),
Tensor(name="bias", shape=[256], dtype="bf16",
data=fp32_to_bf16(np.random.randn(256).astype(np.float32)))]
output_tensors = [Tensor(name='ip:0', source_op=['ip'], dest_op=['output_data'])]
ip_node.construct('ip', 'InnerProduct', input_tensors=input_tensors,
output_tensors=output_tensors)
output_node = OPERATORS['Output']()
input_tensors = [Tensor(name='ip:0', source_op=['ip'], dest_op=['output_data'])]
output_tensors = []
output_node.construct('output_data', 'Output', input_tensors=input_tensors,
output_tensors=output_tensors)
graph.insert_nodes(len(graph.nodes), [input_data_node, ip_node, output_node])
graph.save(self.ir_path)
del graph

@classmethod
def tearDownClass(self):
shutil.rmtree(self.ir_path)

def test_fp8_weight_compression(self):
graph = None
data = fp32_to_bf16(np.random.randn(128, 256).astype(np.float32))
graph = compile(self.ir_path)
g_ret = copy.deepcopy(graph.inference([data])['ip:0'])
fp8_ret = []
for w_tag in ['any', 'int8', 'fp8_5e2m', 'fp8_4e3m']:
with autocast('bf16', weight_dtype= w_tag):
graph = compile(self.ir_path)
fp8_ret.append(copy.deepcopy(graph.inference([data])['ip:0']))

flag = True
for ret in fp8_ret:
flag = np.allclose(g_ret, ret, atol=1e0, equal_nan=True)
if not flag:
break
self.assertTrue(flag)

if __name__ == "__main__":
unittest.main()

0 comments on commit 0065db0

Please sign in to comment.