Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE loses precision for Llama / Gemma + Gemma logits.float() #29285

Merged
merged 20 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def forward(self, x, position_ids, seq_len=None):

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)

# Force float32 since bfloat16 loses precision on long contexts
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False):
danielhanchen marked this conversation as resolved.
Show resolved Hide resolved
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)

Expand Down Expand Up @@ -1079,7 +1083,8 @@ def forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)

logits = logits.float()
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

loss = None
if labels is not None:
# Shift so that tokens < n predict n
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
# Force float32 since bfloat16 loses precision on long contexts
with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False):
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
danielhanchen marked this conversation as resolved.
Show resolved Hide resolved
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=x.dtype)
sin = emb.sin().to(dtype=x.dtype)
Expand Down