Skip to content

Commit

Permalink
support ipex xpu (#1348)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <xin3.he@intel.com>
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
  • Loading branch information
xin3he committed Nov 22, 2023
1 parent c3214c9 commit af0b50f
Show file tree
Hide file tree
Showing 14 changed files with 555 additions and 249 deletions.
Expand Up @@ -106,6 +106,8 @@
help='run benchmark')
parser.add_argument('--ipex', dest='ipex', action='store_true',
help='tuning or benchmark with Intel PyTorch Extension')
parser.add_argument('--xpu', action='store_true',
help='whether use xpu')

best_acc1 = 0

Expand Down Expand Up @@ -225,7 +227,8 @@ def main_worker(gpu, ngpus_per_node, args):
model.cuda()
else:
model = torch.nn.DataParallel(model)

if args.xpu:
model = model.to("xpu")
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
#criterion = nn.CrossEntropyLoss().cuda(args.gpu)
Expand Down Expand Up @@ -297,7 +300,10 @@ def eval_func(model):
if args.tune:
from neural_compressor import PostTrainingQuantConfig
from neural_compressor import quantization
conf = PostTrainingQuantConfig(backend='ipex')
if args.xpu:
conf = PostTrainingQuantConfig(backend='ipex', device="xpu")
else:
conf = PostTrainingQuantConfig(backend='ipex')
q_model = quantization.fit(model,
conf,
calib_dataloader=val_loader,
Expand Down Expand Up @@ -417,6 +423,9 @@ def validate(val_loader, model, criterion, args):
if args.gpu is not None:
input = input.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
if args.xpu:
input = input.to("xpu")
target = target.to("xpu")

# compute output
output = model(input)
Expand Down
@@ -1,5 +1,5 @@
accelerate
datasets>=1.8.0
transformers==4.30.0
transformers>=4.34.1
tensorboard
tqdm
Expand Up @@ -113,7 +113,9 @@ class ModelArguments:
"help": "The inference iterations to run for benchmark."
},
)

xpu: bool = field(
default=False, metadata={"help": "whether to use xpu"}
)

@dataclass
class DataTrainingArguments:
Expand Down Expand Up @@ -650,6 +652,9 @@ def take_eval_steps(model, trainer, metric_name, save_metrics=False):
def eval_func(model):
return take_eval_steps(model, trainer, metric_name)

if model_args.xpu:
model = model.to("xpu")

if model_args.tune:
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
from neural_compressor.config import PostTrainingQuantConfig
Expand All @@ -664,6 +669,8 @@ def eval_func(model):
else:
example_inputs = None # please provide correct example_inputs if necessary.
conf = PostTrainingQuantConfig(backend="ipex", calibration_sampling_size=800, example_inputs=example_inputs)
if model_args.xpu:
conf.device = "xpu"
q_model = quantization.fit(model,
conf,
calib_dataloader=eval_dataloader,
Expand All @@ -680,7 +687,7 @@ def eval_func(model):
example_inputs = get_example_inputs(model, eval_dataloader)
model = ipex.optimize(model)
with torch.no_grad():
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.trace(model, example_inputs=example_inputs, strict=False)
model = torch.jit.freeze(model)

if model_args.benchmark or model_args.accuracy_only:
Expand All @@ -692,6 +699,8 @@ def eval_func(model):
iteration=model_args.iters,
cores_per_instance=4,
num_of_instance=1)
if model_args.xpu:
b_conf.device = "xpu"
benchmark.fit(model, b_conf, b_dataloader=eval_dataloader)
else:
eval_func(model)
Expand Down
Expand Up @@ -44,7 +44,7 @@ function run_tuning {
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--no_cuda \
--no_cuda \ # remove if using xpu
--tune \
--output_dir $tuned_checkpoint
fi
Expand All @@ -55,7 +55,7 @@ function run_tuning {
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--no_cuda \
--no_cuda \ # remove if using xpu
--tune \
--output_dir $tuned_checkpoint
fi
Expand Down
121 changes: 54 additions & 67 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -78,73 +78,37 @@ def get_torch_white_list(approach):
return white_list


def pytorch_forward_wrapper(model, input, device="cpu", conf=None, running_mode="inference"):
def pytorch_forward_wrapper(
model,
input,
conf=None,
backend="default",
running_mode="inference",
):
version = get_torch_version()
if isinstance(input, dict) or isinstance(input, UserDict):
if device == "cpu":
output = model(**input)
elif device == "ipex":
# have to split the case to avoid exposing ipex.DEVICE outside
# which require intel extension installed
if version.release < Version("1.12.0").release: # pragma: no cover
if running_mode == "calibration":
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = model(**input)
else:
output = model(**input)
else:
output = model(**input)
else: # pragma: no cover
for inp in input.keys():
input[inp] = (
input[inp].to("dpcpp" if device == "gpu" else device)
if isinstance(input[inp], torch.Tensor)
else input[inp]
)
output = model(**input)
elif isinstance(input, list) or isinstance(input, tuple):
if device == "cpu":
output = model(*input)
elif device == "ipex":
if version.release < Version("1.12.0").release: # pragma: no cover
if running_mode == "calibration":
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = model(*input)
else:
output = model(*input)
else:
output = model(*input)
else: # pragma: no cover
tmp_device = "dpcpp" if device == "gpu" else device
input = [
inp.to(tmp_device) if isinstance(inp, torch.Tensor) else inp for inp in input
] # pylint: disable=E1133
output = model(*input)
from .torch_utils.util import forward_wrapper

if (
version.release < Version("1.12.0").release and backend == "ipex" and running_mode == "calibration"
): # pragma: no cover
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = forward_wrapper(model, input)
else:
if device == "cpu" or not isinstance(input, torch.Tensor):
output = model(input)
elif device == "ipex":
if version.release < Version("1.12.0").release: # pragma: no cover
if running_mode == "calibration":
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = model(input)
else:
output = model(input)
else:
output = model(input)
else: # pragma: no cover
input = input.to("dpcpp" if device == "gpu" else device) # pylint: disable=no-member
output = model(input)
output = forward_wrapper(model, input)
return output


def get_example_inputs(model, dataloader):
version = get_torch_version()
from .torch_utils.util import move_input_device

# Suggest set dataloader like calib_dataloader
if dataloader is None:
return None
device = next(model.parameters()).device
try:
for idx, (input, label) in enumerate(dataloader):
input = move_input_device(input, device)
output = pytorch_forward_wrapper(model, input)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
Expand All @@ -162,6 +126,7 @@ def get_example_inputs(model, dataloader):
break
except Exception as e: # pragma: no cover
for idx, input in enumerate(dataloader):
input = move_input_device(input, device)
output = pytorch_forward_wrapper(model, input)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
Expand Down Expand Up @@ -814,6 +779,7 @@ def __init__(self, framework_specific_info):
self.bf16_ops = []
self.use_bf16 = framework_specific_info.get("use_bf16", True)
self.device = framework_specific_info["device"]
self.backend = framework_specific_info.get("backend", "default")
self.q_dataloader = framework_specific_info["q_dataloader"]
self.q_func = framework_specific_info.get("q_func", None)
self.benchmark = GLOBAL_STATE.STATE == MODE.BENCHMARK
Expand Down Expand Up @@ -881,14 +847,14 @@ def calib_func(self, model, dataloader, tmp_iterations, conf=None):
try:
for idx, (input, label) in enumerate(dataloader):
output = pytorch_forward_wrapper(
model, input, device=self.device, conf=conf, running_mode="calibration"
model, input, backend=self.backend, conf=conf, running_mode="calibration"
)
if idx >= tmp_iterations - 1:
break
except Exception as e:
for idx, input in enumerate(dataloader):
output = pytorch_forward_wrapper(
model, input, device=self.device, conf=conf, running_mode="calibration"
model, input, backend=self.backend, conf=conf, running_mode="calibration"
)
if idx >= tmp_iterations - 1:
break
Expand Down Expand Up @@ -936,7 +902,7 @@ def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration
if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
output = pytorch_forward_wrapper(model, input, backend=self.backend, conf=conf)
if self.device != "cpu": # pragma: no cover
output = output.to("cpu")
label = label.to("cpu")
Expand Down Expand Up @@ -978,7 +944,7 @@ def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration
if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
output = pytorch_forward_wrapper(model, input, backend=self.backend, conf=conf)

if measurer is not None:
measurer.end()
Expand Down Expand Up @@ -2272,7 +2238,7 @@ def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, **kw
on_step_begin(cnt)
print(".", end="", flush=True)
cnt += 1
output = pytorch_forward_wrapper(model_, image, device=device)
output = pytorch_forward_wrapper(model_, image)
loss = criterion(output, target)
if hooks is not None:
loss = on_after_compute_loss(image, output, loss)
Expand Down Expand Up @@ -2639,7 +2605,9 @@ def __init__(self, framework_specific_info):
super(PyTorch_IPEXAdaptor, self).__init__(framework_specific_info)
self.version = get_torch_version()
query_config_file = "pytorch_ipex.yaml"
self.query_handler = PyTorchQuery(local_config_file=os.path.join(os.path.dirname(__file__), query_config_file))
self.query_handler = PyTorchQuery(
device=self.device, local_config_file=os.path.join(os.path.dirname(__file__), query_config_file)
)
self.cfgs = None
self.fuse_ops = None
self.op_infos_from_cfgs = None
Expand All @@ -2651,7 +2619,6 @@ def __init__(self, framework_specific_info):
os.remove(self.ipex_config_path)
except:
logger.warning("Fail to remove {}.".format(self.ipex_config_path))
self.device = "ipex"

@dump_elapsed_time("Pass quantize model")
def quantize(self, tune_cfg, model, dataloader, q_func=None):
Expand All @@ -2669,6 +2636,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
# IPEX bug #1: deepcopied prepared model cannot do calibration, need model._model
# q_model._model is useless, but we need to copy other attributes, and pass the converted
# model to q_model. Also, sq will collect state_dict to origin_stat for recover
if self.device == "xpu":
model.to(self.device)
if self.performance_only:
q_model = model
else:
Expand Down Expand Up @@ -2722,7 +2691,12 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
if not hasattr(model._model, "save_qconf_summary") or not hasattr(model._model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if self.version.release >= Version("2.1").release:
if self.device == "xpu":
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
)
elif self.version.release >= Version("2.1").release:
static_qconfig = ipex.quantization.default_static_qconfig_mapping
else:
static_qconfig = QConfig(
Expand Down Expand Up @@ -3107,7 +3081,12 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
), "IPEX need q_dataloader or example_inputs to prepare the model"
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if self.version.release >= Version("2.1").release:
if self.device == "xpu":
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
)
elif self.version.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
Expand Down Expand Up @@ -3145,6 +3124,9 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.example_inputs is None:
self.example_inputs = get_example_inputs(model, self.q_dataloader)
from neural_compressor.adaptor.torch_utils.util import move_input_device

self.example_inputs = move_input_device(self.example_inputs, device=self.device)
if isinstance(self.example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=self.example_inputs, inplace=True
Expand Down Expand Up @@ -3400,7 +3382,7 @@ def _simple_inference(self, q_model, dataloader, iterations=1):
"""The function is used for ipex warm-up inference."""
if self.example_inputs is not None:
for _ in range(iterations):
if isinstance(self.example_inputs, tuple):
if isinstance(self.example_inputs, tuple) or isinstance(self.example_inputs, list):
q_model(*self.example_inputs)
elif isinstance(self.example_inputs, dict):
q_model(**self.example_inputs)
Expand Down Expand Up @@ -3919,7 +3901,7 @@ def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, **kw
on_step_begin(cnt)
print(".", end="", flush=True)
cnt += 1
output = pytorch_forward_wrapper(model._model, input, device=device)
output = pytorch_forward_wrapper(model._model, input)
loss = criterion(output, target)
if hooks is not None:
loss = on_after_compute_loss(input, output, loss)
Expand Down Expand Up @@ -4936,10 +4918,11 @@ def query_fw_capability(self, model):


class PyTorchQuery(QueryBackendCapability):
def __init__(self, local_config_file=None):
def __init__(self, device="cpu", local_config_file=None):
super().__init__()
self.version = get_torch_version()
self.cfg = local_config_file
self.device = device
self.cur_config = None
self._one_shot_query()

Expand Down Expand Up @@ -4973,6 +4956,10 @@ def _one_shot_query(self):
raise ValueError(
"Please check if the format of {} follows " "Neural Compressor yaml scheme.".format(self.cfg)
)
if self.device == "xpu":
self.cur_config = self.cur_config[self.device]
elif "cpu" in self.cur_config:
self.cur_config = self.cur_config["cpu"]
self._update_cfg_with_usr_definition()

def _update_cfg_with_usr_definition(self):
Expand Down

0 comments on commit af0b50f

Please sign in to comment.