Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YaRN and Dynamic-YaRN RoPE Scaling Methods #30910

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,31 @@ class OpenLlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models on the deprecated folder should not be updated :) (let's remove the changes on open_llama)

strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`):
Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is
`{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float,
"beta_fast": float, "beta_slow": float,"finetuned": bool}`.
Fields:
original_max_position_embeddings (`int`, *optional*, defaults to 2048):
The original maximum sequence length. This is used to scale the RoPE embeddings.
extrapolation_factor (`float`, defaults to 1):
Factor to ajust the n-dimensional rotational scaling for extrapolation.
attention_factor (`float`, *optional*, defaults to 1):
Factor to adjust the weight attention scaling mechanism.
beta_fast (`float`, *optional*, defaults to 32):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
beta_slow (`float`, *optional*, defaults to 1):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
finetuned (`bool`, *optional*, defaults to `False`):
[Dynamic] Whether the model is finetuned or not.
For more details please refer to https://arxiv.org/abs/2309.00071.

Example:

Expand Down Expand Up @@ -114,6 +132,14 @@ def __init__(
shared_input_output_embedding=True,
rope_theta=10000.0,
rope_scaling=None,
yarn_rope_scaling={
"original_max_position_embeddings": 2048,
"extrapolation_factor": 1.0,
"attention_factor": 1.0,
"beta_fast": 32.0,
"beta_slow": 1.0,
"finetuned": False,
},
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -136,6 +162,8 @@ def __init__(
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.yarn_rope_scaling = yarn_rope_scaling
self._yarn_rope_scaling_validation()

super().__init__(
pad_token_id=pad_token_id,
Expand All @@ -159,9 +187,61 @@ def _rope_scaling_validation(self):
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation
def _yarn_rope_scaling_validation(self):
"""
Validate the `yarn_rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, "
"`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)
finetuned = self.rope_scaling.get("finetuned", None)

if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if (
extrapolation_factor is not None
and not isinstance(extrapolation_factor, float)
or extrapolation_factor < 0
or extrapolation_factor > 1
):
raise ValueError(
f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
if finetuned is not None and not isinstance(finetuned, bool):
raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}")

b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
191 changes: 191 additions & 0 deletions src/transformers/models/deprecated/open_llama/modeling_open_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,165 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.LlamaYaRNScalingRotaryEmbedding with Llama->OpenLlama
class OpenLlamaYaRNScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attention_factor=1,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)

self.original_max_position_embeddings = original_max_position_embeddings
self.extrapolation_factor = extrapolation_factor
self.attention_factor = attention_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow

self.yarn(device)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
emb = self.get_pos_embeddings(device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype())
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype())

# Get positional embeddings based on the current max sequence length
def get_pos_embeddings(self, device):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb

# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

# Find dimension range bounds based on rotations
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)

def linear_ramp_mask(self, min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

def get_mscale(self, scaling_factor=1):
if scaling_factor <= 1:
return 1.0
return 0.1 * math.log(scaling_factor) + 1.0

def forward(self, x, position_ids=None):
# Difference to the original RoPE: applies a scaling factor computed with
# the YaRN method (NTK-by-Parts + Attn Scaling)
# x: [batch_size, seq_len, head_dim]
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
emb = self.get_pos_embeddings(x.device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype)
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype)
return (
self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype),
self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype),
)

def yarn(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)

low, high = self.find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = (
1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

self.register_buffer("inv_freq", inv_freq)
# Get n-dimensional magnitude scaling corrected for interpolation
self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor)


# Copied from transformers.models.llama.modeling_llama.LlamaDynamicYaRNScalingRotaryEmbedding with Llama->OpenLlama
class OpenLlamaDynamicYaRNScalingRotaryEmbedding(OpenLlamaYaRNScalingRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attention_factor=1,
beta_fast=32,
beta_slow=1,
finetuned=False,
device=None,
):
super().__init__(
dim,
max_position_embeddings,
base,
scaling_factor,
original_max_position_embeddings,
extrapolation_factor,
attention_factor,
beta_fast,
beta_slow,
device,
)

if finetuned:
self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings
self.yarn(device)
else:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.mscale = 1

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
emb = self.get_pos_embeddings(device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype())
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype())

def forward(self, x, position_ids=None):
# Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded
# x: [batch_size, seq_len, head_dim]
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.scaling_factor = seq_len / self.original_max_position_embeddings
self.yarn(x.device)
emb = self.get_pos_embeddings(x.device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype)
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype)
return (
self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype),
self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype),
)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -238,6 +397,13 @@ def _init_rope(self):
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
# YaRN parameters
original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"]
extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"]
attention_factor = self.config.yarn_rope_scaling["attention_factor"]
beta_fast = self.config.yarn_rope_scaling["beta_fast"]
beta_slow = self.config.yarn_rope_scaling["beta_slow"]
finetuned = self.config.yarn_rope_scaling["finetuned"]
if scaling_type == "linear":
self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
self.head_dim,
Expand All @@ -252,6 +418,31 @@ def _init_rope(self):
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "yarn":
self.rotary_emb = OpenLlamaYaRNScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
original_max_position_embeddings=original_max_position_embeddings,
extrapolation_factor=extrapolation_factor,
attention_factor=attention_factor,
beta_fast=beta_fast,
beta_slow=beta_slow,
)
elif scaling_type == "dynamic-yarn":
self.rotary_emb = OpenLlamaDynamicYaRNScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
original_max_position_embeddings=original_max_position_embeddings,
extrapolation_factor=extrapolation_factor,
attention_factor=attention_factor,
beta_fast=beta_fast,
beta_slow=beta_slow,
finetuned=finetuned,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
Loading