Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Add Bloom model inference support #5660

Draft
wants to merge 8 commits into
base: feature/colossal-infer
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"baichuan": " <reserved_106> {input_text} <reserved_107> ",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
"bloom": "Assume you are a helpful robot. Please help react to my question or auto complete my prompt."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing {input_text}

# "bloom": "[INST] <<SYS>>\nYou are an intelligent and comprehensive assistant. Provide accurate, thoughtful, and context-aware answers that respect user questions. Avoid content that is harmful, misleading, or unethical. Prioritize safety and fairness in all responses. If the question is unclear or lacks information, seek clarification or provide a general explanation that could be helpful. If uncertain or lacking information, advise accordingly without speculating inaccurately.\n<</SYS>>\n{input_text}[/INST]",
}


Expand Down
3 changes: 3 additions & 0 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from colossalai.accelerator import get_accelerator
Expand All @@ -39,8 +40,10 @@
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
"BloomForCausalLM": BloomForCausalLM,
}


_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]


Expand Down
21 changes: 11 additions & 10 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Any, List, Tuple

import torch
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -15,9 +15,11 @@
GIGABYTE = 1024**3


def get_model_config_attr(config: PretrainedConfig, attr_name: str):
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None):
if hasattr(config, attr_name):
return getattr(config, attr_name)
if alter_attr is not None:
return alter_attr
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
return getattr(config, config.attribute_map[attr_name])
raise AttributeError(f"{attr_name} is not found in config")
Comment on lines +18 to 25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, revise this part (remove the function and refer to attr_mapusage in transformers 4.36

Expand Down Expand Up @@ -53,7 +55,12 @@ class KVCacheManager:
And it's possible to have a batch of sequences with different lengths of block tables.
"""

def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
def __init__(
self,
config: InferenceConfig,
model_config: PretrainedConfig,
verbose: bool = False,
) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()

Expand All @@ -64,15 +71,9 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads", alter_attr=self.head_num)
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num

if hasattr(config, "num_key_value_heads"):
self.kv_head_num = getattr(config, "num_key_value_heads")
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
else:
self.kv_head_num = self.head_num

assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
Expand Down
19 changes: 1 addition & 18 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
Expand Down Expand Up @@ -47,22 +46,6 @@
logger = get_dist_logger(__name__)


# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes


def baichuan_rmsnorm_forward(
self,
hidden_states: torch.Tensor,
Expand Down