Skip to content

Commit

Permalink
Merge pull request #94 from jianzhnie/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
jianzhnie committed Sep 13, 2023
2 parents 6fab54c + 4966498 commit 7c4e2e1
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions chatllms/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import bitsandbytes as bnb
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from transformers.trainer_utils import get_last_checkpoint
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
Expand Down Expand Up @@ -290,6 +292,22 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor


# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores


def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor


def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
Expand Down

0 comments on commit 7c4e2e1

Please sign in to comment.