Skip to content

Commit

Permalink
Fix RoPE for llama
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Feb 28, 2024
1 parent 1a50a4b commit b860a22
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,27 @@ def cos_cached(self):
)
return self._cos_cached

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")

# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
with torch.autocast(device_type=self.inv_freq.device.type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = cos.to(dtype=x.dtype)
sin = sin.to(dtype=x.dtype)
# backwards compatibility
self._cos_cached = cos
self._sin_cached = sin
return cos, sin


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down

0 comments on commit b860a22

Please sign in to comment.