Skip to content

Commit

Permalink
Cohere Support (#457)
Browse files Browse the repository at this point in the history
Co-authored-by: kwonjihun-theori <jihun@theori.io>
  • Loading branch information
TechxGenus and kwonjihun-theori committed Jun 8, 2024
1 parent 5f3785d commit 5fa02b5
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 5 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .gemma import GemmaAWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .cohere import CohereAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
"cohere": CohereAWQForCausalLM,
}


Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
"cohere": "AutoModelForCausalLM",
}


Expand Down
128 changes: 128 additions & 0 deletions awq/models/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import CohereBlock
from awq.modules.fused.model import CohereModel
from transformers.models.cohere.modeling_cohere import (
CohereDecoderLayer as OldCohereDecoderLayer,
CohereForCausalLM as OldCohereForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm

class CohereAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "CohereDecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldCohereForCausalLM):
fuser = CohereFuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldCohereForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: OldCohereDecoderLayer):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldCohereForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(
module: OldCohereDecoderLayer, input_feat, module_kwargs
):
layers = []

# input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
module.mlp.gate_proj,
module.mlp.up_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear out
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers

class CohereFuser:
def __init__(self, model: OldCohereForCausalLM):
self.model = model

self.cohere_blocks: List[Tuple[str, OldCohereDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "CohereDecoderLayer".lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldCohereDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = module.input_layernorm
# norm_2 = FasterTransformerRMSNorm(
# module.post_attention_layernorm.weight,
# module.post_attention_layernorm.variance_epsilon,
# )
blocks.append(
CohereBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
# norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

self.model.model = CohereModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
3 changes: 3 additions & 0 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def __init__(
self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
self.is_neox = True

if kwargs.get("is_neox") is not None:
self.is_neox = kwargs["is_neox"]

def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
Expand Down
67 changes: 67 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,73 @@ def forward(
return out, None, past_key_value


class CohereBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
# norm_2,
dev,
max_seq_len,
rope_theta=10000,
partial_rotary_factor=1.0,
use_alibi=False,
head_dim=None,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = hidden_size // n_heads

# To support gemma-7b, its head_dim is separate
if head_dim:
self.head_dim = head_dim
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim,
is_neox=False,
).to(dev)
# self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev

def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
)

h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(norm_out)

return out, None, past_key_value


class MPTBlock(nn.Module):
def __init__(
self,
Expand Down
67 changes: 63 additions & 4 deletions awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FalconDecoderLayer,
LlamaLikeBlock,
MixtralBlock,
CohereBlock,
)


Expand Down Expand Up @@ -83,11 +84,11 @@ def __init__(self, vocab_size, blocks, embedding, norm):
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@property
def embed_tokens(self):
return self.embedding

@property
def layers(self):
return self.blocks
Expand Down Expand Up @@ -124,9 +125,67 @@ def forward(
h,
mask,
)
h, _, _ = layer(
h, None, attention_mask=mask, is_causal=is_causal
h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)

return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=None,
hidden_states=(),
attentions=(),
)


class CohereModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[CohereBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@property
def embed_tokens(self):
return self.embedding

@property
def layers(self):
return self.blocks

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape

fused_utils.prepare_cache(self.blocks, seqlen)

h = self.embedding(input_ids)

mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)

for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)

return BaseModelOutputWithPast(
Expand Down
3 changes: 2 additions & 1 deletion awq/quantize/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation

allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm]
allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, CohereLayerNorm]
allowed_act_fns = [
nn.GELU,
BloomGelu,
Expand Down

0 comments on commit 5fa02b5

Please sign in to comment.