Skip to content

Commit

Permalink
[Model] Add support for IBM Granite Code models (vllm-project#4636)
Browse files Browse the repository at this point in the history
  • Loading branch information
yikangshen authored and robertgshaw2-neuralmagic committed May 19, 2024
1 parent 9132d19 commit 18355a9
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ def __init__(
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QKVParallelLinear] = None,
bias: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
bias=bias,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
bias=bias,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -348,6 +350,8 @@ def __init__(
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
Expand Down

0 comments on commit 18355a9

Please sign in to comment.