Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/deepsparse/image_classification/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
cv2_error = cv2_import_error


@Pipeline.register(task="image_classification")
@Pipeline.register(
task="image_classification",
default_model_path=(
"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/"
"imagenet/pruned85_quant-none-vnni"
),
)
class ImageClassificationPipeline(Pipeline):
"""
Image classification pipeline for DeepSparse
Expand Down
33 changes: 28 additions & 5 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pydantic import BaseModel, Field

from deepsparse import Engine, Scheduler
from deepsparse.benchmark_model.ort_engine import ORTEngine
from deepsparse.benchmark.ort_engine import ORTEngine
from deepsparse.tasks import SupportedTasks


Expand Down Expand Up @@ -170,7 +170,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
@staticmethod
def create(
task: str,
model_path: str,
model_path: str = None,
engine_type: str = DEEPSPARSE_ENGINE,
batch_size: int = 1,
num_cores: int = None,
Expand All @@ -182,7 +182,7 @@ def create(
"""
:param task: name of task to create a pipeline for
:param model_path: path on local system or SparseZoo stub to load the model
from
from. Some tasks may have a default model path
:param engine_type: inference engine to use. Currently supported values
include 'deepsparse' and 'onnxruntime'. Default is 'deepsparse'
:param batch_size: static batch size to use for inference. Default is 1
Expand Down Expand Up @@ -214,7 +214,22 @@ def create(
f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}"
)

return _REGISTERED_PIPELINES[task](
pipeline_constructor = _REGISTERED_PIPELINES[task]

if (
model_path is None
and hasattr(pipeline_constructor, "default_model_path")
and pipeline_constructor.default_model_path
):
model_path = pipeline_constructor.default_model_path

if model_path is None:
raise ValueError(
f"No model_path provided for pipeline {pipeline_constructor}. Must "
"provide a model path for pipelines that do not have a default defined"
)

return pipeline_constructor(
model_path=model_path,
engine_type=engine_type,
batch_size=batch_size,
Expand All @@ -226,7 +241,12 @@ def create(
)

@classmethod
def register(cls, task: str, task_aliases: Optional[List[str]] = None):
def register(
cls,
task: str,
task_aliases: Optional[List[str]] = None,
default_model_path: Optional[str] = None,
):
"""
Pipeline implementer class decorator that registers the pipeline
task name and its aliases as valid tasks that can be used to load
Expand All @@ -238,6 +258,8 @@ def register(cls, task: str, task_aliases: Optional[List[str]] = None):
:param task: main task name of this pipeline
:param task_aliases: list of extra task names that may be used to reference
this pipeline. Default is None
:param default_model_path: path (ie zoo stub) to use as default for this
task if None is provided
"""
task_names = [task]
if task_aliases:
Expand Down Expand Up @@ -266,6 +288,7 @@ def _register_pipeline_tasks_decorator(pipeline_class: Pipeline):
# set task and task_aliases as class level property
pipeline_class.task = task
pipeline_class.task_aliases = task_aliases
pipeline_class.default_model_path = default_model_path

return pipeline_class

Expand Down
4 changes: 4 additions & 0 deletions src/deepsparse/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class QuestionAnsweringOutput(BaseModel):
@Pipeline.register(
task="question_answering",
task_aliases=["qa"],
default_model_path=(
"zoo:nlp/question_answering/bert-base/pytorch/huggingface/"
"squad/12layer_pruned80_quant-none-vnni"
),
)
class QuestionAnsweringPipeline(TransformersPipeline):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/deepsparse/transformers/pipelines/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class TextClassificationOutput(BaseModel):
@Pipeline.register(
task="text_classification",
task_aliases=["glue", "sentiment_analysis"],
default_model_path=(
"zoo:nlp/sentiment_analysis/bert-base/pytorch/huggingface/"
"sst2/12layer_pruned80_quant-none-vnni"
),
)
class TextClassificationPipeline(TransformersPipeline):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/deepsparse/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class TokenClassificationOutput(BaseModel):
@Pipeline.register(
task="token_classification",
task_aliases=["ner"],
default_model_path=(
"zoo:nlp/token_classification/bert-base/pytorch/huggingface/"
"conll2003/12layer_pruned80_quant-none-vnni"
),
)
class TokenClassificationPipeline(TransformersPipeline):
"""
Expand Down
7 changes: 6 additions & 1 deletion src/deepsparse/yolo/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
cv2_error = cv2_import_error


@Pipeline.register(task="yolo")
@Pipeline.register(
task="yolo",
default_model_path=(
"zoo:cv/detection/yolov5-l/pytorch/ultralytics/coco/pruned_quant-aggressive_95"
),
)
class YOLOPipeline(Pipeline):
"""
Image Segmentation YOLO pipeline for DeepSparse
Expand Down