-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] Add Bloom model inference support #5660
base: feature/colossal-infer
Are you sure you want to change the base?
[Inference] Add Bloom model inference support #5660
Conversation
@@ -28,6 +28,8 @@ | |||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]", | |||
"baichuan": "<reserved_106>{input_text}<reserved_107>", | |||
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ", | |||
"bloom": "Assume you are a helpful robot. Please help react to my question or auto complete my prompt." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing {input_text}
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None): | ||
if hasattr(config, attr_name): | ||
return getattr(config, attr_name) | ||
if alter_attr is not None: | ||
return alter_attr | ||
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]): | ||
return getattr(config, config.attribute_map[attr_name]) | ||
raise AttributeError(f"{attr_name} is not found in config") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, revise this part (remove the function and refer to attr_map
usage in transformers 4.36
import transformers | ||
from packaging.version import Version | ||
# import transformers | ||
# from packaging.version import Version | ||
|
||
assert Version(transformers.__version__) <= Version( | ||
"4.33.0" | ||
), "The Bloom model should run on a transformers version not greater than 4.33.0." | ||
# assert Version(transformers.__version__) <= Version( | ||
# "4.33.0" | ||
# ), "The Bloom model should run on a transformers version not greater than 4.33.0." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could just remove these lines (remove these comments). For your reference, the assertion has been removed in main
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn | ||
|
||
MODEL_PATH = "/home/lixingjian/models/bloom-560m" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent exposing path on your dev machine
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM | ||
from transformers import BloomForCausalLM, BloomTokenizerFast, GenerationConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better prevent modifying this pytest. You might want to test the model locally, or create a pytest in test_models/
(we'll refactor these tests anyway)
@@ -1,111 +0,0 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a commit to remove tests/test_infer/test_models/test_baichuan.py
# class NopadBloomAttention(nn.Module): | ||
# def __init__( | ||
# self, | ||
# hidden_size: int, | ||
# n_heads: int, | ||
# attn_qproj_w: torch.Tensor = None, | ||
# attn_kproj_w: torch.Tensor = None, | ||
# attn_vproj_w: torch.Tensor = None, | ||
# attn_oproj_w: torch.Tensor = None, | ||
# ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the end2end flow for bloom is ready and none of the attention and mlp classes is going to be used, you might want to remove these classes.
# return attn_output | ||
|
||
|
||
class NopadBloomMLP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
# print(f"[TEST] hidden_state {hidden_states} with shape {hidden_states.shape}\n qkv_weight {self.qkv_weight} with shape {self.qkv_weight.shape}") | ||
|
||
# print(f"[DEBUG] after qkv: query_states {query_states} with shape {query_states.shape}, \nkey_states {key_states},\n value_states {value_states}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to remove comments/testing messages when merging
# print(f"[TEST] before merge bsz, query_states {query_states} with shape {query_states.shape}, \nkey_states {key_states},\n value_states {value_states}") | ||
|
||
# [bsz * seq_len, num_heads head_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Signed-off-by: char-1ee <xingjianli59@gmail.com>
for more information, see https://pre-commit.ci
d911664
to
d36c173
Compare
for more information, see https://pre-commit.ci
📌 Checklist before creating the PR
[doc/gemini/tensor/...]: A concise description
pip install pre-commit && pre-commit install
🚨 Issue number
📝 What does this PR do?
💥 Checklist before requesting a review
⭐️ Do you enjoy contributing to Colossal-AI?
Tell us more if you don't enjoy contributing to Colossal-AI.