Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YaRN and Dynamic-YaRN RoPE Scaling Methods #30910

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def __init__(self, config: FalconConfig):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = FalconRotaryEmbedding(
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/fuyu/configuration_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
58 changes: 52 additions & 6 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling
strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
For `yarn` and `dynamic-yarn` strategies, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
`attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio.
`beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
`beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
Expand Down Expand Up @@ -178,15 +187,52 @@ def _rope_scaling_validation(self):
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

if rope_scaling_type not in ["yarn", "dynamic-yarn"]:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)

if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")

b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
159 changes: 159 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,140 @@ def forward(self, x, position_ids):
return cos, sin


class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)

self.original_max_position_embeddings = original_max_position_embeddings
self.attention_factor = attention_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow

if self.attention_factor is None:
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0
Comment on lines +172 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a note saying that this is the default value according to the yarn paper, so it doesn't look like a magic number :)


self.compute_yarn_scaling(device)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
emb = self.get_pos_embeddings(device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype())
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype())
Comment on lines +181 to +182
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic to build these tensors was removed in a recent PR (#30743) -- we can delete them as well as all related functions/variables (emb, get_pos_embeddings, ...)

This comment applies to the other class too :)


# Get positional embeddings based on the current max sequence length
def get_pos_embeddings(self, device):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb

# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

# Find dimension range bounds based on rotations
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)

def linear_ramp_mask(self, min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

def forward(self, x, position_ids=None):
# Difference to the original RoPE: applies a scaling factor computed with
# the YaRN method (NTK-by-Parts + Attn Scaling)
# x: [bs, num_attention_heads, seq_len, head_size]
cos, sin = super().forward(x, position_ids)
cos = cos * self.mscale
sin = sin * self.mscale
return cos, sin

def compute_yarn_scaling(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)

low, high = self.find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

self.register_buffer("inv_freq", inv_freq)
# Get n-dimensional magnitude scaling corrected for interpolation
self.mscale = self.attention_factor


class LlamaDynamicYarnScalingRotaryEmbedding(LlamaYarnScalingRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(
dim,
max_position_embeddings,
base,
scaling_factor,
original_max_position_embeddings,
attention_factor,
beta_fast,
beta_slow,
device,
)

if self.max_position_embeddings != self.original_max_position_embeddings:
self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings
self.compute_yarn_scaling(device)
else:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.mscale = 1

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
emb = self.get_pos_embeddings(device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype())
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype())

def forward(self, x, position_ids=None):
# Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded
# x: [bs, num_attention_heads, seq_len, head_size]
seq_len = torch.max(position_ids) + 1
self.scaling_factor = seq_len / self.original_max_position_embeddings
self.compute_yarn_scaling(x.device)

cos, sin = super().forward(x, position_ids)
return cos, sin


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -275,6 +409,15 @@ def _init_rope(self):
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
# Yarn parameters
kwargs = {
"dim": self.config.rope_scaling.get("original_max_position_embeddings", None),
"max_position_embeddings": self.config.rope_scaling.get("attention_factor", None),
"base": self.config.rope_scaling.get("beta_fast", None),
"scaling_factor": self.config.rope_scaling.get("beta_slow", None),
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
Expand All @@ -289,6 +432,22 @@ def _init_rope(self):
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
elif scaling_type == "dynamic-yarn":
self.rotary_emb = LlamaDynamicYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = OlmoRotaryEmbedding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/phi/configuration_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/stablelm/configuration_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
42 changes: 41 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
)
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaDynamicYarnScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
)


Expand Down Expand Up @@ -397,7 +399,7 @@ def test_llama_token_classification_model(self):
def test_save_load_fast_init_from_base(self):
pass

@parameterized.expand([("linear",), ("dynamic",)])
@parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
Expand Down Expand Up @@ -491,6 +493,44 @@ def test_model_rope_scaling(self):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())

# Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)

# Sanity check Dynamic Yarn RoPE scaling
dynamic_yarn_scaling_rope = LlamaDynamicYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, position_ids_short)
dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, position_ids_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down
2 changes: 1 addition & 1 deletion tests/models/stablelm/test_modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)


# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
# Copied from transformers.tests.models.StableLm.test_modeling_StableLm.StableLmModelTester with StableLm -> StableLm
class StableLmModelTester:
# Ignore copy
def __init__(
Expand Down