Skip to content

Commit

Permalink
Enabling dataset iteration on pipelines.
Browse files Browse the repository at this point in the history
Enabling dataset iteration on pipelines.

Unifying parameters under `set_parameters` function.

Small fix.

Last fixes after rebase

Remove print.

Fixing text2text `generate_kwargs`

No more `self.max_length`.

Fixing tf only conversational.

Consistency in start/stop index over TF/PT.

Speeding up drastically on TF (nasty bug where max_length would increase
a ton.)

Adding test for support for non fast tokenizers.

Fixign GPU usage on zero-shot.

Fix working on Tf.

Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Small cleanup.

Remove all asserts + simple format.
  • Loading branch information
Narsil committed Sep 1, 2021
1 parent 53ee995 commit 7e8f4b9
Show file tree
Hide file tree
Showing 19 changed files with 1,156 additions and 1,154 deletions.
25 changes: 19 additions & 6 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,43 @@ def __call__(
- **text** (:obj:`str`) -- The recognized text.
"""
return super().__call__(inputs, **kwargs)

def set_parameters(self, **kwargs):
# No parameters on this pipeline right now
pass

def preprocess(self, inputs):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()

if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)

assert isinstance(inputs, np.ndarray), "We expect a numpy ndarray as input"
assert len(inputs.shape) == 1, "We expect a single channel audio input for AutomaticSpeechRecognitionPipeline"
if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
processed = self.ensure_tensor_on_device(**processed)
return processed

def forward(self, model_inputs):
model_inputs = self.ensure_tensor_on_device(**model_inputs)
name = self.model.__class__.__name__
if name.endswith("ForConditionalGeneration"):
input_ids = processed["input_features"]
input_ids = model_inputs["input_features"]
tokens = self.model.generate(input_ids=input_ids)
tokens = tokens.squeeze(0)
elif name.endswith("ForCTC"):
outputs = self.model(**processed)
outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1)
return tokens

def postprocess(self, model_outputs):
skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True
recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
recognized_string = self.tokenizer.decode(model_outputs, skip_special_tokens=skip_special_tokens)
return {"text": recognized_string}
175 changes: 118 additions & 57 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,27 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..file_utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging


GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]

if is_tf_available():
import tensorflow as tf

from ..models.auto.modeling_tf_auto import TFAutoModel

if is_torch_available():
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset

from ..models.auto.modeling_auto import AutoModel
else:
Dataset = None

if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
Expand All @@ -50,6 +55,12 @@
logger = logging.get_logger(__name__)


def collate_fn(items):
if len(items) != 1:
raise ValueError("This collate_fn is meant to be used with batch_size=1")
return items[0]


def infer_framework_load_model(
model,
config: AutoConfig,
Expand Down Expand Up @@ -585,6 +596,49 @@ def predict(self, X):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""

if is_torch_available():

class PipelineDataset(Dataset):
def __init__(self, dataset, process):
self.dataset = dataset
self.process = process

def __len__(self):
return len(self.dataset)

def __getitem__(self, i):
item = self.dataset[i]
processed = self.process(item)
return processed

class PipelineIterator(IterableDataset):
def __init__(self, loader, infer):
self.loader = loader
self.infer = infer

def __len__(self):
return len(self.loader)

def __iter__(self):
self.iterator = iter(self.loader)
return self

def __next__(self):
item = next(self.iterator)
processed = self.infer(item)
return processed

class KeyDataset(Dataset):
def __init__(self, dataset: Dataset, key: str):
self.dataset = dataset
self.key = key

def __len__(self):
return len(self.dataset)

def __getitem__(self, i):
return self.dataset[i][self.key]


@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
Expand Down Expand Up @@ -618,6 +672,7 @@ def __init__(
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
**kwargs,
):

if framework is None:
Expand All @@ -641,6 +696,9 @@ def __init__(
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))

self.call_count = 0
self.set_parameters(**kwargs)

def save_pretrained(self, save_directory: str):
"""
Save the pipeline's model and tokenizer.
Expand Down Expand Up @@ -739,65 +797,68 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}."
)

def _parse_and_tokenize(
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
"""
Parse arguments and tokenize
"""
# Parse arguments
if getattr(self.tokenizer, "pad_token", None) is None:
padding = False
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation=truncation,
)
return inputs

def __call__(self, inputs, *args, **kwargs):
try:
model_inputs = self._parse_and_tokenize(inputs, *args, **kwargs)
outputs = self._forward(model_inputs)
return outputs
except ValueError:
# XXX: Some tokenizer do NOT have a pad token, hence we cannot run the inference
# in a batch, instead we run everything sequentially
if isinstance(inputs, list):
values = []
for input_ in inputs:
model_input = self._parse_and_tokenize(input_, padding=False, *args, **kwargs)
value = self._forward(model_input)
values.append(value.squeeze(0))
else:
model_input = self._parse_and_tokenize(inputs, padding=False, *args, **kwargs)
values = self._forward(model_input)
return values
@abstractmethod
def set_parameters(self, **pipeline_parameters):
raise NotImplementedError("set_parameters not implemented")

def _forward(self, inputs, return_tensors=False):
"""
Internal framework specific forward dispatching
@abstractmethod
def preprocess(self, input_: Any, **preprocess_parameters) -> Dict[str, GenericTensor]:
raise NotImplementedError("preprocess not implemented")

Args:
inputs: dict holding all the keyword arguments for required by the model forward method.
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array
@abstractmethod
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters) -> Any:
raise NotImplementedError("postprocess not implemented")

Returns:
Numpy array
"""
# Encode for forward
@abstractmethod
def forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters) -> ModelOutput:
raise NotImplementedError("postprocess not implemented")

def infer_forward(self, model_inputs, *args, **kwargs):
with self.device_placement():
if self.framework == "tf":
# TODO trace model
predictions = self.model(inputs.data, training=False)[0]
else:
model_inputs["training"] = False
model_outputs = self.forward(model_inputs)
elif self.framework == "pt":
with torch.no_grad():
inputs = self.ensure_tensor_on_device(**inputs)
predictions = self.model(**inputs)[0].cpu()

if return_tensors:
return predictions
model_inputs = self.ensure_tensor_on_device(**model_inputs)
model_outputs = self.forward(model_inputs)
else:
raise ValueError(f"Framework {self.framework} is not supported")
return model_outputs

def get_iterator(self, inputs, num_workers: int):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
dataset = PipelineDataset(inputs, self.preprocess)
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, collate_fn=collate_fn)
model_iterator = PipelineIterator(dataloader, self.infer_forward)
final_iterator = PipelineIterator(model_iterator, self.postprocess)
return final_iterator

def __call__(self, inputs, *args, num_workers=8, **kwargs):
self.set_parameters(**kwargs)
self.call_count += 1
if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
UserWarning,
)
if isinstance(inputs, list):
if self.framework == "pt":
final_iterator = self.get_iterator(inputs, num_workers)
outputs = [output for output in final_iterator]
return outputs
else:
return self.run_multi(inputs)
elif Dataset and isinstance(inputs, Dataset):
return self.get_iterator(inputs, num_workers)
else:
return predictions.numpy()
return self.run_single(inputs)

def run_multi(self, inputs):
return [self.run_single(item) for item in inputs]

def run_single(self, inputs):
model_inputs = self.preprocess(inputs)
model_outputs = self.infer_forward(model_inputs)
outputs = self.postprocess(model_outputs)
return outputs
Loading

0 comments on commit 7e8f4b9

Please sign in to comment.