Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/deepsparse/transformers/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def pipeline(
**kwargs,
) -> Pipeline:
"""
[DEPRECATED] - deepsparse.transformers.pipeline is deprecated to craete DeepSparse
pipelines for tranformers tasks use deepsparse.Pipeline.create(task, ...)
[DEPRECATED] - deepsparse.transformers.pipeline is deprecated to create DeepSparse
pipelines for transformers tasks use deepsparse.Pipeline.create(task, ...)

Utility factory method to build a Pipeline

:param task: name of the task to define which pipeline to create. Currently
:param task: name of the task to define which pipeline to create. Currently,
supported task - "question-answering"
:param model_name: canonical name of the hugging face model this model is based on
:param model_path: path to model directory containing `model.onnx`, `config.json`,
Expand All @@ -194,8 +194,8 @@ def pipeline(
:return: Pipeline object for the given taks and model
"""
warnings.warn(
"[DEPRECATED] - deepsparse.transformers.pipeline is deprecated to craete "
"DeepSparse pipelines for tranformers tasks use deepsparse.Pipeline.create()"
"[DEPRECATED] - deepsparse.transformers.pipeline is deprecated to create "
"DeepSparse pipelines for transformers tasks use deepsparse.Pipeline.create()"
)

if config is not None or tokenizer is not None:
Expand Down
97 changes: 64 additions & 33 deletions src/deepsparse/transformers/pipelines_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
text-classification,token-classification}]
-d DATA [--model-name MODEL_NAME] --model-path
MODEL_PATH [--engine-type {deepsparse,onnxruntime}]
[--config CONFIG] [--tokenizer TOKENIZER]
[--max-length MAX_LENGTH] [--num-cores NUM_CORES]
[-b BATCH_SIZE] [--scheduler {multi,single}]
[-o OUTPUT_FILE]
Expand Down Expand Up @@ -50,12 +49,6 @@
--engine-type {deepsparse,onnxruntime}, --engine_type {deepsparse,onnxruntime}
inference engine name to use. Supported options are
'deepsparse'and 'onnxruntime'
--config CONFIG Huggingface model config, if none provided, default
will be usedwhich will be from the model name or
sparsezoo stub if given for model path
--tokenizer TOKENIZER
Huggingface tokenizer, if none provided, default will
be used
--max-length MAX_LENGTH, --max_length MAX_LENGTH
Maximum sequence length of model inputs. default is
128
Expand All @@ -78,7 +71,6 @@
2) deepsparse.transformers.run_inference --task ner \
--model-path models/bert-ner-test.onnx \
--data input.txt \
--config ner-config.json \
--output-file out.txt \
--batch_size 2

Expand All @@ -91,16 +83,26 @@
"""

import argparse
from typing import Optional
import json
from typing import Any, Callable, Optional

from .loaders import SUPPORTED_EXTENSIONS
from .pipelines import SUPPORTED_ENGINES, SUPPORTED_TASKS, pipeline, process_dataset
from deepsparse.pipeline import SUPPORTED_PIPELINE_ENGINES
from deepsparse.transformers import fix_numpy_types
from deepsparse.transformers.loaders import SUPPORTED_EXTENSIONS, get_batch_loader
from deepsparse.transformers.pipelines import pipeline


__all__ = [
"cli",
]

SUPPORTED_TASKS = [
"question_answering",
"text_classification",
"token_classification",
"ner",
]


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
Expand All @@ -111,8 +113,8 @@ def _parse_args() -> argparse.Namespace:
"-t",
"--task",
help="Name of the task to define which pipeline to create."
f" Currently supported tasks {list(SUPPORTED_TASKS.keys())}",
choices=SUPPORTED_TASKS.keys(),
f" Currently supported tasks {SUPPORTED_TASKS}",
choices=SUPPORTED_TASKS,
type=str,
default="sentiment-analysis",
)
Expand Down Expand Up @@ -150,24 +152,8 @@ def _parse_args() -> argparse.Namespace:
help="Inference engine name to use. Supported options are 'deepsparse'"
"and 'onnxruntime'",
type=str,
choices=SUPPORTED_ENGINES,
default=SUPPORTED_ENGINES[0],
)

parser.add_argument(
"--config",
help="Huggingface model config, if none provided, default will be used"
"which will be from the model name or sparsezoo stub if given for "
"model path",
type=str,
default=None,
)

parser.add_argument(
"--tokenizer",
help="Huggingface tokenizer, if none provided, default will be used",
type=str,
default=None,
choices=SUPPORTED_PIPELINE_ENGINES,
default=SUPPORTED_PIPELINE_ENGINES[0],
)

parser.add_argument(
Expand Down Expand Up @@ -229,8 +215,6 @@ def cli():
model_name=_args.model_name,
model_path=_args.model_path,
engine_type=_args.engine_type,
config=_args.config,
tokenizer=_args.tokenizer,
max_length=_args.max_length,
num_cores=_args.num_cores,
batch_size=_args.batch_size,
Expand All @@ -245,5 +229,52 @@ def cli():
)


def response_to_json(response: Any):
"""
Converts a response to a json string

:param response: A List[Any] or Dict[Any, Any] or a Pydantic model,
that should be converted to a valid json string
:return: A json string representation of the response
"""
if isinstance(response, list):
return [response_to_json(val) for val in response]
elif isinstance(response, dict):
return {key: response_to_json(val) for key, val in response.items()}
elif hasattr(response, "json") and callable(response.json):
return response.json()
return json.dumps(response)


def process_dataset(
pipeline_object: Callable,
data_path: str,
batch_size: int,
task: str,
output_path: str,
) -> None:
"""
:param pipeline_object: An instantiated pipeline Callable object
:param data_path: Path to input file, supports csv, json and text files
:param batch_size: batch_size to use for inference
:param task: The task pipeline is instantiated for
:param output_path: Path to a json file to output inference results to
"""
batch_loader = get_batch_loader(
data_file=data_path,
batch_size=batch_size,
task=task,
)
# Wraps pipeline object to make numpy types serializable
pipeline_object = fix_numpy_types(pipeline_object)
with open(output_path, "a") as output_file:
for batch in batch_loader:
batch_output = pipeline_object(**batch)
json_output = response_to_json(batch_output)

json.dump(json_output, output_file)
output_file.write("\n")


if __name__ == "__main__":
cli()