Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-contrast-attr
Browse files Browse the repository at this point in the history
* origin/main:
  Remove `max_input_length` from `model.encode` (#227)
  • Loading branch information
gsarti committed Oct 30, 2023
2 parents 3093bed + 1b52b04 commit b329e77
Showing 1 changed file with 0 additions and 10 deletions.
10 changes: 0 additions & 10 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def encode(
as_targets: bool = False,
return_baseline: bool = False,
include_eos_baseline: bool = False,
max_input_length: int = 512,
add_bos_token: bool = True,
add_special_tokens: bool = True,
) -> BatchEncoding:
Expand All @@ -249,21 +248,12 @@ def encode(
"""
if as_targets and not self.is_encoder_decoder:
raise ValueError("Decoder-only models should use tokenization as source only.")
max_length = self.tokenizer.max_len_single_sentence
# Some tokenizer have weird values for max_len_single_sentence
# Cap length with max_model_input_sizes instead
if max_length > 1e6:
if hasattr(self.tokenizer, "max_model_input_sizes") and self.tokenizer.max_model_input_sizes:
max_length = max(v for _, v in self.tokenizer.max_model_input_sizes.items())
else:
max_length = max_input_length
batch = self.tokenizer(
text=texts if not as_targets else None,
text_target=texts if as_targets else None,
add_special_tokens=add_special_tokens,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(self.device)
baseline_ids = None
Expand Down

0 comments on commit b329e77

Please sign in to comment.