Skip to content

Commit

Permalink
Fix batching in generate (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed May 23, 2023
1 parent 1046230 commit 1789acf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
6 changes: 4 additions & 2 deletions inseq/attr/attribution_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,17 @@ def get_batched(bs: Optional[int], seq: Sequence[Any]) -> List[List[Any]]:
raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")

if batch_size is None:
return [f(self, *args, **kwargs)]
out = f(self, *args, **kwargs)
return out if isinstance(out, list) else [out]
batched_args = [get_batched(batch_size, arg) for arg in args]
len_batches = len(batched_args[0])
assert all(len(batch) == len_batches for batch in batched_args)
output = []
zipped_batched_args = zip(*batched_args) if len(batched_args) > 1 else [(x,) for x in batched_args[0]]
for i, batch in enumerate(zipped_batched_args):
logger.debug(f"Batching enabled: processing batch {i + 1} of {len_batches}...")
output.append(f(self, *batch, **kwargs))
out = f(self, *batch, **kwargs)
output += out if isinstance(out, list) else [out]
return output

return batched_wrapper
4 changes: 3 additions & 1 deletion inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def attribute(
if generate_from_target_prefix:
decoder_input = self.encode(generated_texts, as_targets=True)
generation_args["decoder_input_ids"] = decoder_input.input_ids
generated_texts = self.generate(encoded_input, return_generation_output=False, **generation_args)
generated_texts = self.generate(
encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args
)
else:
if generation_args:
logger.warning(
Expand Down
2 changes: 2 additions & 0 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from transformers.modeling_outputs import CausalLMOutput, ModelOutput, Seq2SeqLMOutput

from ..attr.attribution_decorators import batched
from ..data import BatchEncoding
from ..utils import check_device
from ..utils.typing import (
Expand Down Expand Up @@ -179,6 +180,7 @@ def info(self) -> Dict[str, str]:
return dic_info

@unhooked
@batched
def generate(
self,
inputs: Union[TextInput, BatchEncoding],
Expand Down

0 comments on commit 1789acf

Please sign in to comment.