Skip to content

Commit

Permalink
Patch for Seq2Seq Model predictions (EleutherAI#1584)
Browse files Browse the repository at this point in the history
* Differentiate _encode_pair setting for decoder and enc-dec models

* tok_decode to not skip special token so that eos doen't become empty string

* Update model.py

* Update model.py

* Update huggingface.py

* Update lm_eval/models/huggingface.py

Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update model.py

---------

Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
  • Loading branch information
lintangsutawika and haileyschoelkopf committed Mar 17, 2024
1 parent 92f30af commit b7923a8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
14 changes: 10 additions & 4 deletions lm_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from typing import List, Optional, Tuple, Type, TypeVar

import transformers
from sqlitedict import SqliteDict
from tqdm import tqdm

Expand Down Expand Up @@ -296,11 +297,16 @@ def _encode_pair(self, context, continuation):
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]

whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)

context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]

elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)

return context_enc, continuation_enc

Expand Down
12 changes: 8 additions & 4 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,11 +711,15 @@ def tok_batch_encode(

return encoding["input_ids"], encoding["attention_mask"]

def tok_decode(self, tokens):
def tok_decode(self, tokens, skip_special_tokens=True):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens)
return self.tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens
)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True)
return self.tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens
)

def _model_call(self, inps, attn_mask=None, labels=None):
"""
Expand Down Expand Up @@ -1158,7 +1162,7 @@ def _collate(req: Tuple[str, dict]):
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
Expand Down

0 comments on commit b7923a8

Please sign in to comment.