diff --git a/chatllms/utils/model_utils.py b/chatllms/utils/model_utils.py index dc3d502..defb2f5 100644 --- a/chatllms/utils/model_utils.py +++ b/chatllms/utils/model_utils.py @@ -7,7 +7,8 @@ import torch from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer from transformers.trainer_utils import get_last_checkpoint - +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN) @@ -273,6 +274,21 @@ def find_last_checkpoint(checkpoint_dir): last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}') return last_checkpoint +# Avoid runtime error in model.generate(do_sample=True). +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 0] = 1.0 + return scores + + +def get_logits_processor() -> LogitsProcessorList: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + return logits_processor + def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str): """Collects the state dict and dump to disk."""