Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openvino e5-small not working after conversion model #608

Open
2 of 4 tasks
ZeusFSX opened this issue Mar 14, 2024 · 2 comments
Open
2 of 4 tasks

Openvino e5-small not working after conversion model #608

ZeusFSX opened this issue Mar 14, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@ZeusFSX
Copy link

ZeusFSX commented Mar 14, 2024

System Info

optimum==1.17.1
openvino==2024.0.0
PyTorch==2.2.1+cu121
python-3.10.12

Who can help?

@echarlaix @michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

I used model intfloat/multilingual-e5-small to finetune it for token-classification task.

When I converted it into openvino format everything done without errors. But when I tried to run it i got the error.

Here my inference code:

import torch

from transformers import AutoTokenizer,  pipeline
from optimum.intel import OVModelForTokenClassification

tokenizer = AutoTokenizer.from_pretrained('models/er_4-small-vino')
model = OVModelForTokenClassification.from_pretrained('models/er_4-small-vino')

pipe = pipeline(
    "token-classification",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    batch_size=32,
)

pipe('some text')
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 1
----> 1 pipe('some text')

File ~/gpt/lib/python3.10/site-packages/transformers/pipelines/token_classification.py:248, in TokenClassificationPipeline.__call__(self, inputs, **kwargs)
    245 if offset_mapping:
    246     kwargs["offset_mapping"] = offset_mapping
--> 248 return super().__call__(inputs, **kwargs)

File ~/gpt/lib/python3.10/site-packages/transformers/pipelines/base.py:1188, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1186     return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
   1187 elif self.framework == "pt" and isinstance(self, ChunkPipeline):
-> 1188     return next(
   1189         iter(
   1190             self.get_iterator(
   1191                 [inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
   1192             )
   1193         )
   1194     )
   1195 else:
   1196     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/gpt/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py:124, in PipelineIterator.__next__(self)
    121     return self.loader_batch_item()
    123 # We're out of items within a batch
--> 124 item = next(self.iterator)
    125 processed = self.infer(item, **self.params)
    126 # We now have a batch of "inferred things".

File ~/gpt/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py:266, in PipelinePackIterator.__next__(self)
    263             return accumulator
    265 while not is_last:
--> 266     processed = self.infer(next(self.iterator), **self.params)
    267     if self.loader_batch_size is not None:
    268         if isinstance(processed, torch.Tensor):

File ~/gpt/lib/python3.10/site-packages/transformers/pipelines/base.py:1102, in Pipeline.forward(self, model_inputs, **forward_params)
   1100     with inference_context():
   1101         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1102         model_outputs = self._forward(model_inputs, **forward_params)
   1103         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   1104 else:

File ~/gpt/lib/python3.10/site-packages/transformers/pipelines/token_classification.py:285, in TokenClassificationPipeline._forward(self, model_inputs)
    283     logits = self.model(**model_inputs)[0]
    284 else:
--> 285     output = self.model(**model_inputs)
    286     logits = output["logits"] if isinstance(output, dict) else output[0]
    288 return {
    289     "logits": logits,
    290     "special_tokens_mask": special_tokens_mask,
   (...)
    294     **model_inputs,
    295 }

File ~/gpt/lib/python3.10/site-packages/optimum/modeling_base.py:90, in OptimizedModel.__call__(self, *args, **kwargs)
     89 def __call__(self, *args, **kwargs):
---> 90     return self.forward(*args, **kwargs)

File ~/gpt/lib/python3.10/site-packages/optimum/intel/openvino/modeling.py:340, in OVModelForTokenClassification.forward(self, input_ids, attention_mask, token_type_ids, **kwargs)
    337     inputs["token_type_ids"] = token_type_ids
    339 # Run inference
--> 340 outputs = self.request(inputs)
    341 logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
    342 return TokenClassifierOutput(logits=logits)

File ~/gpt/lib/python3.10/site-packages/openvino/runtime/ie_api.py:387, in CompiledModel.__call__(self, inputs, share_inputs, share_outputs, shared_memory)
    384 if self._infer_request is None:
    385     self._infer_request = self.create_infer_request()
--> 387 return self._infer_request.infer(
    388     inputs,
    389     share_inputs=_deprecated_memory_arg(shared_memory, share_inputs),
    390     share_outputs=share_outputs,
    391 )

File ~/gpt/lib/python3.10/site-packages/openvino/runtime/ie_api.py:144, in InferRequest.infer(self, inputs, share_inputs, share_outputs, shared_memory)
     68 def infer(
     69     self,
     70     inputs: Any = None,
   (...)
     74     shared_memory: Any = None,
     75 ) -> OVDict:
     76     """Infers specified input(s) in synchronous mode.
     77 
     78     Blocks all methods of InferRequest while request is running.
   (...)
    142     :rtype: OVDict
    143     """
--> 144     return OVDict(super().infer(_data_dispatch(
    145         self,
    146         inputs,
    147         is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
    148     ), share_outputs=share_outputs))

File ~/gpt/lib/python3.10/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py:426, in _data_dispatch(request, inputs, is_shared)
    424 if inputs is None:
    425     return {}
--> 426 return create_shared(inputs, request) if is_shared else create_copied(inputs, request)

File /usr/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/gpt/lib/python3.10/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py:220, in _(inputs, request)
    212 @create_shared.register(dict)
    213 @create_shared.register(tuple)
    214 @create_shared.register(OVDict)
   (...)
    217     request: _InferRequestWrapper,
    218 ) -> dict:
    219     request._inputs_data = normalize_arrays(inputs, is_shared=True)
--> 220     return {k: value_to_tensor(v, request=request, is_shared=True, key=k) for k, v in request._inputs_data.items()}

File ~/gpt/lib/python3.10/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py:220, in <dictcomp>(.0)
    212 @create_shared.register(dict)
    213 @create_shared.register(tuple)
    214 @create_shared.register(OVDict)
   (...)
    217     request: _InferRequestWrapper,
    218 ) -> dict:
    219     request._inputs_data = normalize_arrays(inputs, is_shared=True)
--> 220     return {k: value_to_tensor(v, request=request, is_shared=True, key=k) for k, v in request._inputs_data.items()}

File /usr/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/gpt/lib/python3.10/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py:51, in value_to_tensor(value, request, is_shared, key)
     44 @singledispatch
     45 def value_to_tensor(
     46     value: Union[Tensor, np.ndarray, ScalarTypes, str],
   (...)
     49     key: Optional[ValidKeys] = None,
     50 ) -> None:
---> 51     raise TypeError(f"Incompatible inputs of type: {type(value)}")

TypeError: Incompatible inputs of type: <class 'NoneType'>

I've tested it with lower version openvino==2023.3 and also with optimum==1.16.1 but got the same error.

Also I tested openvino but with model intfloat/multilingual-e5-base finetuned on the same dataset and it is working perfectly.

Expected behavior

I expexted it will work with small version too.

@ZeusFSX ZeusFSX added the bug Something isn't working label Mar 14, 2024
@echarlaix echarlaix transferred this issue from huggingface/optimum Mar 14, 2024
@ZeusFSX
Copy link
Author

ZeusFSX commented Mar 15, 2024

Here the described problem in more detail huggingface/optimum#1758

@ZeusFSX
Copy link
Author

ZeusFSX commented Mar 15, 2024

Here is working implementation for OVModelForTokenClassification

from typing import Union, Optional
from transformers.modeling_outputs import TokenClassifierOutput
from optimum.intel.openvino.modeling import OVModel


class OVModelForTokenClassification(OVModel):
    export_feature = "token-classification"
    auto_model_class = AutoModelForTokenClassification

    def __init__(self, model=None, config=None, **kwargs):
        super().__init__(model, config, **kwargs)

    def forward(
        self,
        input_ids: Union[torch.Tensor, np.ndarray],
        attention_mask: Union[torch.Tensor, np.ndarray],
        token_type_ids: Optional[Union[torch.Tensor, np.ndarray]] = None,
        **kwargs,
    ):
        self.compile()

        np_inputs = isinstance(input_ids, np.ndarray)
        if not np_inputs:
            input_ids = np.array(input_ids)
            attention_mask = np.array(attention_mask)
            token_type_ids = np.array(token_type_ids) if token_type_ids is not None else token_type_ids

        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

        # Add the token_type_ids when needed
        if "token_type_ids" in self.input_names:
            # if token_type_ids is None but it in input_names fill it with zeroes
            # see: https://github.com/huggingface/transformers/blob/56b64bf1a51e29046bb3f8ca15839ff4d6a92c74/src/transformers/models/bert/modeling_bert.py#L976
            if token_type_ids is not None:
                inputs["token_type_ids"] = token_type_ids
            else:
                inputs["token_type_ids"] = np.zeros(input_ids.shape)

        # Run inference
        outputs = self.request(inputs)
        logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
        return TokenClassifierOutput(logits=logits)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant