diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 4101e58320..6b14eaf777 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -1,4 +1,5 @@ -from typing import Sequence, cast, overload +from typing import Sequence, Type, cast, overload +from typing_extensions import TypeVar from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnSchema, TWriteDisposition, TSchemaContract @@ -15,6 +16,8 @@ from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR from dlt.pipeline.warnings import credentials_argument_deprecated +TPipeline = TypeVar("TPipeline", bound=Pipeline, default=Pipeline) + @overload def pipeline( @@ -29,7 +32,8 @@ def pipeline( full_refresh: bool = False, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, -) -> Pipeline: + _impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment] +) -> TPipeline: """Creates a new instance of `dlt` pipeline, which moves the data from the source ie. a REST API to a destination ie. database or a data lake. #### Note: @@ -97,8 +101,9 @@ def pipeline( full_refresh: bool = False, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, + _impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment] **kwargs: Any, -) -> Pipeline: +) -> TPipeline: ensure_correct_pipeline_kwargs(pipeline, **kwargs) # call without arguments returns current pipeline orig_args = get_orig_args(**kwargs) # original (*args, **kwargs) @@ -111,7 +116,7 @@ def pipeline( context = Container()[PipelineContext] # if pipeline instance is already active then return it, otherwise create a new one if context.is_active(): - return cast(Pipeline, context.pipeline()) + return cast(TPipeline, context.pipeline()) else: pass @@ -129,7 +134,7 @@ def pipeline( progress = collector_from_name(progress) # create new pipeline instance - p = Pipeline( + p = _impl_cls( pipeline_name, pipelines_dir, pipeline_salt,