Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IPEX models for audio and image classification tasks #536

Merged
merged 12 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
"IPEXModelForQuestionAnswering",
"IPEXModelForImageClassification",
"IPEXModelForAudioClassification",
"IPEXModel",
]


try:
if not (is_openvino_available() and is_nncf_available()):
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -159,7 +161,10 @@
from .utils.dummy_ipex_objects import *
else:
from .ipex import (
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSequenceClassification,
Expand Down
4 changes: 1 addition & 3 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
model.config.return_dict = False
model.config.return_dict = task not in {"text-generation", "audio-classification"}
# check if the model_inputs is correct.
model(**model_inputs)

torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
if is_torch_version(">=", "2.1.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from optimum.intel.ipex.modeling_base import (
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSequenceClassification,
Expand Down
70 changes: 62 additions & 8 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from transformers import (
AutoConfig,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
Expand Down Expand Up @@ -68,6 +70,9 @@ def __init__(
self.model.to(self._device)
self.model_save_dir = model_save_dir

self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand Down Expand Up @@ -170,8 +175,22 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
torch.jit.save(self.model, output_path)

def forward(self, *args, **kwargs):
outputs = self.model(*args, **kwargs)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])

def eval(self):
Expand All @@ -196,14 +215,52 @@ class IPEXModelForSequenceClassification(IPEXModel):
export_feature = "text-classification"


class IPEXModelForTokenClassification(IPEXModel):
auto_model_class = AutoModelForTokenClassification
export_feature = "token-classification"


class IPEXModelForMaskedLM(IPEXModel):
auto_model_class = AutoModelForMaskedLM
export_feature = "fill-mask"


class IPEXModelForTokenClassification(IPEXModel):
auto_model_class = AutoModelForTokenClassification
export_feature = "token-classification"
class IPEXModelForImageClassification(IPEXModel):
auto_model_class = AutoModelForImageClassification
export_feature = "image-classification"

def forward(
self,
pixel_values: torch.Tensor,
**kwargs,
):
inputs = {
"pixel_values": pixel_values,
}

outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForAudioClassification(IPEXModel):
auto_model_class = AutoModelForAudioClassification
export_feature = "audio-classification"

def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_values": input_values,
}

if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForQuestionAnswering(IPEXModel):
Expand Down Expand Up @@ -233,9 +290,6 @@ def __init__(

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand Down
33 changes: 33 additions & 0 deletions optimum/intel/utils/dummy_ipex_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])


class IPEXModel(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForSequenceClassification(metaclass=DummyObject):
_backends = ["ipex"]

Expand Down Expand Up @@ -75,3 +86,25 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForImageClassification(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForAudioClassification(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])
Loading
Loading