Skip to content

Commit

Permalink
Quantization for FSDPA #976
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubramony committed May 28, 2024
1 parent 570cfa1 commit 87386b7
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 3 deletions.
14 changes: 11 additions & 3 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python ../gaudi_spawn.py
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--bucket_size=128 \
--bucket_internal \
--use_flash_attention \
--flash_attention_recompute \
--bf16 \
--batch_size 1
```
Expand All @@ -280,7 +283,10 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--bucket_size=128 \
--bucket_internal \
--use_flash_attention \
--flash_attention_recompute \
--bf16 \
--batch_size 1 \
--fp8
Expand All @@ -296,8 +302,10 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--use_flash_attention \
--flash_attention_recompute \
--bf16 \
--batch_size 277 \
--batch_size 350 \
--max_new_tokens 2048 \
--max_input_tokens 2048 \
--limit_hpu_graphs \
Expand Down
140 changes: 140 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,134 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


class GaudiLlamaMLP(LlamaMLP):
def pre_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
output = sum(down_proj)
else:
input = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
output = self.down_proj(input)
return output

def mlp_all_reduce(self, x):
if hasattr(self.down_proj, "all_reduce"):
self.down_proj.all_reduce(x)

def post_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
return x
if hasattr(self.down_proj, "post_all_reduce"):
return self.down_proj.post_all_reduce(x)
return x


def gaudi_llama_repeat_kv(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
- Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask

new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
key_states = key_states.reshape(new_kv_shape)
value_states = value_states.reshape(new_kv_shape)

batch, _, q_len, head_dim = query_states.shape
new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
query_states = query_states.reshape(new_q_shape)

if attention_mask is not None:
# Add groups dim and set to 1
attention_mask = attention_mask.unsqueeze(1)

return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

def update(self, prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


class GaudiLlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
Expand Down Expand Up @@ -442,20 +570,32 @@ def pre_attn_forward(
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
<<<<<<< HEAD
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
=======
query_states, key_states, value_states, attention_mask, 0.0, False, None
>>>>>>> d638f6a54598a41042bfd120bb33c0a805ee7034
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
<<<<<<< HEAD
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
=======
query_states, key_states, value_states, None, 0.0, True, None
>>>>>>> d638f6a54598a41042bfd120bb33c0a805ee7034
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
<<<<<<< HEAD
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
=======
query_states, key_states, value_states, attention_mask, 0.0, False, None
>>>>>>> d638f6a54598a41042bfd120bb33c0a805ee7034
)

else:
Expand Down

0 comments on commit 87386b7

Please sign in to comment.