Skip to content

Commit

Permalink
RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285)
Browse files Browse the repository at this point in the history
* Update modeling_llama.py

Llama - Force float32 since bfloat16 loses precision on long contexts

* Update modeling_llama.py

* Update modeling_gemma.py

Fix RoPE and logits.float()

* @torch.no_grad()

* @torch.no_grad()

* Cos, Sin to float32

* cos, sin to float32

* Update src/transformers/models/gemma/modeling_gemma.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Resolve PR conflicts

* Fix RoPE for llama

* Revert "Fix RoPE for llama"

This reverts commit b860a22.

* Fix RoPE for llama

* RoPE device

* Autocast device type

* RoPE

* RoPE isinstance

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
danielhanchen and ArthurZucker committed Feb 28, 2024
1 parent 7628b3a commit d3a4b47
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
17 changes: 12 additions & 5 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,25 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.base = base
self.register_buffer("inv_freq", None, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Copied from transformers.models.llama.modeling_llama.rotate_half
Expand Down Expand Up @@ -1082,7 +1089,7 @@ def forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)

logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,24 @@ def cos_cached(self):
)
return self._cos_cached

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")

# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down

0 comments on commit d3a4b47

Please sign in to comment.