Skip to content

Commit

Permalink
更新chatglm2-6b的双卡部署
Browse files Browse the repository at this point in the history
  • Loading branch information
hzg0601 committed Jun 26, 2023
1 parent 4750b74 commit d9a0315
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
6 changes: 3 additions & 3 deletions configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
},
"vicuna-13b-hf": {
"name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf",
"pretrained_model_name": "TheBloke/vicuna-13B-1.1-HF/",
"local_model_path": None,
"provides": "LLamaLLM"
},
Expand Down Expand Up @@ -173,7 +173,7 @@

# LLM 名称
#! bug: 调用fastchat接口时,若openai版本为0.27.6,则会raise AttributeError: 'str' object has no attribute 'get'
LLM_MODEL = "moss-int8"
LLM_MODEL = "chatglm2-6b"
# 量化加载8bit 模型
LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
Expand All @@ -193,7 +193,7 @@

# LLM running device
#? bug, 如果设为cpu,则在加载完模型后,不会进行下一步
LLM_DEVICE = "cuda:0" #"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
LLM_DEVICE = "cuda" #"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
Expand Down
22 changes: 20 additions & 2 deletions models/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ def recursively_load_model(LoaderClass,
try_turn = 0
while True:
try:

if config is not None:
print(config,checkpoint)
print("*"*80)
model = LoaderClass.from_pretrained(checkpoint,
config=config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
resume_download=resume_download)
else:
print(checkpoint, kwargs)
print("-"*80)
model = LoaderClass.from_pretrained(checkpoint,**kwargs)
return model
except Exception as e:
Expand Down Expand Up @@ -182,19 +187,32 @@ def _load_model(self, model_name):
trust_remote_code=True).half()
# 可传入device_map自定义每张卡的部署情况
if self.device_map is None:
if 'chatglm' in model_name.lower():
if 'chatglm' in model_name.lower() and "chatglm2" not in model_name.lower():
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
elif "chatglm2" in model_name.lower():
from accelerate.utils import get_balanced_memory
max_memory = get_balanced_memory(model,
dtype=torch.int8 if self.load_in_8bit else None,
low_zero=False,
no_split_module_classes=model._no_split_modules)
self.device_map = infer_auto_device_map(model,
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
max_memory=max_memory,
no_split_module_classes=model._no_split_modules)
else:
# 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 实测在bloom模型上如此
# print(dir(model))
# print("*"*80)
# model.tie_weights()
self.device_map = infer_auto_device_map(model,
dtype=torch.int8,
no_split_module_classes=model._no_split_modules)

model = dispatch_model(model, device_map=self.device_map)
else:
model = (
Expand Down

0 comments on commit d9a0315

Please sign in to comment.