Skip to content

Commit

Permalink
Improve a add-new-pipeline docs a bit (#14485)
Browse files Browse the repository at this point in the history
  • Loading branch information
stancld committed Nov 22, 2021
1 parent a4553e6 commit e0e2da1
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions docs/source/add_new_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to
from transformers import Pipeline
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2)
def preprocess(self, inputs, maybe_arg=2):
model_input = Tensor(....)
return {"model_input": model_input}
def _forward(self, model_inputs)
def _forward(self, model_inputs):
# model_inputs == {"model_input": model_input}
oututs = self.model(**model_inputs)
outputs = self.model(**model_inputs)
# Maybe {"logits": Tensor(...)}
return outputs
def postprocess(self, model_outputs)
def postprocess(self, model_outputs):
best_class = model_outputs["logits"].softmax(-1)
return best_class
Expand Down Expand Up @@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa
.. code-block::
def postprocess(self, model_outputs, top_k=5)
def postprocess(self, model_outputs, top_k=5):
best_class = model_outputs["logits"].softmax(-1)
# Add logic to handle top_k
return best_class
def _sanitize_parameters(self, **kwargs)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
Expand Down

0 comments on commit e0e2da1

Please sign in to comment.