diff --git a/docs/source/add_new_pipeline.rst b/docs/source/add_new_pipeline.rst index a49772fbd3d29..ccf587968f62c 100644 --- a/docs/source/add_new_pipeline.rst +++ b/docs/source/add_new_pipeline.rst @@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to from transformers import Pipeline class MyPipeline(Pipeline): - def _sanitize_parameters(self, **kwargs) + def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "maybe_arg" in kwargs: preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] return preprocess_kwargs, {}, {} - def preprocess(self, inputs, maybe_arg=2) + def preprocess(self, inputs, maybe_arg=2): model_input = Tensor(....) return {"model_input": model_input} - def _forward(self, model_inputs) + def _forward(self, model_inputs): # model_inputs == {"model_input": model_input} - oututs = self.model(**model_inputs) + outputs = self.model(**model_inputs) # Maybe {"logits": Tensor(...)} return outputs - def postprocess(self, model_outputs) + def postprocess(self, model_outputs): best_class = model_outputs["logits"].softmax(-1) return best_class @@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa .. code-block:: - def postprocess(self, model_outputs, top_k=5) + def postprocess(self, model_outputs, top_k=5): best_class = model_outputs["logits"].softmax(-1) # Add logic to handle top_k return best_class - def _sanitize_parameters(self, **kwargs) + def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "maybe_arg" in kwargs: preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]