Skip to content

Commit

Permalink
Explicitly set batch=1 for NNCF in order to avoid issue with Wave2Vec (
Browse files Browse the repository at this point in the history
…#312)

* explicitly set batch=1 for NNCF in order to avoid issue with Wave2Vec

* More safe changes. affects only pruning scenario

* renamed variable

* Corrections

* correction

* fixed style
  • Loading branch information
ljaljushkin committed Jun 7, 2023
1 parent 8109a15 commit 2d2af36
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
4 changes: 2 additions & 2 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def __init__(
self._enable_standard_onnx_export_option()
self.optimum_version = kwargs.pop("optimum_version", None)

def add_input_info(self, model_inputs: Dict):
def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False):
self.input_info = [
{
"sample_size": list(value.shape),
"sample_size": [1] + list(value.shape[1:]) if force_batch_one else list(value.shape),
"type": "long" if value.dtype is torch.int64 else "float",
"keyword": name,
}
Expand Down
12 changes: 11 additions & 1 deletion optimum/intel/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def __init__(
model_inputs = next(iter(train_dataloader))
for label_name in self.label_names:
model_inputs.pop(label_name)
self.ov_config.add_input_info(model_inputs)
force_batch_one = self._is_pruning_enabled()
self.ov_config.add_input_info(model_inputs, force_batch_one)
nncf_config = NNCFConfig.from_dict(self.ov_config.__dict__)
nncf_config.register_extra_structs(
[
Expand Down Expand Up @@ -770,3 +771,12 @@ def _set_task(self):
if self.task is None:
raise ValueError("The model task defining the model topology needs to be specified for the ONNX export.")
self.task = _TASK_ALIASES.get(self.task, self.task)

def _is_pruning_enabled(compression: Union[Dict, List, None]):
if isinstance(compression, dict) and compression["algorithm"] == "movement_pruning":
return True
if isinstance(compression, list):
for algo_config in compression:
if algo_config["algorithm"] == "movement_pruning":
return True
return False

0 comments on commit 2d2af36

Please sign in to comment.