diff --git a/README.md b/README.md index 91e4d21b22..e0ea89bcc3 100644 --- a/README.md +++ b/README.md @@ -139,12 +139,12 @@ deepsparse.benchmark [-h] [-b BATCH_SIZE] [-shapes INPUT_SHAPES] ## 👩‍💻 NLP Inference Example ```python -from deepsparse.transformers import pipeline +from deepsparse import Pipeline # SparseZoo model stub or path to ONNX file model_path = "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/12layer_pruned80_quant-none-vnni" -qa_pipeline = pipeline( +qa_pipeline = Pipeline.create( task="question-answering", model_path=model_path, ) diff --git a/src/deepsparse/__init__.py b/src/deepsparse/__init__.py index 3d3113b74b..d9c28dc591 100644 --- a/src/deepsparse/__init__.py +++ b/src/deepsparse/__init__.py @@ -31,6 +31,7 @@ cpu_vnni_compatible, ) from .engine import * +from .pipeline import * from .version import __version__, is_release diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 45343e2a39..155744fd90 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -21,7 +21,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy from pydantic import BaseModel, Field @@ -72,7 +72,8 @@ class Pipeline(ABC): * `engine` <- `_initialize_engine` - on __call__: - * `pre_processed_inputs` <- `process_inputs(inputs: input_model)` + * `parsed_inputs: input_model` <- `parse_inputs(*args, **kwargs)` + * `pre_processed_inputs` <- `process_inputs(parsed_inputs)` * `engine_outputs` <- `engine(pre_processed_inputs)` * `outputs: output_model` <- `process_engine_outputs(engine_outputs)` @@ -133,25 +134,29 @@ def __init__( self._engine_args["scheduler"] = scheduler self.onnx_file_path = self.setup_onnx_file_path() - self._engine = self._initialize_engine() + self.engine = self._initialize_engine() - def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel: - if pipeline_inputs is None and kwargs: - # parse kwarg inputs into the expected input format - pipeline_inputs = self.input_model(**kwargs) - - # validate inputs format + def __call__(self, *args, **kwargs) -> BaseModel: + # parse inputs into input_model schema if necessary + pipeline_inputs = self.parse_inputs(*args, **kwargs) if not isinstance(pipeline_inputs, self.input_model): - raise ValueError( - f"Calling {self.__class__} requires passing inputs as an " - f"{self.input_model} object or a list of kwargs used to create " - f"a {self.input_model} object" + raise RuntimeError( + f"Unable to parse {self.__class__} inputs into a " + f"{self.input_model} object. Inputs parsed to {type(pipeline_inputs)}" ) # run pipeline engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs) + + if isinstance(engine_inputs, tuple): + engine_inputs, postprocess_kwargs = engine_inputs + else: + postprocess_kwargs = {} + engine_outputs: List[numpy.ndarray] = self.engine(engine_inputs) - pipeline_outputs = self.process_engine_outputs(engine_outputs) + pipeline_outputs = self.process_engine_outputs( + engine_outputs, **postprocess_kwargs + ) # validate outputs format if not isinstance(pipeline_outputs, self.output_model): @@ -306,17 +311,27 @@ class properties into an inference ready onnx file to be compiled by the raise NotImplementedError() @abstractmethod - def process_inputs(self, inputs: BaseModel) -> List[numpy.ndarray]: + def process_inputs( + self, + inputs: BaseModel, + ) -> Union[List[numpy.ndarray], Tuple[List[numpy.ndarray], Dict[str, Any]]]: """ :param inputs: inputs to the pipeline. Must be the type of the `input_model` of this pipeline :return: inputs of this model processed into a list of numpy arrays that - can be directly passed into the forward pass of the pipeline engine + can be directly passed into the forward pass of the pipeline engine. Can + also include a tuple with engine inputs and special key word arguments + to pass to process_engine_outputs to facilitate information from the raw + inputs to postprocessing that may not be included in the engine inputs """ raise NotImplementedError() @abstractmethod - def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseModel: + def process_engine_outputs( + self, + engine_outputs: List[numpy.ndarray], + **kwargs, + ) -> BaseModel: """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass @@ -327,7 +342,7 @@ def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseMod @property @abstractmethod - def input_model(self) -> BaseModel: + def input_model(self) -> Type[BaseModel]: """ :return: pydantic model class that inputs to this pipeline must comply to """ @@ -335,7 +350,7 @@ def input_model(self) -> BaseModel: @property @abstractmethod - def output_model(self) -> BaseModel: + def output_model(self) -> Type[BaseModel]: """ :return: pydantic model class that outputs of this pipeline must comply to """ @@ -365,13 +380,6 @@ def model_path(self) -> str: """ return self._model_path - @property - def engine(self) -> Union[Engine, ORTEngine]: - """ - :return: engine instance used for model forward pass in pipeline - """ - return self._engine - @property def engine_args(self) -> Dict[str, Any]: """ @@ -417,6 +425,28 @@ def to_config(self) -> "PipelineConfig": kwargs=kwargs, ) + def parse_inputs(self, *args, **kwargs) -> BaseModel: + """ + :param args: ordered arguments to pipeline, only an input_model object + is supported as an arg for this function + :param kwargs: keyword arguments to pipeline + :return: pipeline arguments parsed into the given `input_model` + schema if necessary. If an instance of the `input_model` is provided + it will be returned + """ + # passed input_model schema directly + if len(args) == 1 and isinstance(args[0], self.input_model) and not kwargs: + return args[0] + + if args: + raise ValueError( + f"pipeline {self.__class__} only supports either only a " + f"{self.input_model} object. or keyword arguments to be construct one. " + f"Found {len(args)} args and {len(kwargs)} kwargs" + ) + + return self.input_model(**kwargs) + def _initialize_engine(self) -> Union[Engine, ORTEngine]: engine_type = self.engine_type.lower() diff --git a/src/deepsparse/server/main.py b/src/deepsparse/server/main.py index 9879f92123..831de52443 100644 --- a/src/deepsparse/server/main.py +++ b/src/deepsparse/server/main.py @@ -166,7 +166,7 @@ def _add_pipeline_route( async def _predict_func(request: pipeline.input_model): results = await execute_async( pipeline, - **vars(request), + request, ) return serializable_response(results) diff --git a/src/deepsparse/transformers/__init__.py b/src/deepsparse/transformers/__init__.py index 89c7eb68ef..1264aa316d 100644 --- a/src/deepsparse/transformers/__init__.py +++ b/src/deepsparse/transformers/__init__.py @@ -120,4 +120,3 @@ def _check_transformers_install(): from .helpers import * from .loaders import * from .pipelines import * -from .server import * diff --git a/src/deepsparse/transformers/eval_downstream.py b/src/deepsparse/transformers/eval_downstream.py index b434dec625..8f9e9c5d49 100644 --- a/src/deepsparse/transformers/eval_downstream.py +++ b/src/deepsparse/transformers/eval_downstream.py @@ -58,7 +58,7 @@ from tqdm.auto import tqdm -from deepsparse.transformers import pipeline +from deepsparse import Pipeline from datasets import load_dataset, load_metric # isort: skip @@ -79,14 +79,14 @@ def squad_eval(args): squad_metrics = load_metric("squad") # load QA pipeline - question_answer = pipeline( + question_answer = Pipeline.create( task="question-answering", model_path=args.onnx_filepath, engine_type=args.engine, num_cores=args.num_cores, - max_length=args.max_sequence_length, + sequence_length=args.max_sequence_length, ) - print(f"Engine info: {question_answer.model}") + print(f"Engine info: {question_answer.engine}") for idx, sample in enumerate(tqdm(squad)): pred = question_answer( @@ -96,7 +96,7 @@ def squad_eval(args): ) squad_metrics.add_batch( - predictions=[{"prediction_text": pred["answer"], "id": sample["id"]}], + predictions=[{"prediction_text": pred.answer, "id": sample["id"]}], references=[{"answers": sample["answers"], "id": sample["id"]}], ) @@ -114,21 +114,21 @@ def mnli_eval(args): mnli_metrics = load_metric("glue", "mnli") # load pipeline - text_classify = pipeline( + text_classify = Pipeline.create( task="text-classification", model_path=args.onnx_filepath, engine_type=args.engine, num_cores=args.num_cores, - max_length=args.max_sequence_length, + sequence_length=args.max_sequence_length, ) - print(f"Engine info: {text_classify.model}") + print(f"Engine info: {text_classify.engine}") label_map = {"entailment": 0, "neutral": 1, "contradiction": 2} for idx, sample in enumerate(tqdm(mnli_matched)): pred = text_classify([[sample["premise"], sample["hypothesis"]]]) mnli_metrics.add_batch( - predictions=[label_map.get(pred[0]["label"])], + predictions=[label_map.get(pred.labels[0])], references=[sample["label"]], ) @@ -154,14 +154,14 @@ def qqp_eval(args): qqp_metrics = load_metric("glue", "qqp") # load pipeline - text_classify = pipeline( + text_classify = Pipeline.create( task="text-classification", model_path=args.onnx_filepath, engine_type=args.engine, num_cores=args.num_cores, - max_length=args.max_sequence_length, + sequence_length=args.max_sequence_length, ) - print(f"Engine info: {text_classify.model}") + print(f"Engine info: {text_classify.engine}") label_map = {"not_duplicate": 0, "duplicate": 1} @@ -169,7 +169,7 @@ def qqp_eval(args): pred = text_classify([[sample["question1"], sample["question2"]]]) qqp_metrics.add_batch( - predictions=[label_map.get(pred[0]["label"])], + predictions=[label_map.get(pred.labels[0])], references=[sample["label"]], ) @@ -185,14 +185,14 @@ def sst2_eval(args): sst2_metrics = load_metric("glue", "sst2") # load pipeline - text_classify = pipeline( + text_classify = Pipeline.create( task="text-classification", model_path=args.onnx_filepath, engine_type=args.engine, num_cores=args.num_cores, - max_length=args.max_sequence_length, + sequence_length=args.max_sequence_length, ) - print(f"Engine info: {text_classify.model}") + print(f"Engine info: {text_classify.engine}") label_map = {"negative": 0, "positive": 1} @@ -202,7 +202,7 @@ def sst2_eval(args): ) sst2_metrics.add_batch( - predictions=[label_map.get(pred[0]["label"])], + predictions=[label_map.get(pred.labels[0])], references=[sample["label"]], ) diff --git a/src/deepsparse/transformers/pipelines.py b/src/deepsparse/transformers/pipelines.py deleted file mode 100644 index 7725a0e2c2..0000000000 --- a/src/deepsparse/transformers/pipelines.py +++ /dev/null @@ -1,1414 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Adaptation of transformers.pipelines and onnx_transformers.pipelines - -adapted from: -https://github.com/huggingface/transformers/blob/master/src/transformers/pipelines/base.py -https://github.com/patil-suraj/onnx_transformers/blob/master/onnx_transformers/pipelines.py - -""" - -import json -from abc import ABC, abstractmethod -from dataclasses import dataclass -from itertools import chain -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union - -import numpy as np -from transformers.configuration_utils import PretrainedConfig -from transformers.data import ( - SquadExample, - SquadFeatures, - squad_convert_examples_to_features, -) -from transformers.file_utils import ExplicitEnum -from transformers.models.auto import AutoConfig, AutoTokenizer -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy -from transformers.utils import logging - -from deepsparse import Engine, compile_model, cpu -from deepsparse.transformers.helpers import ( - fix_numpy_types, - get_onnx_path_and_configs, - overwrite_transformer_onnx_model_inputs, -) -from deepsparse.transformers.loaders import get_batch_loader - - -try: - import onnxruntime - - ort_import_error = None -except Exception as ort_import_err: - onnxruntime = None - ort_import_error = ort_import_err - -__all__ = [ - "ArgumentHandler", - "Pipeline", - "TextClassificationPipeline", - "TokenClassificationPipeline", - "QuestionAnsweringPipeline", - "pipeline", - "overwrite_transformer_onnx_model_inputs", - "SUPPORTED_ENGINES", - "SUPPORTED_TASKS", -] - -logger = logging.get_logger(__name__) if logging else None - - -class ArgumentHandler(ABC): - """ - Base interface for handling arguments for each Pipeline. - """ - - @abstractmethod - def __call__(self, *args, **kwargs): - raise NotImplementedError() - - -class DefaultArgumentHandler(ArgumentHandler): - """ - Default argument parser handling parameters for each Pipeline`. - """ - - @staticmethod - def handle_kwargs(kwargs: Dict) -> List: - """ - :param kwargs: key word arguments for a pipeline - :return: list of the processed key word arguments - """ - if len(kwargs) == 1: - output = list(kwargs.values()) - else: - output = list(chain(kwargs.values())) - - return DefaultArgumentHandler.handle_args(output) - - @staticmethod - def handle_args(args: Sequence[Any]) -> List[str]: - """ - :param args: sequence of arguments to a pipeline - :return: list of formatted, processed arguments - """ - - # Only one argument, let's do case by case - if len(args) == 1: - if isinstance(args[0], str): - return [args[0]] - elif not isinstance(args[0], list): - return list(args) - else: - return args[0] - - # Multiple arguments (x1, x2, ...) - elif len(args) > 1: - if all([isinstance(arg, str) for arg in args]): - return list(args) - - # If not instance of list, then it should be an instance of iterable - elif isinstance(args, Iterable): - return list(chain.from_iterable(chain(args))) - else: - raise ValueError( - f"Invalid input type {type(args)}. Pipeline supports " - "Union[str, Iterable[str]]" - ) - else: - return [] - - def __call__(self, *args, **kwargs): - if len(kwargs) > 0 and len(args) > 0: - raise ValueError("Pipeline cannot handle mixed args and kwargs") - - if len(kwargs) > 0: - return DefaultArgumentHandler.handle_kwargs(kwargs) - else: - return DefaultArgumentHandler.handle_args(args) - - -class _ScikitCompat(ABC): - """ - Interface layer for the Scikit and Keras compatibility. - """ - - @abstractmethod - def transform(self, X): - raise NotImplementedError() - - @abstractmethod - def predict(self, X): - raise NotImplementedError() - - -class Pipeline(_ScikitCompat): - """ - The Pipeline class is the class from which all pipelines inherit. - Refer to this class for methods shared across different pipelines. - This base Pipeline class provides support for multiple inference engine backends. - - Base class implementing pipelined operations. - Pipeline workflow is defined as a sequence of the following operations: - - Input -> Tokenization -> Model Inference -> - Post-Processing (task dependent) -> Output - - Pipeline supports running with the DeepSparse engine or onnxruntime. - - :param model: loaded inference engine to run the model with, can be a - deepsparse Engine or onnxruntime InferenceSession - :param tokenizer: tokenizer to be used for preprocessing - :param config: transformers model config for this model - :param engine_type: name of inference engine that is used. Options are - deepsparse and onnxruntime - :param max_length: maximum sequence length to set for model inputs by default. - default value is 128 - :param input_names: list of input names to the neural network - :param args_parser: Reference to the object in charge of parsing supplied - pipeline parameters. A default is provided if None - :param binary_output: if True, stores outputs as pickled binaries to avoid - storing large amount of textual data. Default is False - """ - - default_input_names = None - - def __init__( - self, - model: Union[Engine, "onnxruntime.InferenceSession"], - tokenizer: PreTrainedTokenizer, - config: PretrainedConfig, - engine_type: str, - max_length: int = 128, - input_names: Optional[List[str]] = None, - args_parser: ArgumentHandler = None, - binary_output: bool = False, - ): - - self.model = model - self.tokenizer = tokenizer - self.config = config - self.engine_type = engine_type - self.max_length = max_length - self.input_names = input_names - self.binary_output = binary_output - self._args_parser = args_parser or DefaultArgumentHandler() - self._framework = ( - "np" if self.engine_type in [DEEPSPARSE_ENGINE, ORT_ENGINE] else "pt" - ) - - def transform(self, X): - """ - Scikit / Keras interface to transformers' pipelines. - This method will forward to __call__(). - """ - return self(X=X) - - def predict(self, X): - """ - Scikit / Keras interface to transformers' pipelines. - This method will forward to __call__(). - """ - return self(X=X) - - def _parse_and_tokenize( - self, *args, padding=True, add_special_tokens=True, **kwargs - ): - # Parse arguments - inputs = self._args_parser(*args, **kwargs) - inputs = self.tokenizer( - inputs, - add_special_tokens=add_special_tokens, - return_tensors=self._framework, - padding=PaddingStrategy.MAX_LENGTH.value, - truncation=TruncationStrategy.LONGEST_FIRST.value, - ) - - return inputs - - def __call__(self, *args, **kwargs): - inputs = self._parse_and_tokenize(*args, **kwargs) - return self._forward(inputs) - - def _forward(self, inputs): - if not all(name in inputs for name in self.input_names): - raise ValueError( - f"pipeline expected arrays with names {self.input_names}, received " - f"inputs: {list(inputs.keys())}" - ) - - if self.engine_type == ORT_ENGINE: - inputs = {k: v for k, v in inputs.items() if k in self.input_names} - return self.model.run(None, inputs) - elif self.engine_type == DEEPSPARSE_ENGINE: - return self.model.run([inputs[name] for name in self.input_names]) - # TODO: torch - # with self.device_placement(): - # with torch.no_grad(): - # inputs = self.ensure_tensor_on_device(**inputs) - # predictions = self.model(**inputs)[0].cpu() - # if return_tensors: - # return predictions - # else: - # return predictions.numpy() - - -class TokenClassificationArgumentHandler(ArgumentHandler): - """ - Handles arguments for token classification. - """ - - def __call__(self, inputs: Union[str, List[str]], **kwargs): - - if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: - inputs = list(inputs) - batch_size = len(inputs) - elif isinstance(inputs, str): - inputs = [inputs] - batch_size = 1 - else: - raise ValueError("At least one input is required.") - - offset_mapping = kwargs.get("offset_mapping") - if offset_mapping: - if isinstance(offset_mapping, list) and isinstance( - offset_mapping[0], tuple - ): - offset_mapping = [offset_mapping] - if len(offset_mapping) != batch_size: - raise ValueError( - "offset_mapping should have the same batch size as the input" - ) - return inputs, offset_mapping - - -class QuestionAnsweringArgumentHandler(ArgumentHandler): - """ - QuestionAnsweringPipeline requires the user to provide multiple arguments - (i.e. question & context) to be mapped - to internal `transformers.SquadExample` - - QuestionAnsweringArgumentHandler manages all the possible to create a - `transformers.SquadExample` from the command-line supplied arguments - """ - - def __call__(self, *args, **kwargs): - # Position args, handling is sensibly the same as X and data, - # so forwarding to avoid duplicating - if args is not None and len(args) > 0: - if len(args) == 1: - kwargs["X"] = args[0] - else: - kwargs["X"] = list(args) - - # Generic compatibility with sklearn and Keras - # Batched data - if "X" in kwargs or "data" in kwargs: - inputs = kwargs["X"] if "X" in kwargs else kwargs["data"] - - if isinstance(inputs, dict): - inputs = [inputs] - else: - # Copy to avoid overriding arguments - inputs = [i for i in inputs] - - for i, item in enumerate(inputs): - if isinstance(item, dict): - if any(k not in item for k in ["question", "context"]): - raise KeyError( - "You need to provide a dictionary with keys " - "{question:..., context:...}" - ) - - inputs[i] = QuestionAnsweringPipeline.create_sample(**item) - - elif not isinstance(item, SquadExample): - arg_name = "X" if "X" in kwargs else "data" - raise ValueError( - f"{arg_name} argument needs to be of type " - "(list[SquadExample | dict], SquadExample, dict)" - ) - - # Tabular input - elif "question" in kwargs and "context" in kwargs: - if isinstance(kwargs["question"], str): - kwargs["question"] = [kwargs["question"]] - - if isinstance(kwargs["context"], str): - kwargs["context"] = [kwargs["context"]] - - inputs = [ - QuestionAnsweringPipeline.create_sample(q, c) - for q, c in zip(kwargs["question"], kwargs["context"]) - ] - else: - raise ValueError(f"Unknown arguments {kwargs}") - - if not isinstance(inputs, list): - inputs = [inputs] - - return inputs - - -class TextClassificationPipeline(Pipeline): - """ - Text classification pipeline using any `ModelForSequenceClassification`. - - This text classification pipeline can currently be loaded from `pipeline()` - using the following task identifier: `"text-classification"`. - - The models that this pipeline can use are models that have been fine-tuned on - a text classification task. - - :param return_all_scores: set True to return all model scores. Default False - """ - - def __init__(self, return_all_scores: bool = False, **kwargs): - super().__init__(**kwargs) - - self.return_all_scores = return_all_scores - - def __call__(self, *args, **kwargs): - """ - Classify the text(s) given as inputs. - - :param args: One or several texts (or one list of prompts) to classify - :param args: kwargs for inner call function - :return: A list or a list of list of dicts: Each result comes as list of dicts - with the following keys: - - `label` -- The label predicted. - - `score` -- The corresponding probability. - If ``self.return_all_scores=True``, one dictionary is returned per label - """ - outputs = super().__call__(*args, **kwargs) - - if isinstance(outputs, list) and outputs: - outputs = outputs[0] - - if self.config.num_labels == 1: - scores = 1.0 / (1.0 + np.exp(-outputs)) - else: - scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True) - if self.return_all_scores: - return [ - [ - {"label": self.config.id2label[i], "score": score.item()} - for i, score in enumerate(item) - ] - for item in scores - ] - else: - return [ - { - "label": self.config.id2label[item.argmax()], - "score": item.max().item(), - } - for item in scores - ] - - -class AggregationStrategy(ExplicitEnum): - """ - All the valid aggregation strategies for TokenClassificationPipeline - """ - - NONE = "none" - SIMPLE = "simple" - FIRST = "first" - AVERAGE = "average" - MAX = "max" - - -class TokenClassificationPipeline(Pipeline): - """ - Named Entity Recognition pipeline using any `ModelForTokenClassification`. - - This token classification pipeline can currently be loaded from `pipeline()` - using the following task identifier: `"token-classification"`. - - The models that this pipeline can use are models that have been fine-tuned on - a token classification task. - - :param args_parser: argument parser to use default is - TokenClassificationArgumentHandler - :param aggregation_strategy: AggregationStrategy Enum object to determine - the pipeline aggregation strategy. Default is AggregationStrategy.NONE - :param ignore_labels: list of labels to ignore. Default is `["O"]` - """ - - default_input_names = "sequences" - - def __init__( - self, - args_parser: ArgumentHandler = None, - aggregation_strategy: AggregationStrategy = AggregationStrategy.NONE, - ignore_labels: List[str] = False, - **kwargs, - ): - super().__init__( - args_parser=args_parser or TokenClassificationArgumentHandler(), - **kwargs, - ) - - self.ignore_labels = ignore_labels or ["O"] - - if isinstance(aggregation_strategy, str): - aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()] - - if ( - aggregation_strategy - in { - AggregationStrategy.FIRST, - AggregationStrategy.MAX, - AggregationStrategy.AVERAGE, - } - and not self.tokenizer.is_fast - ): - raise ValueError( - "Slow tokenizers cannot handle subwords. Please set the " - '`aggregation_strategy` option to `"simple"` or use a fast tokenizer.' - ) - - self.aggregation_strategy = aggregation_strategy - - def __call__(self, inputs: Union[str, List[str]], **kwargs): - """ - Classify each token of the text(s) given as inputs. - - - :param inputs: One or several texts (or one list of texts) for token - classification - :return: A list or a list of list of :obj:`dict`: Each result comes as a list - of dictionaries (one for each token in the corresponding input, or each - entity if this pipeline was instantiated with an aggregation_strategy) - with the following keys: - - `word` -- The token/word classified. - - `score` -- The corresponding probability for `entity`. - - `entity` -- The entity predicted for that token/word (it is named - `entity_group` when `aggregation_strategy` is not `"none"`. - - `index` -- The index of the corresponding token in the sentence. - - `start` -- index of the start of the corresponding entity in the sentence - Only exists if the offsets are available within the tokenizer - - `end` -- The index of the end of the corresponding entity in the sentence. - Only exists if the offsets are available within the tokenizer - """ - - _inputs, offset_mappings = self._args_parser(inputs, **kwargs) - - answers = [] - - tokens = self.tokenizer( - _inputs, - return_tensors=self._framework, - truncation=TruncationStrategy.LONGEST_FIRST.value, - padding=PaddingStrategy.MAX_LENGTH.value, - return_special_tokens_mask=True, - return_offsets_mapping=self.tokenizer.is_fast, - ) - - if self.tokenizer.is_fast: - offset_mapping = tokens.pop("offset_mapping") - elif not offset_mappings: - offset_mapping = [None] * len(_inputs) - - special_tokens_mask = tokens.pop("special_tokens_mask") - - # Forward - _forward_pass = self._forward(tokens) - for entities_index, current_entities in enumerate(_forward_pass[0]): - input_ids = tokens["input_ids"][entities_index] - - scores = np.exp(current_entities) / np.exp(current_entities).sum( - -1, keepdims=True - ) - pre_entities = self.gather_pre_entities( - _inputs[entities_index], - input_ids, - scores, - offset_mapping[entities_index], - special_tokens_mask[entities_index], - ) - grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy) - # Filter anything that is in self.ignore_labels - current_entities = [ - entity - for entity in grouped_entities - if entity.get("entity", None) not in self.ignore_labels - and entity.get("entity_group", None) not in self.ignore_labels - ] - answers.append(current_entities) - - if len(answers) == 1: - return answers[0] - return answers - - def gather_pre_entities( - self, - sentence: str, - input_ids: np.ndarray, - scores: np.ndarray, - offset_mapping: Optional[List[Tuple[int, int]]], - special_tokens_mask: np.ndarray, - ) -> List[dict]: - pre_entities = [] - for idx, token_scores in enumerate(scores): - # Filter special_tokens, they should only occur - # at the sentence boundaries since we're not encoding pairs of - # sentences so we don't have to keep track of those. - if special_tokens_mask[idx]: - continue - - word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) - if offset_mapping is not None: - start_ind, end_ind = offset_mapping[idx] - word_ref = sentence[start_ind:end_ind] - is_subword = len(word_ref) != len(word) - - if int(input_ids[idx]) == self.tokenizer.unk_token_id: - word = word_ref - is_subword = False - else: - start_ind = None - end_ind = None - is_subword = False - - pre_entity = { - "word": word, - "scores": token_scores, - "start": start_ind, - "end": end_ind, - "index": idx, - "is_subword": is_subword, - } - pre_entities.append(pre_entity) - return pre_entities - - def aggregate( - self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy - ) -> List[dict]: - if aggregation_strategy in { - AggregationStrategy.NONE, - AggregationStrategy.SIMPLE, - }: - entities = [] - for pre_entity in pre_entities: - entity_idx = pre_entity["scores"].argmax() - score = pre_entity["scores"][entity_idx] - entity = { - "entity": self.config.id2label[entity_idx], - "score": score, - "index": pre_entity["index"], - "word": pre_entity["word"], - "start": pre_entity["start"], - "end": pre_entity["end"], - } - entities.append(entity) - else: - entities = self.aggregate_words(pre_entities, aggregation_strategy) - - if aggregation_strategy == AggregationStrategy.NONE: - return entities - return self.group_entities(entities) - - def aggregate_word( - self, entities: List[dict], aggregation_strategy: AggregationStrategy - ) -> dict: - word = self.tokenizer.convert_tokens_to_string( - [entity["word"] for entity in entities] - ) - if aggregation_strategy == AggregationStrategy.FIRST: - scores = entities[0]["scores"] - idx = scores.argmax() - score = scores[idx] - entity = self.config.id2label[idx] - elif aggregation_strategy == AggregationStrategy.MAX: - max_entity = max(entities, key=lambda entity: entity["scores"].max()) - scores = max_entity["scores"] - idx = scores.argmax() - score = scores[idx] - entity = self.config.id2label[idx] - elif aggregation_strategy == AggregationStrategy.AVERAGE: - scores = np.stack([entity["scores"] for entity in entities]) - average_scores = np.nanmean(scores, axis=0) - entity_idx = average_scores.argmax() - entity = self.config.id2label[entity_idx] - score = average_scores[entity_idx] - else: - raise ValueError("Invalid aggregation_strategy") - new_entity = { - "entity": entity, - "score": score, - "word": word, - "start": entities[0]["start"], - "end": entities[-1]["end"], - } - return new_entity - - def aggregate_words( - self, entities: List[dict], aggregation_strategy: AggregationStrategy - ) -> List[dict]: - assert aggregation_strategy not in { - AggregationStrategy.NONE, - AggregationStrategy.SIMPLE, - }, "NONE and SIMPLE strategies are invalid" - - word_entities = [] - word_group = None - for entity in entities: - if word_group is None: - word_group = [entity] - elif entity["is_subword"]: - word_group.append(entity) - else: - word_entities.append( - self.aggregate_word(word_group, aggregation_strategy) - ) - word_group = [entity] - # Last item - word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) - return word_entities - - def group_sub_entities(self, entities: List[dict]) -> dict: - # Get the first entity in the entity group - entity = entities[0]["entity"].split("-")[-1] - scores = np.nanmean([entity["score"] for entity in entities]) - tokens = [entity["word"] for entity in entities] - - entity_group = { - "entity_group": entity, - "score": np.mean(scores), - "word": self.tokenizer.convert_tokens_to_string(tokens), - "start": entities[0]["start"], - "end": entities[-1]["end"], - } - return entity_group - - def get_tag(self, entity_name: str) -> Tuple[str, str]: - if entity_name.startswith("B-"): - bi = "B" - tag = entity_name[2:] - elif entity_name.startswith("I-"): - bi = "I" - tag = entity_name[2:] - else: - # It's not in B-, I- format - bi = "B" - tag = entity_name - return bi, tag - - def group_entities(self, entities: List[dict]) -> List[dict]: - - entity_groups = [] - entity_group_disagg = [] - - for entity in entities: - if not entity_group_disagg: - entity_group_disagg.append(entity) - continue - - # If the current entity is similar and adjacent to the previous entity, - # append it to the disaggregated entity group - # The split is meant to account for the "B" and "I" prefixes - # Shouldn't merge if both entities are B-type - bi, tag = self.get_tag(entity["entity"]) - last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"]) - - if tag == last_tag and bi != "B": - # Modify subword type to be previous_type - entity_group_disagg.append(entity) - else: - # If the current entity is different from the previous entity - # aggregate the disaggregated entity group - entity_groups.append(self.group_sub_entities(entity_group_disagg)) - entity_group_disagg = [entity] - if entity_group_disagg: - # it's the last entity, add it to the entity groups - entity_groups.append(self.group_sub_entities(entity_group_disagg)) - - return entity_groups - - -class QuestionAnsweringPipeline(Pipeline): - """ - Question Answering pipeline using any `ModelForQuestionAnswering` - - This question answering pipeline can currently be loaded from `pipeline()` - using the following task identifier: `"question-answering"`. - - The models that this pipeline can use are models that have been fine-tuned on - a question answering task. - - :param model: loaded inference engine to run the model with, can be a - deepsparse Engine or onnxruntime InferenceSession - :param tokenizer: tokenizer to be used for preprocessing - :param config: transformers model config for this model - :param engine_type: name of inference engine that is used. Options are - deepsparse and onnxruntime - :param input_names: list of input names to the neural network - :param args_parser: Reference to the object in charge of parsing supplied - pipeline parameters. A default is provided if None - :param binary_output: if True, stores outputs as pickled binaries to avoid - storing large amount of textual data. Default is False - """ - - default_input_names = "question,context" - - def __init__( - self, - model: Union[Engine, "onnxruntime.InferenceSession"], - tokenizer: PreTrainedTokenizer, - engine_type: str, - input_names: Optional[List[str]] = None, - **kwargs, - ): - super().__init__( - model=model, - tokenizer=tokenizer, - engine_type=engine_type, - args_parser=QuestionAnsweringArgumentHandler(), - input_names=input_names, - **kwargs, - ) - - @staticmethod - def create_sample( - question: Union[str, List[str]], context: Union[str, List[str]] - ) -> Union[SquadExample, List[SquadExample]]: - """ - :param question: single question or list of question strings - :param context: single context or list of context strings - :return: processed SquadExample object(s) for each question/context pair given - """ - if isinstance(question, list): - return [ - SquadExample(None, q, c, None, None, None) - for q, c in zip(question, context) - ] - else: - return SquadExample(None, question, context, None, None, None) - - def __call__(self, *args, **kwargs): - """ - Answer the question(s) given as inputs by using the context(s). - Multiple arguments can be used to pass the context, question data - - :param args: SquadExample or list of them containing the question and context - :param X: SquadExample or list of them containing the question and context - :param data: SquadExample or list of them containing the question and context - :param question: single question or list of question strings - :param context: single context or list of context strings - :param topk: the number of answers to return. Will be chosen by - order of likelihood) - :param doc_stride: if the context is too long to fit with the question for the - model, it will be split in several chunks with some overlap. This argument - controls the size of that overlap - :param max_answer_len: maximum length of predicted answers (e.g., only - answers with a shorter length are considered) - :param max_seq_len: maximum length of the total sentence (context + question) - after tokenization. The context will be split in several chunks - (using the doc_stride) if needed - :param max_question_len: maximum length of the question after tokenization. - It will be truncated if needed - :param handle_impossible_answer: whether or not we accept impossible as an - answer - :param num_spans: maximum number of span to use as input from a long - context. Default is to stride the entire context string - :param preprocessed_inputs: if provided, preprocessing will be skipped in favor - of these inputs. Expected format is the output of self.preprocess; a tuple - of (examples, features_list) - :return: dict or list of dictionaries, each containing the following keys: - `"score"` - The probability associated to the answer - `"start"` - The start index of the answer - `"end"` - The end index of the answer - `"answer"` - The answer to the question - """ - # Set defaults values - kwargs.setdefault("topk", 1) - kwargs.setdefault("max_answer_len", 15) - kwargs.setdefault("handle_impossible_answer", False) - kwargs.setdefault("preprocessed_inputs", None) # (examples, features_list) - - if kwargs["topk"] < 1: - raise ValueError(f"topk parameter should be >= 1 (got {kwargs['topk']})") - - if kwargs["max_answer_len"] < 1: - raise ValueError( - "max_answer_len parameter should be >= 1 " - f"(got {kwargs['max_answer_len']})" - ) - - # run pre-processing if not provided - examples, features_list = kwargs["preprocessed_inputs"] or self.preprocess( - *args, **kwargs - ) - - # forward pass and post-processing - all_answers = [] - for features, example in zip(features_list, examples): - model_input_names = self.tokenizer.model_input_names + ["input_ids"] - fw_args = { - k: [feature.__dict__[k] for feature in features] - for k in model_input_names - } - - # Manage tensor allocation on correct device - fw_args = {k: np.array(v) for (k, v) in fw_args.items()} - start, end = self._forward(fw_args)[:2] - - # TODO: torch - # fw_args = {k: torch.tensor(v, device=self.device) - # for (k, v) in fw_args.items()} - # start, end = self.model(**fw_args)[:2] - # start, end = start.cpu().numpy(), end.cpu().numpy() - - min_null_score = 1000000 # large and positive - answers = [] - for (feature, start_, end_) in zip(features, start, end): - # Ensure padded tokens & question tokens cannot belong - undesired_tokens = ( - np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask - ) - - # Generate mask - undesired_tokens_mask = undesired_tokens == 0.0 - - # Make sure non-context indexes cannot contribute to the softmax - start_ = np.where(undesired_tokens_mask, -10000.0, start_) - end_ = np.where(undesired_tokens_mask, -10000.0, end_) - - # Normalize logits and spans to retrieve the answer - start_ = np.exp( - start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)) - ) - end_ = np.exp( - end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)) - ) - - if kwargs["handle_impossible_answer"]: - min_null_score = min(min_null_score, (start_[0] * end_[0]).item()) - - # Mask CLS - start_[0] = end_[0] = 0.0 - - starts, ends, scores = self.decode( - start_, end_, kwargs["topk"], kwargs["max_answer_len"] - ) - - if not self.tokenizer.is_fast: - char_to_word = np.array(example.char_to_word_offset) - answers += [ - { - "score": score.item(), - "start": np.where( - char_to_word == feature.token_to_orig_map[s] - )[0][0].item(), - "end": np.where( - char_to_word == feature.token_to_orig_map[e] - )[0][-1].item(), - "answer": " ".join( - example.doc_tokens[ - feature.token_to_orig_map[ - s - ] : feature.token_to_orig_map[e] - + 1 - ] - ), - } - for s, e, score in zip(starts, ends, scores) - ] - else: - question_first = bool(self.tokenizer.padding_side == "right") - - # Sometimes the max probability token is in the middle of a word so: - # we start by finding the right word containing the token with - # `token_to_word` then we convert this word in a character span - answers += [ - { - "score": score.item(), - "start": feature.encoding.word_to_chars( - feature.encoding.token_to_word(s), - sequence_index=1 if question_first else 0, - )[0], - "end": feature.encoding.word_to_chars( - feature.encoding.token_to_word(e), - sequence_index=1 if question_first else 0, - )[1], - "answer": example.context_text[ - feature.encoding.word_to_chars( - feature.encoding.token_to_word(s), - sequence_index=1 if question_first else 0, - )[0] : feature.encoding.word_to_chars( - feature.encoding.token_to_word(e), - sequence_index=1 if question_first else 0, - )[ - 1 - ] - ], - } - for s, e, score in zip(starts, ends, scores) - ] - - if kwargs["handle_impossible_answer"]: - answers.append( - {"score": min_null_score, "start": 0, "end": 0, "answer": ""} - ) - - answers = sorted(answers, key=lambda x: x["score"], reverse=True)[ - : kwargs["topk"] - ] - all_answers += answers - - if len(all_answers) == 1: - return all_answers[0] - return all_answers - - def preprocess(self, *args, **kwargs) -> Tuple[Any, Any]: - """ - preprocess the given QA model inputs using squad_convert_examples_to_features - - :param args: SquadExample or list of them containing the question and context - :param X: SquadExample or list of them containing the question and context - :param data: SquadExample or list of them containing the question and context - :param question: single question or list of question strings - :param context: single context or list of context strings - :param doc_stride: if the context is too long to fit with the question for the - model, it will be split in several chunks with some overlap. This argument - controls the size of that overlap - :param max_seq_len: maximum length of the total sentence (context + question) - after tokenization. The context will be split in several chunks - (using the doc_stride) if needed - :param max_question_len: maximum length of the question after tokenization. - It will be truncated if needed - :param num_spans: maximum number of spans to use as input from a long - context. Default is to stride the entire context string - :return: tuple of SquadExample inputs and preprocessed features list - """ - kwargs.setdefault("doc_stride", 128) - kwargs.setdefault("max_seq_len", self.max_length) - kwargs.setdefault("max_question_len", 64) - kwargs.setdefault("num_spans", None) - - # Convert inputs to features - examples = self._args_parser(*args, **kwargs) - if not self.tokenizer.is_fast: - features_list = [ - squad_convert_examples_to_features( - examples=[example], - tokenizer=self.tokenizer, - max_seq_length=kwargs["max_seq_len"], - doc_stride=kwargs["doc_stride"], - max_query_length=kwargs["max_question_len"], - padding_strategy=PaddingStrategy.MAX_LENGTH.value, - is_training=False, - tqdm_enabled=False, - ) - for example in examples - ] - else: - features_list = self._encode_features_fast(examples, **kwargs) - - if kwargs["num_spans"]: - features_list = [ - features[: kwargs["num_spans"]] for features in features_list - ] - - return examples, features_list - - def decode( - self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int - ) -> Tuple: - """ - :param start: Individual start probabilities for each token - :param end: Individual end probabilities for each token - :param topk: Indicates how many possible answer span(s) to extract from the - model output - :param max_answer_len: Maximum size of the answer to extract from the model - output - :return: probabilities for each span to be the actual answer. Will filter out - unwanted and impossible cases - """ - # Ensure we have batch axis - if start.ndim == 1: - start = start[None] - - if end.ndim == 1: - end = end[None] - - # Compute the score of each tuple(start, end) to be the real answer - outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1)) - - # Remove candidate with end < start and end - start > max_answer_len - candidates = np.tril(np.triu(outer), max_answer_len - 1) - - # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA) - scores_flat = candidates.flatten() - if topk == 1: - idx_sort = [np.argmax(scores_flat)] - elif len(scores_flat) < topk: - idx_sort = np.argsort(-scores_flat) - else: - idx = np.argpartition(-scores_flat, topk)[0:topk] - idx_sort = idx[np.argsort(-scores_flat[idx])] - - start, end = np.unravel_index(idx_sort, candidates.shape)[1:] - return start, end, candidates[0, start, end] - - def span_to_answer( - self, text: str, start: int, end: int - ) -> Dict[str, Union[str, int]]: - """ - When decoding from token probabilities, this method maps token indexes to - actual word in the initial context. - - :param text: The actual context to extract the answer from - :param start: The answer starting token index - :param end: The answer end token index - :return: Dictionary containing the start, end, and answer - """ - words = [] - token_idx = char_start_idx = char_end_idx = chars_idx = 0 - - for i, word in enumerate(text.split(" ")): - token = self.tokenizer.tokenize(word) - - # Append words if they are in the span - if start <= token_idx <= end: - if token_idx == start: - char_start_idx = chars_idx - - if token_idx == end: - char_end_idx = chars_idx + len(word) - - words += [word] - - # Stop if we went over the end of the answer - if token_idx > end: - break - - # Append the subtokenization length to the running index - token_idx += len(token) - chars_idx += len(word) + 1 - - # Join text with spaces - return { - "answer": " ".join(words), - "start": max(0, char_start_idx), - "end": min(len(text), char_end_idx), - } - - def _encode_features_fast(self, examples: Any, **kwargs) -> List[SquadFeatures]: - features_list = [] - for example in examples: - # Define the side we want to truncate / pad and the text/pair sorting - question_first = bool(self.tokenizer.padding_side == "right") - - encoded_inputs = self.tokenizer( - text=example.question_text if question_first else example.context_text, - text_pair=( - example.context_text if question_first else example.question_text - ), - padding=PaddingStrategy.MAX_LENGTH.value, - truncation="only_second" if question_first else "only_first", - max_length=kwargs["max_seq_len"], - stride=kwargs["doc_stride"], - return_tensors="np", - return_token_type_ids=True, - return_overflowing_tokens=True, - return_offsets_mapping=True, - return_special_tokens_mask=True, - ) - - total_spans = len(encoded_inputs["input_ids"]) - - # p_mask: mask with 1 for token than cannot be in the answer - # We put 0 on the tokens from the context and 1 everywhere else - p_mask = np.asarray( - [ - [ - tok != 1 if question_first else 0 - for tok in encoded_inputs.sequence_ids(span_id) - ] - for span_id in range(total_spans) - ] - ) - - # keep the cls_token unmasked - if self.tokenizer.cls_token_id is not None: - cls_index = np.nonzero( - encoded_inputs["input_ids"] == self.tokenizer.cls_token_id - ) - p_mask[cls_index] = 0 - - features = [] - for span_idx in range(total_spans): - features.append( - SquadFeatures( - input_ids=encoded_inputs["input_ids"][span_idx], - attention_mask=encoded_inputs["attention_mask"][span_idx], - token_type_ids=encoded_inputs["token_type_ids"][span_idx], - p_mask=p_mask[span_idx].tolist(), - encoding=encoded_inputs[span_idx], - # the following values are unused for fast tokenizers - cls_index=None, - token_to_orig_map={}, - example_index=0, - unique_id=0, - paragraph_len=0, - token_is_max_context=0, - tokens=[], - start_position=0, - end_position=0, - is_impossible=False, - qas_id=None, - ) - ) - features_list.append(features) - return features_list - - -@dataclass -class TaskInfo: - """ - Information about an NLP task - - :param pipeline_constructor: reference to constructor for the given pipeline task - :param default model name: the transformers canonical name for the default model - :param base_stub: sparsezoo stub path for the base model for this task - :param default_pruned_stub: sparsezoo stub path for the default pruned model - for this task - :param default_quant_stub: sparsezoo stub path for the default quantized model - for this task - """ - - pipeline_constructor: Callable[[Any], Pipeline] - default_model_name: str - base_stub: Optional[str] = None - default_pruned_stub: Optional[str] = None - default_quant_stub: Optional[str] = None - - -# Register all the supported tasks here -SUPPORTED_TASKS = { - "ner": TaskInfo( - pipeline_constructor=TokenClassificationPipeline, - default_model_name="bert-base-uncased", - ), - "question-answering": TaskInfo( - pipeline_constructor=QuestionAnsweringPipeline, - default_model_name="bert-base-uncased", - base_stub=( - "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/base-none" - ), - default_pruned_stub=( - "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/" - "pruned-aggressive_98" - ), - ), - "sentiment-analysis": TaskInfo( - pipeline_constructor=TextClassificationPipeline, - default_model_name="bert-base-uncased", - ), - "text-classification": TaskInfo( - pipeline_constructor=TextClassificationPipeline, - default_model_name="bert-base-uncased", - ), - "token-classification": TaskInfo( - pipeline_constructor=TokenClassificationPipeline, - default_model_name="bert-base-uncased", - ), -} - -DEEPSPARSE_ENGINE = "deepsparse" -ORT_ENGINE = "onnxruntime" - -SUPPORTED_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE] - - -def pipeline( - task: str, - model_name: Optional[str] = None, - model_path: Optional[str] = None, - engine_type: str = DEEPSPARSE_ENGINE, - config: Optional[Union[str, PretrainedConfig]] = None, - tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, - max_length: int = 128, - num_cores: Optional[int] = None, - scheduler: Optional[str] = None, - batch_size: Optional[int] = 1, - **kwargs, -) -> Pipeline: - """ - Utility factory method to build a Pipeline - - :param task: name of the task to define which pipeline to create. Currently - supported task - "question-answering" - :param model_name: canonical name of the hugging face model this model is based on - :param model_path: path to model directory containing `model.onnx`, `config.json`, - and `tokenizer.json` files, ONNX model file, or SparseZoo stub - :param engine_type: inference engine name to use. Supported options are 'deepsparse' - and 'onnxruntime' - :param config: huggingface model config, if none provided, default will be used - which will be from the model name or sparsezoo stub if given for model path - :param tokenizer: huggingface tokenizer, if none provided, default will be used - :param max_length: maximum sequence length of model inputs. default is 128 - :param num_cores: number of CPU cores to run engine with. Default is the maximum - available - :param scheduler: The scheduler to use for the engine. Can be None, single or multi. - :param batch_size: The batch_size to use for the pipeline. Defaults to 1 - Note: `question-answering` pipeline only supports a batch_size of 1. - :param kwargs: additional key word arguments for task specific pipeline constructor - :return: Pipeline object for the given taks and model - """ - - # Retrieve the task - if task not in SUPPORTED_TASKS: - raise KeyError( - f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys())}" - ) - if engine_type not in SUPPORTED_ENGINES: - raise ValueError( - f"Unsupported engine {engine_type}, supported engines " - f"are {SUPPORTED_ENGINES}" - ) - if task == "question-answering" and batch_size != 1: - raise ValueError( - f"{task} pipeline only supports batch_size 1. " - f"Supplied batch_size = {batch_size}" - ) - task_info = SUPPORTED_TASKS[task] - - model_path = model_path or _get_default_model_path(task_info) - model_name = model_name or task_info.default_model_name - - onnx_path, config_path, tokenizer_path = get_onnx_path_and_configs(model_path) - - # default the tokenizer and config to file in model directory or given model name - config = config or config_path or model_name - tokenizer = tokenizer or tokenizer_path or model_name - - # create model - model, input_names = _create_model( - onnx_path, - engine_type, - num_cores, - max_length, - scheduler=scheduler, - batch_size=batch_size, - ) - - # Instantiate tokenizer if needed - if isinstance(tokenizer, (str, tuple)): - if isinstance(tokenizer, tuple): - # For tuple we have (tokenizer name, {kwargs}) - tokenizer_kwargs = tokenizer[1] - tokenizer_kwargs["model_max_length"] = max_length - tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1]) - else: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer, model_max_length=max_length - ) - - # Instantiate config if needed - if config is not None and isinstance(config, str): - config = AutoConfig.from_pretrained(config, finetuning_task=task) - - return task_info.pipeline_constructor( - model=model, - tokenizer=tokenizer, - config=config, - engine_type=engine_type, - max_length=max_length, - input_names=input_names, - **kwargs, - ) - - -def _get_default_model_path(task_info: TaskInfo) -> str: - if cpu.cpu_vnni_compatible() and task_info.default_quant_stub: - return task_info.default_quant_stub - return task_info.default_pruned_stub or task_info.base_stub - - -def _create_model( - model_path: str, - engine_type: str, - num_cores: Optional[int], - max_length: int = 128, - scheduler: Optional[str] = None, - batch_size: int = 1, -) -> Tuple[Union[Engine, "onnxruntime.InferenceSession"], List[str]]: - onnx_path, input_names, _ = overwrite_transformer_onnx_model_inputs( - model_path, max_length=max_length - ) - - if engine_type == DEEPSPARSE_ENGINE: - model = compile_model( - onnx_path, - batch_size=batch_size, - num_cores=num_cores, - scheduler=scheduler, - ) - elif engine_type == ORT_ENGINE: - _validate_ort_import() - sess_options = onnxruntime.SessionOptions() - if num_cores is not None: - sess_options.intra_op_num_threads = num_cores - sess_options.log_severity_level = 3 - sess_options.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - ) - - model = onnxruntime.InferenceSession(onnx_path, sess_options=sess_options) - - return model, input_names - - -def _validate_ort_import(): - if ort_import_error is not None: - raise ImportError( - "An exception occurred when importing onxxruntime. Please verify that " - "onnxruntime is installed in order to use the onnxruntime inference " - f"engine. \n\nException info: {ort_import_error}" - ) - - -def process_dataset( - pipeline_object: Callable, - data_path: str, - batch_size: int, - task: str, - output_path: str, -) -> None: - """ - :param pipeline_object: An instantiated pipeline Callable object - :param data_path: Path to input file, supports csv, json and text files - :param batch_size: batch_size to use for inference - :param task: The task pipeline is instantiated for - :param output_path: Path to a json file to output inference results to - """ - batch_loader = get_batch_loader( - data_file=data_path, - batch_size=batch_size, - task=task, - ) - # Wraps pipeline object to make numpy types serializable - pipeline_object = fix_numpy_types(pipeline_object) - with open(output_path, "a") as output_file: - for batch in batch_loader: - batch_output = pipeline_object(**batch) - json.dump(batch_output, output_file) - output_file.write("\n") diff --git a/src/deepsparse/transformers/pipelines/__init__.py b/src/deepsparse/transformers/pipelines/__init__.py new file mode 100644 index 0000000000..9986181a2a --- /dev/null +++ b/src/deepsparse/transformers/pipelines/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .pipeline import * +from .question_answering import * +from .text_classification import * +from .token_classification import * diff --git a/src/deepsparse/transformers/pipelines/pipeline.py b/src/deepsparse/transformers/pipelines/pipeline.py new file mode 100644 index 0000000000..2fdcd27236 --- /dev/null +++ b/src/deepsparse/transformers/pipelines/pipeline.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base Pipeline class for transformers inference pipeline +""" + + +import warnings +from typing import Any, List, Mapping, Optional + +import numpy +from transformers.models.auto import AutoConfig, AutoTokenizer + +from deepsparse import Pipeline +from deepsparse.transformers.helpers import ( + get_onnx_path_and_configs, + overwrite_transformer_onnx_model_inputs, +) + + +__all__ = [ + "TransformersPipeline", + "pipeline", +] + + +class TransformersPipeline(Pipeline): + """ + Base deepsparse.Pipeline class for transformers model loading. This class handles + the parsing of deepsparse-transformers files and model inputs, supporting loading + from sparsezoo, a directory containing a model.onnx, tokenizer, and model config, + or just an ONNX file with the ability to load a tokenizer and model config from + a default huggingface-transformers model. + + Note, when implementing child tasks in deepsparse.transformers.pipelines, + in addition to registering task names with Pipeline.register, task names should + be added to the supported nlp tasks in deepsparse.tasks so they can be properly + imported at runtime. + + :param model_path: sparsezoo stub to a transformers model, an ONNX file, or + (preferred) a directory containing a model.onnx, tokenizer config, and model + config. If no tokenizer and/or model config(s) are found, then they will be + loaded from huggingface transformers using the `default_model_name` key + :param engine_type: inference engine to use. Currently supported values include + 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param sequence_length: static sequence length to use for inference + :param default_model_name: huggingface transformers model name to use to + load a tokenizer and model config when none are provided in the `model_path`. + Default is 'bert-base-uncased' + """ + + def __init__( + self, + *, + sequence_length: int = 128, + default_model_name: str = "bert-base-uncased", + **kwargs, + ): + + self._sequence_length = sequence_length + self._default_model_name = default_model_name + + self.config = None + self.tokenizer = None + self.onnx_input_names = None + + self._temp_model_directory = None + + super().__init__(**kwargs) + + @property + def sequence_length(self) -> int: + """ + :return: static sequence length to use for inference + """ + return self._sequence_length + + @property + def default_model_name(self) -> str: + """ + :return: huggingface transformers model name to use to + load a tokenizer and model config when none are provided in the + `model_path` + """ + return self._default_model_name + + def setup_onnx_file_path(self) -> str: + """ + Parses ONNX, tokenizer, and config file paths from the given `model_path`. + Supports sparsezoo stubs. If a tokenizer and/or config file are not found, + they will be defaulted to the default_model_name in the transformers repo + + :return: file path to the processed ONNX file for the engine to compile + """ + onnx_path, config_path, tokenizer_path = get_onnx_path_and_configs( + self.model_path + ) + + # default config + tokenizer if necessary + config_path = config_path or self.default_model_name + tokenizer_path = tokenizer_path or self.default_model_name + + self.config = AutoConfig.from_pretrained( + config_path, finetuning_task=self.task if hasattr(self, "task") else None + ) + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, model_max_length=self.sequence_length + ) + + # overwrite onnx graph to given required input shape + ( + onnx_path, + self.onnx_input_names, + self._temp_model_directory, + ) = overwrite_transformer_onnx_model_inputs( + onnx_path, max_length=self.sequence_length + ) + + return onnx_path + + def tokens_to_engine_input( + self, tokens: Mapping[Any, numpy.ndarray] + ) -> List[numpy.ndarray]: + """ + :param tokens: outputs of the pipeline tokenizer + :return: list of numpy arrays in expected order for model input + """ + if not all(name in tokens for name in self.onnx_input_names): + raise ValueError( + f"pipeline expected arrays with names {self.onnx_input_names}, " + f"received inputs: {list(tokens.keys())}" + ) + + return [tokens[name] for name in self.onnx_input_names] + + +def pipeline( + task: str, + model_name: Optional[str] = None, + model_path: Optional[str] = None, + engine_type: str = "deepsparse", + config: Optional[str] = None, + tokenizer: Optional[str] = None, + max_length: int = 128, + num_cores: Optional[int] = None, + scheduler: Optional[str] = None, + batch_size: Optional[int] = 1, + **kwargs, +) -> Pipeline: + """ + [DEPRECATED] - deepsparse.transformers.pipeline is deprecated to craete DeepSparse + pipelines for tranformers tasks use deepsparse.Pipeline.create(task, ...) + + Utility factory method to build a Pipeline + + :param task: name of the task to define which pipeline to create. Currently + supported task - "question-answering" + :param model_name: canonical name of the hugging face model this model is based on + :param model_path: path to model directory containing `model.onnx`, `config.json`, + and `tokenizer.json` files, ONNX model file, or SparseZoo stub + :param engine_type: inference engine name to use. Options are 'deepsparse' + and 'onnxruntime'. Default is 'deepsparse' + :param config: huggingface model config, if none provided, default will be used + which will be from the model name or sparsezoo stub if given for model path + :param tokenizer: huggingface tokenizer, if none provided, default will be used + :param max_length: maximum sequence length of model inputs. default is 128 + :param num_cores: number of CPU cores to run engine with. Default is the maximum + available + :param scheduler: The scheduler to use for the engine. Can be None, single or multi + :param batch_size: The batch_size to use for the pipeline. Defaults to 1 + Note: `question-answering` pipeline only supports a batch_size of 1. + :param kwargs: additional key word arguments for task specific pipeline constructor + :return: Pipeline object for the given taks and model + """ + warnings.warn( + "[DEPRECATED] - deepsparse.transformers.pipeline is deprecated to craete " + "DeepSparse pipelines for tranformers tasks use deepsparse.Pipeline.create()" + ) + + if config is not None or tokenizer is not None: + raise ValueError( + "Directly passing in a config or tokenizer to DeepSparse transformers " + "pipelines is no longer supported. config and tokenizer objects should " + "be specified by including config.json and tokenizer.json files in the " + "model directory respectively" + ) + + return Pipeline.create( + task=task, + model_path=model_path, + engine_type=engine_type, + batch_size=batch_size, + num_cores=num_cores, + scheduler=scheduler, + sequence_length=max_length, + default_model_name=model_name, + **kwargs, + ) diff --git a/src/deepsparse/transformers/pipelines/question_answering.py b/src/deepsparse/transformers/pipelines/question_answering.py new file mode 100644 index 0000000000..f15f3ba45d --- /dev/null +++ b/src/deepsparse/transformers/pipelines/question_answering.py @@ -0,0 +1,405 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# postprocessing adapted from huggingface/transformers + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Pipeline implementation and pydantic models for question answering transformers +tasks +""" + + +from typing import Any, Dict, List, Tuple, Type + +import numpy +from pydantic import BaseModel, Field +from transformers.data import ( + SquadExample, + SquadFeatures, + squad_convert_examples_to_features, +) +from transformers.tokenization_utils_base import PaddingStrategy + +from deepsparse import Pipeline +from deepsparse.transformers.pipelines import TransformersPipeline + + +__all__ = [ + "QuestionAnsweringInput", + "QuestionAnsweringOutput", + "QuestionAnsweringPipeline", +] + + +class QuestionAnsweringInput(BaseModel): + """ + Schema for inputs to question_answering pipelines + """ + + question: str = Field(description="String question to be answered") + context: str = Field(description="String representing context for answer") + + +class QuestionAnsweringOutput(BaseModel): + """ + Schema for question_answering pipeline output. Values are in batch order + """ + + score: float = Field(description="confidence score for prediction") + answer: str = Field(description="predicted answer") + start: int = Field(description="start index of the answer") + end: int = Field(description="end index of the answer") + + +@Pipeline.register( + task="question_answering", + task_aliases=["qa"], +) +class QuestionAnsweringPipeline(TransformersPipeline): + """ + transformers question_answering pipeline + + example instantiation: + ```python + question_answering = Pipeline.create( + task="question_answering", + model_path="question_answering_model_dir/", + ) + ``` + + :param model_path: sparsezoo stub to a transformers model, an ONNX file, or + (preferred) a directory containing a model.onnx, tokenizer config, and model + config. If no tokenizer and/or model config(s) are found, then they will be + loaded from huggingface transformers using the `default_model_name` key + :param engine_type: inference engine to use. Currently supported values include + 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param sequence_length: sequence length to compile model and tokenizer for. + Default is 128 + :param default_model_name: huggingface transformers model name to use to + load a tokenizer and model config when none are provided in the `model_path`. + Default is 'bert-base-uncased' + :param doc_stride: if the context is too long to fit with the question for the + model, it will be split in several chunks with some overlap. This argument + controls the size of that overlap. Currently, only reading the first span + is supported (everything after doc_stride will be truncated). Default + is 128 + :param max_question_len: maximum length of the question after tokenization. + It will be truncated if needed. Default is 64 + :param max_answer_len: maximum length of answer after decoding. Default is 15 + """ + + def __init__( + self, + *, + doc_stride: int = 128, + max_question_length: int = 64, + max_answer_length: int = 15, + **kwargs, + ): + + if kwargs.get("batch_size") and kwargs["batch_size"] > 1: + raise ValueError( + f"{self.__class__.__name__} currently only supports batch size 1, " + f"batch size set to {kwargs['batch_size']}" + ) + + self._doc_stride = doc_stride + self._max_question_length = max_question_length + self._max_answer_length = max_answer_length + + super().__init__(**kwargs) + + @property + def doc_stride(self) -> int: + """ + :return: if the context is too long to fit with the question for the + model, it will be split in several chunks with some overlap. This argument + controls the size of that overlap. Currently, only reading the first span + is supported (everything after doc_stride will be truncated) + """ + return self._doc_stride + + @property + def max_answer_length(self) -> int: + """ + :return: maximum length of answer after decoding + """ + return self._max_answer_length + + @property + def max_question_length(self) -> int: + """ + :return: maximum length of the question after tokenization. + It will be truncated if needed + """ + return self._max_question_length + + @property + def input_model(self) -> Type[BaseModel]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + return QuestionAnsweringInput + + @property + def output_model(self) -> Type[BaseModel]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + return QuestionAnsweringOutput + + def process_inputs( + self, + inputs: QuestionAnsweringInput, + ) -> Tuple[List[numpy.ndarray], Dict[str, Any]]: + """ + :param inputs: inputs to the pipeline. Must be the type of the + QuestionAnsweringInput + :return: inputs of this model processed into a list of numpy arrays that + can be directly passed into the forward pass of the pipeline engine and + dictionary of parsed features and original extracted example + """ + squad_example = SquadExample( + None, inputs.question, inputs.context, None, None, None + ) + features = self._tokenize(squad_example) + tokens = features.__dict__ + + engine_inputs = self.tokens_to_engine_input(tokens) + # add batch dimension, assuming batch size 1 + engine_inputs = [numpy.expand_dims(inp, axis=0) for inp in engine_inputs] + + return engine_inputs, dict( + features=features, + example=squad_example, + ) + + def process_engine_outputs( + self, engine_outputs: List[numpy.ndarray], **kwargs + ) -> BaseModel: + """ + :param engine_outputs: list of numpy arrays that are the output of the engine + forward pass + :return: outputs of engine post-processed into an object in the `output_model` + format of this pipeline + """ + features = kwargs["features"] + example = kwargs["example"] + start_vals, end_vals = engine_outputs[:2] + + # assuming batch size 0 + start = start_vals[0] + end = end_vals[0] + + # Ensure padded tokens & question tokens cannot belong + undesired_tokens = ( + numpy.abs(numpy.array(features.p_mask) - 1) & features.attention_mask + ) + + # Generate mask + undesired_tokens_mask = undesired_tokens == 0.0 + + # Make sure non-context indexes cannot contribute to the softmax + start = numpy.where(undesired_tokens_mask, -10000.0, start) + end = numpy.where(undesired_tokens_mask, -10000.0, end) + + # Normalize logits and spans to retrieve the answer + start = numpy.exp( + start - numpy.log(numpy.sum(numpy.exp(start), axis=-1, keepdims=True)) + ) + end = numpy.exp( + end - numpy.log(numpy.sum(numpy.exp(end), axis=-1, keepdims=True)) + ) + + # Mask CLS + start[0] = 0.0 + end[0] = 0.0 + + ans_start, ans_end, scores = self._decode(start, end) + # assuming one stride, so grab first idx + ans_start = ans_start[0] + ans_end = ans_end[0] + score = scores[0] + + # decode start, end idx into text + if not self.tokenizer.is_fast: + char_to_word = numpy.array(example.char_to_word_offset) + return self.output_model( + score=score.item(), + start=numpy.where( + char_to_word == features.token_to_orig_map[ans_start] + )[0][0].item(), + end=numpy.where(char_to_word == features.token_to_orig_map[ans_end])[0][ + -1 + ].item(), + answer=" ".join( + example.doc_tokens[ + features.token_to_orig_map[ + ans_start + ] : features.token_to_orig_map[ans_end] + + 1 + ] + ), + ) + else: + question_first = bool(self.tokenizer.padding_side == "right") + + # Sometimes the max probability token is in the middle of a word so: + # we start by finding the right word containing the token with + # `token_to_word` then we convert this word in a character span + return self.output_model( + score=score.item(), + start=features.encoding.word_to_chars( + features.encoding.token_to_word(ans_start), + sequence_index=1 if question_first else 0, + )[0], + end=features.encoding.word_to_chars( + features.encoding.token_to_word(ans_end), + sequence_index=1 if question_first else 0, + )[1], + answer=example.context_text[ + features.encoding.word_to_chars( + features.encoding.token_to_word(ans_start), + sequence_index=1 if question_first else 0, + )[0] : features.encoding.word_to_chars( + features.encoding.token_to_word(ans_end), + sequence_index=1 if question_first else 0, + )[ + 1 + ] + ], + ) + + def _tokenize(self, example: SquadExample): + if not self.tokenizer.is_fast: + features = squad_convert_examples_to_features( + examples=[example], + tokenizer=self.tokenizer, + max_set_length=self.sequence_length, + doc_stride=self.doc_stride, + max_query_length=self.max_question_length, + padding_strategy=PaddingStrategy.MAX_LENGTH.value, + is_training=False, + tqdm_enabled=False, + ) + # only 1 span supported so taking only the first element of features + # to add support for num_spans switch to features = features[:num_spans] + # not included for now due to static batch requirements in production + features = features[0] + else: + question_first = bool(self.tokenizer.padding_side == "right") + encoded_inputs = self.tokenizer( + text=example.question_text if question_first else example.context_text, + text_pair=( + example.context_text if question_first else example.question_text + ), + padding=PaddingStrategy.MAX_LENGTH.value, + truncation="only_second" if question_first else "only_first", + max_length=self.sequence_length, + stride=self.doc_stride, + return_tensors="np", + return_token_type_ids=True, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_special_tokens_mask=True, + ) + + # only 1 span supported so taking only the first element of features + # to add support for num_spans switch hardcoded 0 idx lookups to loop + # over values in num_spans + + # p_mask: mask with 1 for token than cannot be in the answer + # We put 0 on the tokens from the context and 1 everywhere else + p_mask = numpy.asarray( + [ + [ + tok != 1 if question_first else 0 + for tok in encoded_inputs.sequence_ids(0) + ] + ] + ) + + # keep the cls_token unmasked + if self.tokenizer.cls_token_id is not None: + cls_index = numpy.nonzero( + encoded_inputs["input_ids"][0] == self.tokenizer.cls_token_id + ) + p_mask[cls_index] = 0 + + features = SquadFeatures( + input_ids=encoded_inputs["input_ids"][0], + attention_mask=encoded_inputs["attention_mask"][0], + token_type_ids=encoded_inputs["token_type_ids"][0], + p_mask=p_mask[0].tolist(), + encoding=encoded_inputs[0], + # the following values are unused for fast tokenizers + cls_index=None, + token_to_orig_map={}, + example_index=0, + unique_id=0, + paragraph_len=0, + token_is_max_context=0, + tokens=[], + start_position=0, + end_position=0, + is_impossible=False, + qas_id=None, + ) + + return features + + def _decode(self, start: numpy.ndarray, end: numpy.ndarray) -> Tuple: + # Ensure we have batch axis + if start.ndim == 1: + start = start[None] + + if end.ndim == 1: + end = end[None] + + # Compute the score of each tuple(start, end) to be the real answer + outer = numpy.matmul(numpy.expand_dims(start, -1), numpy.expand_dims(end, 1)) + + # Remove candidate with end < start and end - start > max_answer_len + candidates = numpy.tril(numpy.triu(outer), self.max_answer_length - 1) + + # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA) + scores_flat = candidates.flatten() + # only returning best result, use argsort for topk support + idx_sort = [numpy.argmax(scores_flat)] + + start, end = numpy.unravel_index(idx_sort, candidates.shape)[1:] + return start, end, candidates[0, start, end] diff --git a/src/deepsparse/transformers/pipelines/text_classification.py b/src/deepsparse/transformers/pipelines/text_classification.py new file mode 100644 index 0000000000..44449b5c46 --- /dev/null +++ b/src/deepsparse/transformers/pipelines/text_classification.py @@ -0,0 +1,217 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# postprocessing adapted from huggingface/transformers + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Pipeline implementation and pydantic models for text classification transformers +tasks +""" + + +from typing import List, Type, Union + +import numpy +from pydantic import BaseModel, Field +from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy + +from deepsparse import Pipeline +from deepsparse.transformers.pipelines import TransformersPipeline + + +__all__ = [ + "TextClassificationInput", + "TextClassificationOutput", + "TextClassificationPipeline", +] + + +class TextClassificationInput(BaseModel): + """ + Schema for inputs to text_classification pipelines + """ + + sequences: Union[List[List[str]], List[str], str] = Field( + description="A string or List of strings representing input to" + "text_classification task" + ) + + +class TextClassificationOutput(BaseModel): + """ + Schema for text_classification pipeline output. Values are in batch order + """ + + labels: List[str] = Field(description="The predicted labels in batch order") + scores: List[float] = Field( + description="The corresponding probability for each label in the batch" + ) + + +@Pipeline.register( + task="text_classification", + task_aliases=["glue", "sentiment_analysis"], +) +class TextClassificationPipeline(TransformersPipeline): + """ + transformers text classification pipeline + + example instantiation: + ```python + text_classifier = Pipeline.create( + task="text_classification", + model_path="text_classification_model_dir/", + batch_size=BATCH_SIZE, + ) + ``` + + example batch size 1, single text inputs (ie sentiment analysis): + ```python + sentiment = text_classifier("the food tastes great") + sentiment = text_classifier(["the food tastes great"]) + sentiment = text_classifier([["the food tastes great"]]) + ``` + + example batch size 1, multi text input (ie QQP like tasks): + ```python + prediction = text_classifier([["how is the food?", "what is the food?"]]) + ``` + + example batch size n, single text inputs: + ```python + sentiments = text_classifier(["the food tastes great", "the food tastes bad"]) + sentiments = text_classifier([["the food tastes great"], ["the food tastes bad"]]) + ``` + + :param model_path: sparsezoo stub to a transformers model, an ONNX file, or + (preferred) a directory containing a model.onnx, tokenizer config, and model + config. If no tokenizer and/or model config(s) are found, then they will be + loaded from huggingface transformers using the `default_model_name` key + :param engine_type: inference engine to use. Currently supported values include + 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param sequence_length: sequence length to compile model and tokenizer for. + Default is 128 + :param default_model_name: huggingface transformers model name to use to + load a tokenizer and model config when none are provided in the `model_path`. + Default is 'bert-base-uncased' + """ + + @property + def input_model(self) -> Type[BaseModel]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + return TextClassificationInput + + @property + def output_model(self) -> Type[BaseModel]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + return TextClassificationOutput + + def parse_inputs(self, *args, **kwargs) -> BaseModel: + """ + :param args: ordered arguments to pipeline, only an input_model object + is supported as an arg for this function + :param kwargs: keyword arguments to pipeline + :return: pipeline arguments parsed into the given `input_model` + schema if necessary. If an instance of the `input_model` is provided + it will be returned + """ + if args and kwargs: + raise ValueError( + f"{self.__class__} only support args OR kwargs. Found " + f" {len(args)} args and {len(kwargs)} kwargs" + ) + + if args: + if len(args) == 1: + # passed input_model schema directly + if isinstance(args[0], self.input_model): + return args[0] + return self.input_model(sequences=args[0]) + else: + return self.input_model(sequences=args) + + return self.input_model(**kwargs) + + def process_inputs(self, inputs: TextClassificationInput) -> List[numpy.ndarray]: + """ + :param inputs: inputs to the pipeline. Must be the type of the + TextClassificationInput + :return: inputs of this model processed into a list of numpy arrays that + can be directly passed into the forward pass of the pipeline engine + """ + tokens = self.tokenizer( + inputs.sequences, + add_special_tokens=True, + return_tensors="np", + padding=PaddingStrategy.MAX_LENGTH.value, + truncation=TruncationStrategy.LONGEST_FIRST.value, + ) + return self.tokens_to_engine_input(tokens) + + def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseModel: + """ + :param engine_outputs: list of numpy arrays that are the output of the engine + forward pass + :return: outputs of engine post-processed into an object in the `output_model` + format of this pipeline + """ + outputs = engine_outputs + if isinstance(outputs, list): + outputs = outputs[0] + + scores = ( + 1.0 / (1.0 + numpy.exp(-outputs)) + if self.config.num_labels == 1 + else numpy.exp(outputs) / numpy.exp(outputs).sum(-1, keepdims=True) + ) + + labels = [] + label_scores = [] + + for score in scores: + labels.append(self.config.id2label[score.argmax()]) + label_scores.append(score.max().item()) + + return self.output_model( + labels=labels, + scores=label_scores, + ) diff --git a/src/deepsparse/transformers/pipelines/token_classification.py b/src/deepsparse/transformers/pipelines/token_classification.py new file mode 100644 index 0000000000..6150085626 --- /dev/null +++ b/src/deepsparse/transformers/pipelines/token_classification.py @@ -0,0 +1,495 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# postprocessing adapted from huggingface/transformers + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Pipeline implementation and pydantic models for token classification transformers +tasks +""" +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy +from pydantic import BaseModel, Field +from transformers.file_utils import ExplicitEnum +from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy + +from deepsparse import Pipeline +from deepsparse.transformers.pipelines import TransformersPipeline + + +__all__ = [ + "AggregationStrategy", + "TokenClassificationInput", + "TokenClassificationResult", + "TokenClassificationOutput", + "TokenClassificationPipeline", +] + + +class AggregationStrategy(ExplicitEnum): + """ + Valid aggregation strategies for postprocessing in the TokenClassificationPipeline + """ + + NONE = "none" + SIMPLE = "simple" + FIRST = "first" + AVERAGE = "average" + MAX = "max" + + +class TokenClassificationInput(BaseModel): + """ + Schema for inputs to token_classification pipelines + """ + + inputs: Union[List[str], str] = Field( + description=( + "A string or List of batch of strings representing input(s) to" + "a token_classification task" + ) + ) + + +class TokenClassificationResult(BaseModel): + """ + Schema for a classification of a single token + """ + + entity: str = Field(description="entity predicted for that token/word") + score: float = Field(description="The corresponding probability for `entity`") + index: int = Field(description="index of the corresponding token in the sentence") + word: str = Field(description="token/word classified") + start: Optional[int] = Field( + description=( + "index of the start of the corresponding entity in the sentence. " + "Only exists if the offsets are available within the tokenizer" + ) + ) + end: Optional[int] = Field( + description=( + "index of the end of the corresponding entity in the sentence. " + "Only exists if the offsets are available within the tokenizer" + ) + ) + is_grouped: bool = Field( + default=False, + description="True if this result is part of an entity group", + ) + + +class TokenClassificationOutput(BaseModel): + """ + Schema for results of TokenClassificationPipeline inference. Classifications of each + token stored in a list of lists of batch[sentence[token]] + """ + + predictions: List[List[TokenClassificationResult]] = Field( + description=( + "list of list of results of token classification pipeline. Outer list " + "has one item for each sequence in the batch. Inner list has one " + "TokenClassificationResult item per token in the given sequence" + ) + ) + + +@Pipeline.register( + task="token_classification", + task_aliases=["ner"], +) +class TokenClassificationPipeline(TransformersPipeline): + """ + transformers token classification pipeline + + example instantiation: + ```python + token_classifier = Pipeline.create( + task="token_classification", + model_path="token_classification_model_dir/", + batch_size=BATCH_SIZE, + ) + ``` + + :param model_path: sparsezoo stub to a transformers model, an ONNX file, or + (preferred) a directory containing a model.onnx, tokenizer config, and model + config. If no tokenizer and/or model config(s) are found, then they will be + loaded from huggingface transformers using the `default_model_name` key + :param engine_type: inference engine to use. Currently supported values include + 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param sequence_length: sequence length to compile model and tokenizer for. + Default is 128 + :param default_model_name: huggingface transformers model name to use to + load a tokenizer and model config when none are provided in the `model_path`. + Default is 'bert-base-uncased' + :param aggregation_strategy: how to aggregate tokens in postprocessing. Options + include 'none', 'simple', 'first', 'average', and 'max'. Default is None + :param ignore_labels: list of label names to ignore in output. Default is + ['0'] which ignores the default known class label + """ + + def __init__( + self, + *, + aggregation_strategy: AggregationStrategy = AggregationStrategy.NONE, + ignore_labels: List[str] = None, + **kwargs, + ): + + if isinstance(aggregation_strategy, str): + aggregation_strategy = aggregation_strategy.strip().lower() + self._aggregation_strategy = AggregationStrategy(aggregation_strategy) + self._ignore_labels = ["0"] if ignore_labels is None else ignore_labels + + super().__init__(**kwargs) + + @property + def aggregation_strategy(self) -> str: + """ + :return: how to aggregate tokens in postprocessing. Options + include 'none', 'simple', 'first', 'average', and 'max' + """ + return self._aggregation_strategy.value + + @property + def ignore_labels(self) -> List[str]: + """ + :return: list of label names to ignore in output. Default is + ['0'] which ignores the default known class label + """ + return self._ignore_labels + + @property + def input_model(self) -> Type[BaseModel]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + return TokenClassificationInput + + @property + def output_model(self) -> Type[BaseModel]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + return TokenClassificationOutput + + def parse_inputs(self, *args, **kwargs) -> BaseModel: + """ + :param args: ordered arguments to pipeline, only an input_model object + is supported as an arg for this function + :param kwargs: keyword arguments to pipeline + :return: pipeline arguments parsed into the given `input_model` + schema if necessary. If an instance of the `input_model` is provided + it will be returned + """ + if args and kwargs: + raise ValueError( + f"{self.__class__} only support args OR kwargs. Found " + f" {len(args)} args and {len(kwargs)} kwargs" + ) + + if args: + if len(args) == 1: + # passed input_model schema directly + if isinstance(args[0], self.input_model): + return args[0] + return self.input_model(inputs=args[0]) + else: + return self.input_model(inputs=args) + + return self.input_model(**kwargs) + + def process_inputs( + self, + inputs: TokenClassificationInput, + ) -> Tuple[List[numpy.ndarray], Dict[str, Any]]: + """ + :param inputs: inputs to the pipeline. Must be the type of the + TokenClassificationInput + :return: inputs of this model processed into a list of numpy arrays that + can be directly passed into the forward pass of the pipeline engine + and dictionary containing offset mappings and special tokens mask to + be used during postprocessing + """ + tokens = self.tokenizer( + inputs.inputs, + return_tensors="np", + truncation=TruncationStrategy.LONGEST_FIRST.value, + padding=PaddingStrategy.MAX_LENGTH.value, + return_special_tokens_mask=True, + return_offsets_mapping=self.tokenizer.is_fast, + ) + + offset_mapping = ( + tokens.pop("offset_mapping") + if self.tokenizer.is_fast + else [None] * len(inputs.inputs) + ) + special_tokens_mask = tokens.pop("special_tokens_mask") + postprocessing_kwargs = dict( + inputs=inputs, + tokens=tokens, + offset_mapping=offset_mapping, + special_tokens_mask=special_tokens_mask, + ) + + return self.tokens_to_engine_input(tokens), postprocessing_kwargs + + def process_engine_outputs( + self, + engine_outputs: List[numpy.ndarray], + **kwargs, + ) -> BaseModel: + """ + :param engine_outputs: list of numpy arrays that are the output of the engine + forward pass + :return: outputs of engine post-processed into an object in the `output_model` + format of this pipeline + """ + inputs = kwargs["inputs"] + tokens = kwargs["tokens"] + offset_mapping = kwargs["offset_mapping"] + special_tokens_mask = kwargs["special_tokens_mask"] + + predictions = [] # type: List[List[TokenClassificationResult]] + + for entities_index, current_entities in enumerate(engine_outputs[0]): + input_ids = tokens["input_ids"][entities_index] + + scores = numpy.exp(current_entities) / numpy.exp(current_entities).sum( + -1, keepdims=True + ) + pre_entities = self._gather_pre_entities( + inputs.inputs[entities_index], + input_ids, + scores, + offset_mapping[entities_index], + special_tokens_mask[entities_index], + ) + grouped_entities = self._aggregate(pre_entities) + # Filter anything that is in self.ignore_labels + current_results = [] # type: List[TokenClassificationResult] + for entity in grouped_entities: + if entity.get("entity") in self.ignore_labels or ( + entity.get("entity_group") in self.ignore_labels + ): + continue + if entity.get("entity_group"): + entity["entity"] = entity["entity_group"] + entity["is_grouped"] = True + del entity["entity_group"] + current_results.append(TokenClassificationResult(**entity)) + predictions.append(current_results) + + return self.output_model(predictions=predictions) + + # utilities below adapted from transformers + + def _gather_pre_entities( + self, + sentence: str, + input_ids: numpy.ndarray, + scores: numpy.ndarray, + offset_mapping: Optional[List[Tuple[int, int]]], + special_tokens_mask: numpy.ndarray, + ) -> List[dict]: + pre_entities = [] + for idx, token_scores in enumerate(scores): + # Filter special_tokens, they should only occur + # at the sentence boundaries since we're not encoding pairs of + # sentences so we don't have to keep track of those. + if special_tokens_mask[idx]: + continue + + word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) + if offset_mapping is not None: + start_ind, end_ind = offset_mapping[idx] + word_ref = sentence[start_ind:end_ind] + is_subword = len(word_ref) != len(word) + + if int(input_ids[idx]) == self.tokenizer.unk_token_id: + word = word_ref + is_subword = False + else: + start_ind = None + end_ind = None + is_subword = False + + pre_entity = { + "word": word, + "scores": token_scores, + "start": start_ind, + "end": end_ind, + "index": idx, + "is_subword": is_subword, + } + pre_entities.append(pre_entity) + return pre_entities + + def _aggregate(self, pre_entities: List[dict]) -> List[dict]: + if self._aggregation_strategy in { + AggregationStrategy.NONE, + AggregationStrategy.SIMPLE, + }: + entities = [] + for pre_entity in pre_entities: + entity_idx = pre_entity["scores"].argmax() + score = pre_entity["scores"][entity_idx] + entity = { + "entity": self.config.id2label[entity_idx], + "score": score, + "index": pre_entity["index"], + "word": pre_entity["word"], + "start": pre_entity["start"], + "end": pre_entity["end"], + } + entities.append(entity) + else: + entities = self._aggregate_words(pre_entities) + + if self._aggregation_strategy == AggregationStrategy.NONE: + return entities + return self._group_entities(entities) + + def _aggregate_word(self, entities: List[dict]) -> dict: + word = self.tokenizer.convert_tokens_to_string( + [entity["word"] for entity in entities] + ) + if self._aggregation_strategy == AggregationStrategy.FIRST: + scores = entities[0]["scores"] + idx = scores.argmax() + score = scores[idx] + entity = self.config.id2label[idx] + elif self._aggregation_strategy == AggregationStrategy.MAX: + max_entity = max(entities, key=lambda entity: entity["scores"].max()) + scores = max_entity["scores"] + idx = scores.argmax() + score = scores[idx] + entity = self.config.id2label[idx] + elif self._aggregation_strategy == AggregationStrategy.AVERAGE: + scores = numpy.stack([entity["scores"] for entity in entities]) + average_scores = numpy.nanmean(scores, axis=0) + entity_idx = average_scores.argmax() + entity = self.config.id2label[entity_idx] + score = average_scores[entity_idx] + else: + raise ValueError( + f"Invalid aggregation_strategy: {self._aggregation_strategy}" + ) + new_entity = { + "entity": entity, + "score": score, + "word": word, + "start": entities[0]["start"], + "end": entities[-1]["end"], + } + return new_entity + + def _aggregate_words(self, entities: List[dict]) -> List[dict]: + word_entities = [] + word_group = None + for entity in entities: + if word_group is None: + word_group = [entity] + elif entity["is_subword"]: + word_group.append(entity) + else: + word_entities.append(self._aggregate_word(word_group)) + word_group = [entity] + # Last item + word_entities.append(self._aggregate_word(word_group)) + return word_entities + + def _group_sub_entities(self, entities: List[dict]) -> dict: + # Get the first entity in the entity group + entity = entities[0]["entity"].split("-")[-1] + scores = numpy.nanmean([entity["score"] for entity in entities]) + tokens = [entity["word"] for entity in entities] + + entity_group = { + "entity_group": entity, + "score": numpy.mean(scores), + "word": self.tokenizer.convert_tokens_to_string(tokens), + "start": entities[0]["start"], + "end": entities[-1]["end"], + } + return entity_group + + def _get_tag(self, entity_name: str) -> Tuple[str, str]: + if entity_name.startswith("B-"): + bi = "B" + tag = entity_name[2:] + elif entity_name.startswith("I-"): + bi = "I" + tag = entity_name[2:] + else: + # It's not in B-, I- format + bi = "B" + tag = entity_name + return bi, tag + + def _group_entities(self, entities: List[dict]) -> List[dict]: + + entity_groups = [] + entity_group_disagg = [] + + for entity in entities: + if not entity_group_disagg: + entity_group_disagg.append(entity) + continue + + # If the current entity is similar and adjacent to the previous entity, + # append it to the disaggregated entity group + # The split is meant to account for the "B" and "I" prefixes + # Shouldn't merge if both entities are B-type + bi, tag = self._get_tag(entity["entity"]) + last_bi, last_tag = self._get_tag(entity_group_disagg[-1]["entity"]) + + if tag == last_tag and bi != "B": + # Modify subword type to be previous_type + entity_group_disagg.append(entity) + else: + # If the current entity is different from the previous entity + # aggregate the disaggregated entity group + entity_groups.append(self._group_sub_entities(entity_group_disagg)) + entity_group_disagg = [entity] + if entity_group_disagg: + # it's the last entity, add it to the entity groups + entity_groups.append(self._group_sub_entities(entity_group_disagg)) + + return entity_groups diff --git a/src/deepsparse/transformers/server.py b/src/deepsparse/transformers/server.py deleted file mode 100644 index 59035dba80..0000000000 --- a/src/deepsparse/transformers/server.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Specs, schemas, and pipelines for use when serving transformers models -""" - -from typing import Any, Dict, List, Optional, Tuple, Union - -from deepsparse.tasks import SupportedTasks -from deepsparse.transformers.pipelines import Pipeline, pipeline - - -try: - from deepsparse.server.config import ServeModelConfig - - deepsparse_server_err = None -except Exception as _err: - deepsparse_server_err = _err - ServeModelConfig = object - -try: - from pydantic import BaseModel, Field - - pydantic_import_err = None -except Exception as _err: - pydantic_import_err = _err - BaseModel = object - Field = dict - - -__all__ = [ - "create_pipeline_definitions", - "QuestionAnsweringRequest", - "QuestionAnsweringResponse", - "TextClassificationRequest", - "TextClassificationResponse", - "TokenClassificationRequest", - "TokenClassificationResponse", -] - - -def create_pipeline_definitions( - model_config: ServeModelConfig, -) -> Tuple[Pipeline, Any, Any, Dict]: - """ - Create a pipeline definition and the supporting files for a given model config - to use for serving in the DeepSparse inference server - - :param model_config: the server model config describing the model and params - :return: a tuple containing (the pipeline to use for inference, - the expected request body, the expected response body, - any additional keyword args for use with the server) - """ - if deepsparse_server_err: - raise deepsparse_server_err - - if pydantic_import_err: - raise pydantic_import_err - - if SupportedTasks.nlp.question_answering.matches(model_config.task): - request_model = QuestionAnsweringRequest - response_model = Union[ - List[QuestionAnsweringResponse], - QuestionAnsweringResponse, - ] - kwargs = {} - elif SupportedTasks.nlp.text_classification.matches(model_config.task): - request_model = TextClassificationRequest - response_model = Union[ - List[TextClassificationResponse], List[List[TextClassificationResponse]] - ] - kwargs = {} - elif SupportedTasks.nlp.token_classification.matches(model_config.task): - request_model = TokenClassificationRequest - response_model = Union[ - List[TokenClassificationResponse], List[List[TokenClassificationResponse]] - ] - kwargs = {} - else: - raise ValueError( - f"unrecognized task given of {model_config.task} for config {model_config}" - ) - - pipeline_instance: Pipeline = pipeline( - task=model_config.task.lower().replace("_", "-"), - model_path=model_config.model_path, - engine_type=model_config.engine, - num_cores=model_config.num_cores, - scheduler=model_config.scheduler, - batch_size=model_config.batch_size, - **model_config.kwargs, - ) - - return pipeline_instance, request_model, response_model, kwargs - - -class QuestionAnsweringRequest(BaseModel): - """ - The request model for Question Answering Task - """ - - question: Union[List[str], str] = Field( - description="Either a string or a List of string questions to answer" - ) - context: Union[List[str], str] = Field( - description="Either a string or List of strings representing the context " - "for each question" - ) - - -class TokenClassificationRequest(BaseModel): - """ - Schema for TokenClassificationPipeline Request - """ - - inputs: Union[List[str], str] = Field( - description="A string or List of strings representing input to" - "TokenClassificationPipeline task" - ) - - -class TextClassificationRequest(BaseModel): - """ - Schema for TextClassificationPipeline Request - """ - - sequences: Union[List[str], str] = Field( - description="A string or List of strings representing input to" - "TextClassificationPipeline task" - ) - - -class QuestionAnsweringResponse(BaseModel): - """ - Schema for a result from Question Answering Task - """ - - score: float = Field(description="confidence score for prediction") - start: int = Field(description="The start index of the answer") - end: int = Field(description="The end index of the answer") - answer: str = Field(description="The predicted answer") - - -class TokenClassificationResponse(BaseModel): - """ - Schema for TokenClassificationPipeline Response - """ - - entity: str = Field( - description="The entity predicted for that token/word (it is named" - "`entity_group` when `aggregation_strategy` is not `none`." - ) - score: float = Field(description="The corresponding probability for `entity`.") - index: int = Field( - description="The index of the corresponding token in the sentence." - ) - word: str = Field(description="The token/word classified.") - start: Optional[int] = Field( - description="The index of the start of the corresponding entity in the " - "sentence. Only exists if the offsets are available within the tokenizer" - ) - end: Optional[int] = Field( - description="The index of the end of the corresponding entity in the sentence. " - "Only exists if the offsets are available within the tokenizer" - ) - - -class TextClassificationResponse(BaseModel): - """ - Schema for TextClassificationPipeline Response - """ - - label: str = Field(description="The label predicted.") - score: float = Field(description="The corresponding probability.")