Skip to content

Commit

Permalink
add IPEX int8 model load func (#1357)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Oct 20, 2022
1 parent 103b4c4 commit 23c585e
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 354 deletions.
Expand Up @@ -25,33 +25,6 @@
from neural_compressor.adaptor.pytorch import get_torch_version
from packaging.version import Version

try:
try:
import intel_pytorch_extension as ipex
IPEX_110 = False
IPEX_112 = False
except:
try:
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
IPEX_110 = False
IPEX_112 = True
except:
import intel_extension_for_pytorch as ipex
import torch.fx.experimental.optimization as optimization
IPEX_110 = True
IPEX_112 = False
TEST_IPEX = True
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
except:
IPEX_110 = None
TEST_IPEX = False
model_names = sorted(name for name in quantize_models.__dict__
if name.islower() and not name.startswith("__")
and callable(quantize_models.__dict__[name]))

model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
Expand Down Expand Up @@ -142,9 +115,6 @@ def main():
args = parser.parse_args()
print(args)

if args.ipex:
assert TEST_IPEX, 'Please import intel_pytorch_extension or intel_extension_for_pytorch according to version.'

if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
Expand Down Expand Up @@ -323,13 +293,7 @@ def main_worker(gpu, ngpus_per_node, args):

if args.tune:
from neural_compressor.experimental import Quantization, common
if args.ipex:
quantizer = Quantization("./conf_ipex.yaml")
else:
model.eval()
if pytorch_version < Version("1.7.0-rc1"):
model.fuse_model()
quantizer = Quantization("./conf.yaml")
quantizer = Quantization("./conf_ipex.yaml")
quantizer.model = common.Model(model)
q_model = quantizer.fit()
q_model.save(args.tuned_checkpoint)
Expand All @@ -339,21 +303,12 @@ def main_worker(gpu, ngpus_per_node, args):
model.eval()
ipex_config_path = None
if args.int8:
if args.ipex:
if not IPEX_110 and not IPEX_112:
# TODO: It will remove when IPEX spport to save script model.
model.to(ipex.DEVICE)
try:
new_model = torch.jit.script(model)
except:
new_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to(ipex.DEVICE))
else:
new_model = model
ipex_config_path = os.path.join(os.path.expanduser(args.tuned_checkpoint),
"best_configure.json")
from neural_compressor.utils.pytorch import load
q_model = load(os.path.expanduser(args.tuned_checkpoint), model)
model = q_model
else:
new_model = model
validate(val_loader, new_model, criterion, args, ipex_config_path)
model = model
validate(val_loader, model, criterion, args, ipex_config_path)
return

for epoch in range(args.start_epoch, args.epochs):
Expand Down Expand Up @@ -434,58 +389,17 @@ def validate(val_loader, model, criterion, args, ipex_config_path=None):
prefix='Test: ')

# switch to evaluate mode
model.eval()
if args.ipex:
if not IPEX_110 and not IPEX_112:
conf = (
ipex.AmpConf(torch.int8, configure_file=ipex_config_path)
if ipex_config_path is not None
else ipex.AmpConf(None)
)
if IPEX_110:
if ipex_config_path is not None:
conf = ipex.quantization.QuantConf(configure_file=ipex_config_path)
model = optimization.fuse(model, inplace=True)
for idx, (input, label) in enumerate(val_loader):
x = input.contiguous(memory_format=torch.channels_last)
break
model = ipex.quantization.convert(model, conf, x)
else:
model = model
if IPEX_112:
if ipex_config_path is not None:
x = torch.randn(args.batch_size, 3, 224, 224).contiguous(memory_format=torch.channels_last)
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
weight= PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
prepared_model = ipex.quantization.prepare(model, qconfig, x, inplace=True)
prepared_model.load_qconf_summary(qconf_summary=ipex_config_path)
model = ipex.quantization.convert(prepared_model)
model = torch.jit.trace(model, x)
model = torch.jit.freeze(model.eval())
y = model(x)
y = model(x)
print("running int8 model\n")
else:
model = model
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
input = input.contiguous(memory_format=torch.channels_last)
if i >= args.warmup_iter:
start = time.time()
if args.gpu is not None:
input = input.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)

# compute output
if args.ipex:
if not IPEX_110 and not IPEX_112:
with ipex.AutoMixPrecision(conf, running_mode='inference'):
output = model(input.to(ipex.DEVICE))
target = target.to(ipex.DEVICE)
else:
output = model(input)
else:
output = model(input)
output = model(input)

# measure elapsed time
if i >= args.warmup_iter:
Expand Down
Expand Up @@ -45,9 +45,9 @@
from typing import Optional
from utils_qa import postprocess_qa_predictions
try:
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
IPEX_112 = True
except:
assert False, "transformers 4.19.0 requests IPEX version higher or equal to 1.12"

Expand Down Expand Up @@ -630,6 +630,7 @@ def eval_func(model):
return take_eval_steps(model, trainer, metric_name)

if model_args.tune:
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
from neural_compressor.experimental import Quantization, common
quantizer = Quantization('conf.yaml')
quantizer.eval_func = eval_func
Expand All @@ -640,32 +641,13 @@ def eval_func(model):
return

if model_args.benchmark or model_args.accuracy_only:
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
model.eval()
if model_args.int8:
if IPEX_112:
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
import intel_extension_for_pytorch as ipex
import torch
static_qconfig = QConfig(activation=MinMaxObserver.with_args(
qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, \
qscheme=torch.per_channel_symmetric))
sample = next(iter(trainer.get_eval_dataloader()))
example_inputs = []
for key, value in sample.items():
example_inputs.append(value)
ipex.quantization.prepare(model, static_qconfig, \
example_inputs=example_inputs, inplace=True)
configure_dir = os.path.join(os.getcwd(), training_args.output_dir, "best_configure.json")
model.load_qconf_summary(qconf_summary = configure_dir)
model = ipex.quantization.convert(model)
with torch.no_grad():
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model)
output = model(**sample)
output = model(**sample)
trainer.model = model
else:
assert "this script request IPEX version higher or equal to 1.12, please see README.md for details"
from neural_compressor.utils.pytorch import load
q_model = load(training_args.output_dir, model)
trainer.model = q_model

start_time = timeit.default_timer()
results = trainer.evaluate()
evalTime = timeit.default_timer() - start_time
Expand Down
Expand Up @@ -33,13 +33,13 @@
import torch.fx.experimental.optimization as optimization

try:
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
IPEX_112 = True
except:
IPEX_112 = False
assert False, "please install intel-extension-for-pytorch, support version higher than 1.10"

import intel_extension_for_pytorch as ipex
#disable for fp32, bf16
use_ipex = False


Expand Down Expand Up @@ -624,6 +624,8 @@ def coco_eval(model):
return False

if args.tune:
ssd_r34.eval()
ssd_r34.model = optimization.fuse(ssd_r34.model)
from neural_compressor.experimental import Quantization, common
quantizer = Quantization("./conf.yaml")
quantizer.model = common.Model(ssd_r34)
Expand All @@ -634,41 +636,12 @@ def coco_eval(model):
return

if args.benchmark or args.accuracy_mode:
ssd_r34.eval()
if args.int8:
config_file = os.path.join(args.tuned_checkpoint, "best_configure.json")
assert os.path.exists(config_file), "there is no ipex config file, Please tune with Neural Compressor first!"
if IPEX_112:
ssd_r34 = ssd_r34.eval()
print('int8 conv_bn_fusion enabled')
with torch.no_grad():
ssd_r34.model = optimization.fuse(ssd_r34.model, inplace=False)
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
example_inputs = torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)
prepared_model = prepare(ssd_r34, qconfig, example_inputs=example_inputs, inplace=False)
print("INT8 LLGA start trace")
# insert quant/dequant based on configure.json
prepared_model.load_qconf_summary(qconf_summary = config_file)
convert_model = convert(prepared_model)
with torch.no_grad():
model = torch.jit.trace(convert_model, example_inputs, check_trace=False).eval()
model = torch.jit.freeze(model)
print("done ipex default recipe.......................")
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
with torch.no_grad():
for i in range(2):
_, _ = model(example_inputs)
ssd_r34 = model
else:
ssd_r34 = ssd_r34.eval()
print('int8 conv_bn_fusion enabled')
ssd_r34.model = optimization.fuse(ssd_r34.model)
print("INT8 LLGA start trace")
# insert quant/dequant based on configure.json
conf = ipex.quantization.QuantConf(configure_file = config_file)
ssd_r34 = ipex.quantization.convert(ssd_r34, conf, torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
print("done ipex default recipe.......................")
config_file = os.path.join(args.tuned_checkpoint, "best_model.pt")
assert os.path.exists(config_file), "there is no ipex model file, Please tune with Neural Compressor first!"
from neural_compressor.utils.pytorch import load
ssd_r34 = load(args.tuned_checkpoint, ssd_r34)
coco_eval(ssd_r34)
return

Expand Down
Expand Up @@ -12,6 +12,7 @@ function main {
function init_params {
tuned_checkpoint=saved_results
batch_size=16
iters=100
for var in "$@"
do
case $var in
Expand Down Expand Up @@ -50,21 +51,21 @@ function init_params {
# run_benchmark
function run_benchmark {

extra_cmd=""
if [[ ${mode} == "accuracy" ]]; then
mode_cmd="--accuracy-mode "
elif [[ ${mode} == "benchmark" ]]; then
mode_cmd="--benchmark "
extra_cmd=$extra_cmd" --iteration ${iters}"
else
echo "Error: No such mode: ${mode}"
exit 1
fi

extra_cmd=""
if [[ ${int8} == "true" ]]; then
extra_cmd=$extra_cmd"--int8"
extra_cmd=$extra_cmd" --int8"
fi


python infer.py \
--data ${dataset_location} \
--device 0 \
Expand Down
Expand Up @@ -93,19 +93,14 @@
from torch.nn.parallel.scatter_gather import gather, scatter
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import _LRScheduler

# intel
import intel_extension_for_pytorch as ipex
from torch.utils import ThroughputBenchmark
# For distributed run
import extend_distributed as ext_dist

try:
from intel_extension_for_pytorch.quantization import prepare, convert
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
IPEX_112 = True
import intel_extension_for_pytorch as ipex
except:
IPEX_112 = False
assert False, "please install intel-extension-for-pytorch, support version higher than 1.10"


exc = getattr(builtins, "IOError", "FileNotFoundError")
Expand Down Expand Up @@ -411,32 +406,12 @@ def trace_model(args, dlrm, test_ld, inplace=True):
dlrm.emb_l.bfloat16()
dlrm = ipex.optimize(dlrm, dtype=torch.bfloat16, inplace=inplace)
elif args.int8 and not args.tune:
if IPEX_112:
if args.num_cpu_cores != 0:
torch.set_num_threads(args.num_cpu_cores)
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
prepare(dlrm, qconfig, example_inputs=(X, lS_o, lS_i), inplace=True)
dlrm.load_qconf_summary(qconf_summary = args.int8_configure)
convert(dlrm, inplace=True)
dlrm = torch.jit.trace(dlrm, [X, lS_o, lS_i])
dlrm = torch.jit.freeze(dlrm)
else:
conf = ipex.quantization.QuantConf(args.int8_configure)
dlrm = ipex.quantization.convert(dlrm, conf, (X, lS_o, lS_i))
from neural_compressor.utils.pytorch import load
dlrm = load(args.save_model, dlrm)
elif args.int8 and args.tune:
dlrm = dlrm
else:
dlrm = ipex.optimize(dlrm, dtype=torch.float, inplace=inplace)
if not IPEX_112 and not args.tune:
if args.int8:
dlrm = freeze(dlrm)
else:
with torch.cpu.amp.autocast(enabled=args.bf16):
dlrm = torch.jit.trace(dlrm, (X, lS_o, lS_i), check_trace=True)
dlrm = torch.jit.freeze(dlrm)
dlrm(X, lS_o, lS_i)
dlrm(X, lS_o, lS_i)
return dlrm


Expand Down
Expand Up @@ -79,7 +79,7 @@ function run_tuning {
--arch-sparse-feature-size=128 --max-ind-range=40000000 \
--numpy-rand-seed=727 --inference-only --ipex-interaction \
--print-freq=100 --print-time --mini-batch-size=2048 --test-mini-batch-size=16384 \
--test-freq=2048 --print-auc $ARGS \
--save-model ${tuned_checkpoint} --test-freq=2048 --print-auc $ARGS \
--load-model=${input_model}
elif [[ ${mode} == "benchmark" ]]; then
LOG_0="${LOG}/throughput.log"
Expand All @@ -92,7 +92,7 @@ function run_tuning {
--arch-sparse-feature-size=128 --max-ind-range=40000000 --ipex-interaction \
--numpy-rand-seed=727 --inference-only --num-batches=1000 \
--print-freq=10 --print-time --mini-batch-size=128 --test-mini-batch-size=${batch_size} \
--share-weight-instance=$CORES --num-cpu-cores=$CORES\
--save-model ${tuned_checkpoint} --share-weight-instance=$CORES --num-cpu-cores=$CORES\
$ARGS |tee $LOG_0
wait
set +x
Expand Down

0 comments on commit 23c585e

Please sign in to comment.