Skip to content

Commit

Permalink
fix unequal session_len of turbomind and pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Jun 5, 2024
1 parent 2ab8e3d commit 0b0508a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
9 changes: 5 additions & 4 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pydantic.dataclasses import dataclass

from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.model import MODELS

from ..source_model.base import BaseInputModel, BaseReader

Expand Down Expand Up @@ -165,9 +164,11 @@ def __init__(self,
def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
"""Generate turbomind model config (config.ini)."""
_, bos_id, eos_id = self.input_model.tokenizer_info()
model_info = self.input_model.model_info()
session_len = model_info.get('max_position_embeddings',
MODELS.get(cfg.model_name)().session_len)
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(self.input_model.model_path,
trust_remote_code=True)
from lmdeploy.utils import _get_and_verify_max_len
session_len = _get_and_verify_max_len(hf_config, self.cfg.session_len)
final_cfg = cfg.__dict__
final_cfg.update(
dict(start_id=bos_id, end_id=eos_id, session_len=session_len + 8))
Expand Down
3 changes: 0 additions & 3 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,6 @@ def _from_hf(self, model_source: ModelSource, model_path: str,
input_model=input_model, cfg=cfg, to_file=False, out_dir='')

cfg = output_model.cfg
if engine_config.session_len is not None:
cfg.session_len = engine_config.session_len

cfg.update_prefill_conifg(engine_config)
self.model_name = cfg.model_name
self.config = cfg
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def _get_and_verify_max_len(
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
if not isinstance(hf_tm_config, PretrainedConfig) and \
hasattr(hf_tm_config, 'session_len'):
# turbomind backend already has session_len
return getattr(hf_tm_config, 'session_len')

logger = get_logger('lmdeploy')
derived_max_model_len = float('inf')
possible_keys = [
Expand Down

0 comments on commit 0b0508a

Please sign in to comment.