Skip to content

Commit

Permalink
[Feat] Add _set_model_kwargs_torch_dtype for HF model (#507)
Browse files Browse the repository at this point in the history
* add _set_model_kwargs_torch_dtype for hf models

* add logger
  • Loading branch information
Leymore committed Oct 27, 2023
1 parent 6405cd2 commit e3d4901
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
19 changes: 17 additions & 2 deletions opencompass/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,28 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
self.tokenizer.eos_token = '</s>'
self.tokenizer.pad_token_id = 0

def _set_model_kwargs_torch_dtype(self, model_kwargs):
if 'torch_dtype' not in model_kwargs:
torch_dtype = torch.float16
else:
torch_dtype = {
'torch.float16': torch.float16,
'torch.bfloat16': torch.bfloat16,
'torch.float': torch.float,
'auto': 'auto',
'None': None
}.get(model_kwargs['torch_dtype'])
self.logger.debug(f'HF using torch_dtype: {torch_dtype}')
if torch_dtype is not None:
model_kwargs['torch_dtype'] = torch_dtype

def _load_model(self,
path: str,
model_kwargs: dict,
peft_path: Optional[str] = None):
from transformers import AutoModel, AutoModelForCausalLM

model_kwargs.setdefault('torch_dtype', torch.float16)
self._set_model_kwargs_torch_dtype(model_kwargs)
try:
self.model = AutoModelForCausalLM.from_pretrained(
path, **model_kwargs)
Expand Down Expand Up @@ -409,7 +424,7 @@ def _load_model(self,
peft_path: Optional[str] = None):
from transformers import AutoModelForCausalLM

model_kwargs.setdefault('torch_dtype', torch.float16)
self._set_model_kwargs_torch_dtype(model_kwargs)
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
if peft_path is not None:
from peft import PeftModel
Expand Down
10 changes: 8 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,14 @@ def parse_hf_args(hf_parser):
hf_parser.add_argument('--hf-path', type=str)
hf_parser.add_argument('--peft-path', type=str)
hf_parser.add_argument('--tokenizer-path', type=str)
hf_parser.add_argument('--model-kwargs', nargs='+', action=DictAction)
hf_parser.add_argument('--tokenizer-kwargs', nargs='+', action=DictAction)
hf_parser.add_argument('--model-kwargs',
nargs='+',
action=DictAction,
default={})
hf_parser.add_argument('--tokenizer-kwargs',
nargs='+',
action=DictAction,
default={})
hf_parser.add_argument('--max-out-len', type=int)
hf_parser.add_argument('--max-seq-len', type=int)
hf_parser.add_argument('--no-batch-padding',
Expand Down

0 comments on commit e3d4901

Please sign in to comment.