From cc9b82e916edbd898a8e690a430552b7f4ad99b5 Mon Sep 17 00:00:00 2001 From: Miguel Monte e Freitas Date: Sun, 19 May 2024 22:02:56 +0100 Subject: [PATCH 1/2] Add YaRN and Dynamic-YaRN RoPE Scaling Methods YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida --- .../open_llama/configuration_open_llama.py | 88 +++++++- .../open_llama/modeling_open_llama.py | 191 ++++++++++++++++++ .../models/falcon/configuration_falcon.py | 84 +++++++- .../models/falcon/modeling_falcon.py | 188 +++++++++++++++++ .../models/fuyu/configuration_fuyu.py | 4 +- .../models/gpt_neox/configuration_gpt_neox.py | 88 +++++++- .../models/gpt_neox/modeling_gpt_neox.py | 190 +++++++++++++++++ .../models/llama/configuration_llama.py | 87 +++++++- .../models/llama/modeling_llama.py | 189 +++++++++++++++++ .../models/olmo/configuration_olmo.py | 88 +++++++- src/transformers/models/olmo/modeling_olmo.py | 191 ++++++++++++++++++ .../persimmon/configuration_persimmon.py | 84 +++++++- .../models/persimmon/modeling_persimmon.py | 190 +++++++++++++++++ .../models/phi/configuration_phi.py | 84 +++++++- src/transformers/models/phi/modeling_phi.py | 190 +++++++++++++++++ .../models/stablelm/configuration_stablelm.py | 84 +++++++- .../models/stablelm/modeling_stablelm.py | 190 +++++++++++++++++ tests/models/falcon/test_modeling_falcon.py | 38 +++- .../models/gpt_neox/test_modeling_gpt_neox.py | 40 +++- tests/models/llama/test_modeling_llama.py | 38 +++- tests/models/olmo/test_modeling_olmo.py | 4 +- .../persimmon/test_modeling_persimmon.py | 40 +++- tests/models/phi/test_modeling_phi.py | 40 +++- .../models/stablelm/test_modeling_stablelm.py | 42 +++- 24 files changed, 2412 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py index ae2add5a5f29a..192dfa2eea728 100644 --- a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py @@ -66,13 +66,31 @@ class OpenLlamaConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling + strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. Example: @@ -114,6 +132,14 @@ def __init__( shared_input_output_embedding=True, rope_theta=10000.0, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, **kwargs, ): self.vocab_size = vocab_size @@ -136,6 +162,8 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() + self.yarn_rope_scaling = yarn_rope_scaling + self._yarn_rope_scaling_validation() super().__init__( pad_token_id=pad_token_id, @@ -159,9 +187,61 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 098f8c7da50d5..9b12797568bfd 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -147,6 +147,165 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +# Copied from transformers.models.llama.modeling_llama.LlamaYaRNScalingRotaryEmbedding with Llama->OpenLlama +class OpenLlamaYaRNScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device, scaling_factor) + + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, position_ids=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) + return ( + self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicYaRNScalingRotaryEmbedding with Llama->OpenLlama +class OpenLlamaDynamicYaRNScalingRotaryEmbedding(OpenLlamaYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + + def forward(self, x, position_ids=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) + return ( + self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), + ) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -238,6 +397,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -252,6 +418,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = OpenLlamaYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = OpenLlamaDynamicYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index 0dd61047dd275..25faf60fdf1f2 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -77,13 +77,31 @@ class FalconConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling + strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. bos_token_id (`int`, *optional*, defaults to 11): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 11): @@ -133,6 +151,14 @@ def __init__( max_position_embeddings=2048, rope_theta=10000.0, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, bos_token_id=11, eos_token_id=11, ffn_hidden_size=None, @@ -162,12 +188,14 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.yarn_rope_scaling = yarn_rope_scaling self.activation = activation if ffn_hidden_size is None: self.ffn_hidden_size = hidden_size * 4 else: self.ffn_hidden_size = ffn_hidden_size self._rope_scaling_validation() + self._yarn_rope_scaling_validation() super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -201,3 +229,55 @@ def _rope_scaling_validation(self): ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b9fbf8d70bd60..dcbf540ae55c3 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -209,6 +209,162 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +class FalconYaRNScalingRotaryEmbedding(FalconRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device) + + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, seq_len=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +class FalconDynamicYaRNScalingRotaryEmbedding(FalconYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + def forward(self, x, seq_len=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) @@ -308,6 +464,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = FalconLinearScalingRotaryEmbedding( self.head_dim, @@ -322,6 +485,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = FalconYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = FalconDynamicYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/fuyu/configuration_fuyu.py b/src/transformers/models/fuyu/configuration_fuyu.py index 8a5013a65134c..b9f66d3a8ea70 100644 --- a/src/transformers/models/fuyu/configuration_fuyu.py +++ b/src/transformers/models/fuyu/configuration_fuyu.py @@ -200,9 +200,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index d559148a7221f..d7deb791dfac3 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -74,13 +74,31 @@ class GPTNeoXConfig(PretrainedConfig): Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training speedup at large scales (e.g. 20B). rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling + strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to`{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. attention_bias (`bool`, *optional*, defaults to `True`): Whether to use a bias in the query, key, value and output projection layers during self-attention. @@ -124,6 +142,14 @@ def __init__( tie_word_embeddings=False, use_parallel_residual=True, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, attention_bias=True, **kwargs, ): @@ -146,8 +172,10 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.use_parallel_residual = use_parallel_residual self.rope_scaling = rope_scaling + self.yarn_rope_scaling = yarn_rope_scaling self.attention_bias = attention_bias self._rope_scaling_validation() + self._yarn_rope_scaling_validation() if self.hidden_size % self.num_attention_heads != 0: raise ValueError( @@ -168,9 +196,61 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e0b2309fc9658..666cd493044c0 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch GPTNeoX model.""" +import math from typing import Optional, Tuple, Union import torch @@ -136,6 +137,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding( self.rotary_ndims, @@ -150,6 +158,31 @@ def _init_rope(self): base=self.config.rotary_emb_base, scaling_factor=scaling_factor, ) + elif scaling_type == "yarn": + self.rotary_emb = GPTNeoXYaRNScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = GPTNeoXDynamicYaRNScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -607,6 +640,163 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin(), persistent=False) +# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->GPTNeoX +class GPTNeoXYaRNScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device) + + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, seq_len=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +class GPTNeoXDynamicYaRNScalingRotaryEmbedding(GPTNeoXYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + def forward(self, x, seq_len=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index b406b12fc702c..7c6af7b16ba27 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -84,13 +84,31 @@ class LlamaConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling + strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -134,6 +152,14 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, attention_bias=False, attention_dropout=0.0, mlp_bias=False, @@ -159,6 +185,8 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() + self.yarn_rope_scaling = yarn_rope_scaling + self._yarn_rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias @@ -184,9 +212,60 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5d8a3f987a47e..e2d5dd21131b7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -151,6 +151,163 @@ def forward(self, x, position_ids): return cos, sin +class LlamaYaRNScalingRotaryEmbedding(LlamaRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device, scaling_factor) + + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, position_ids=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) + return ( + self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +class LlamaDynamicYaRNScalingRotaryEmbedding(LlamaYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + + def forward(self, x, position_ids=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) + return ( + self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), + ) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -277,6 +434,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -291,6 +455,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = LlamaDynamicYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index 56cd01f7f2a72..b87934b6a1bb7 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -76,13 +76,31 @@ class OlmoConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling + strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -125,6 +143,14 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, attention_bias=False, attention_dropout=0.0, clip_qkv=None, @@ -148,6 +174,8 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() + self.yarn_rope_scaling = yarn_rope_scaling + self._yarn_rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.clip_qkv = clip_qkv @@ -174,9 +202,61 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9b4b08239bc4d..bb139932cde61 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -149,6 +149,165 @@ def forward(self, x, position_ids): return cos, sin +# Copied from transformers.models.llama.modeling_llama.LlamaYaRNScalingRotaryEmbedding with Llama->Olmo +class OlmoYaRNScalingRotaryEmbedding(OlmoRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device, scaling_factor) + + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, position_ids=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) + return ( + self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicYaRNScalingRotaryEmbedding with Llama->Olmo +class OlmoDynamicYaRNScalingRotaryEmbedding(OlmoYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) + + def forward(self, x, position_ids=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) + return ( + self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), + ) + + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -261,6 +420,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = OlmoLinearScalingRotaryEmbedding( self.head_dim, @@ -275,6 +441,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = OlmoYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = OlmoDynamicYaRNScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/persimmon/configuration_persimmon.py b/src/transformers/models/persimmon/configuration_persimmon.py index 04bf792964c89..1180e96c093fd 100644 --- a/src/transformers/models/persimmon/configuration_persimmon.py +++ b/src/transformers/models/persimmon/configuration_persimmon.py @@ -67,6 +67,24 @@ class PersimmonConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to`{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. qk_layernorm (`bool`, *optional*, default to `True`): Whether or not to normalize the Queries and Keys after projecting the hidden states hidden_dropout (`float`, *optional*, default to 0.0): @@ -103,6 +121,14 @@ def __init__( tie_word_embeddings=False, rope_theta=25000.0, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, qk_layernorm=True, hidden_dropout=0.0, attention_dropout=0.0, @@ -124,11 +150,13 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.yarn_rope_scaling = yarn_rope_scaling self.qk_layernorm = qk_layernorm self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() + self._yarn_rope_scaling_validation() super().__init__( pad_token_id=pad_token_id, @@ -152,9 +180,61 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 75ba7163ba23c..fddb7c61ca88a 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -129,6 +129,164 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->Persimmon +class PersimmonYaRNScalingRotaryEmbedding(PersimmonRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device) + + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, seq_len=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicYaRNScalingRotaryEmbedding with Falcon->Persimmon +class PersimmonDynamicYaRNScalingRotaryEmbedding(PersimmonYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + def forward(self, x, seq_len=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -232,6 +390,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = PersimmonLinearScalingRotaryEmbedding( int(self.partial_rotary_factor * self.head_dim), @@ -246,6 +411,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = PersimmonYaRNScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = PersimmonDynamicYaRNScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index d221255f1182b..5e1da4fb454bd 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -83,6 +83,24 @@ class PhiConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. qk_layernorm (`bool`, *optional*, defaults to `False`): @@ -129,6 +147,14 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, partial_rotary_factor=0.5, qk_layernorm=False, bos_token_id=1, @@ -155,9 +181,11 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.yarn_rope_scaling = yarn_rope_scaling self.partial_rotary_factor = partial_rotary_factor self.qk_layernorm = qk_layernorm self._rope_scaling_validation() + self._yarn_rope_scaling_validation() super().__init__( bos_token_id=bos_token_id, @@ -180,9 +208,61 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 1f82b09a25704..92e9eed36cb45 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -160,6 +160,164 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->Phi +class PhiYaRNScalingRotaryEmbedding(PhiRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device) + + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, seq_len=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicYaRNScalingRotaryEmbedding with Falcon->Phi +class PhiDynamicYaRNScalingRotaryEmbedding(PhiYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + def forward(self, x, seq_len=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -283,6 +441,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = PhiLinearScalingRotaryEmbedding( int(self.partial_rotary_factor * self.head_dim), @@ -297,6 +462,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = PhiYaRNScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = PhiDynamicYaRNScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/stablelm/configuration_stablelm.py b/src/transformers/models/stablelm/configuration_stablelm.py index 64b39fe20e518..ea5153aa96c86 100644 --- a/src/transformers/models/stablelm/configuration_stablelm.py +++ b/src/transformers/models/stablelm/configuration_stablelm.py @@ -78,6 +78,24 @@ class StableLmConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): + Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is + `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, + "beta_fast": float, "beta_slow": float,"finetuned": bool}`. + Fields: + original_max_position_embeddings (`int`, *optional*, defaults to 2048): + The original maximum sequence length. This is used to scale the RoPE embeddings. + extrapolation_factor (`float`, defaults to 1): + Factor to ajust the n-dimensional rotational scaling for extrapolation. + attention_factor (`float`, *optional*, defaults to 1): + Factor to adjust the weight attention scaling mechanism. + beta_fast (`float`, *optional*, defaults to 32): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + beta_slow (`float`, *optional*, defaults to 1): + Parameter to set the boundary for interpolation (only) in the linear ramp function. + finetuned (`bool`, *optional*, defaults to `False`): + [Dynamic] Whether the model is finetuned or not. + For more details please refer to https://arxiv.org/abs/2309.00071. use_qkv_bias (`bool`, *optional*, defaults to `False`): Whether or not the model should use bias for qkv layers. qk_layernorm (`bool`, *optional*, defaults to `False`): @@ -124,6 +142,14 @@ def __init__( tie_word_embeddings=False, rope_theta=10_000, rope_scaling=None, + yarn_rope_scaling={ + "original_max_position_embeddings": 2048, + "extrapolation_factor": 1.0, + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "finetuned": False, + }, use_qkv_bias=False, qk_layernorm=False, use_parallel_residual=False, @@ -149,6 +175,7 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.yarn_rope_scaling = yarn_rope_scaling self.use_qkv_bias = use_qkv_bias self.qk_layernorm = qk_layernorm self.use_parallel_residual = use_parallel_residual @@ -156,6 +183,7 @@ def __init__( self.attention_dropout = attention_dropout self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() + self._yarn_rope_scaling_validation() super().__init__( bos_token_id=bos_token_id, @@ -178,9 +206,61 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation + def _yarn_rope_scaling_validation(self): + """ + Validate the `yarn_rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " + "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + finetuned = self.rope_scaling.get("finetuned", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if ( + extrapolation_factor is not None + and not isinstance(extrapolation_factor, float) + or extrapolation_factor < 0 + or extrapolation_factor > 1 + ): + raise ValueError( + f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + if finetuned is not None and not isinstance(finetuned, bool): + raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index e8d07340d3b0f..f7683bee62822 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -155,6 +155,164 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->StableLm +class StableLmYaRNScalingRotaryEmbedding(StableLmRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device) + + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn(device) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + # Get positional embeddings based on the current max sequence length + def get_pos_embeddings(self, device): + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(self, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 + + def forward(self, x, seq_len=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + def yarn(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + + +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicYaRNScalingRotaryEmbedding with Falcon->StableLm +class StableLmDynamicYaRNScalingRotaryEmbedding(StableLmYaRNScalingRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attention_factor=1, + beta_fast=32, + beta_slow=1, + finetuned=False, + device=None, + ): + super().__init__( + dim, + max_position_embeddings, + base, + scaling_factor, + original_max_position_embeddings, + extrapolation_factor, + attention_factor, + beta_fast, + beta_slow, + device, + ) + + if finetuned: + self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings + self.yarn(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.mscale = 1 + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + emb = self.get_pos_embeddings(device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) + + def forward(self, x, seq_len=None): + # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded + # x: [batch_size, seq_len, head_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.yarn(x.device) + emb = self.get_pos_embeddings(x.device) + + self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) + return ( + self._cos_cached[:seq_len, ...].to(dtype=x.dtype), + self._sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -291,6 +449,13 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # YaRN parameters + original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] + extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] + attention_factor = self.config.yarn_rope_scaling["attention_factor"] + beta_fast = self.config.yarn_rope_scaling["beta_fast"] + beta_slow = self.config.yarn_rope_scaling["beta_slow"] + finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = StableLmLinearScalingRotaryEmbedding( int(self.partial_rotary_factor * self.head_dim), @@ -305,6 +470,31 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = StableLmYaRNScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + ) + elif scaling_type == "dynamic-yarn": + self.rotary_emb = StableLmDynamicYaRNScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + original_max_position_embeddings=original_max_position_embeddings, + extrapolation_factor=extrapolation_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + finetuned=finetuned, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 59ab316140342..dccac591eea87 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -54,8 +54,10 @@ ) from transformers.models.falcon.modeling_falcon import ( FalconDynamicNTKScalingRotaryEmbedding, + FalconDynamicYaRNScalingRotaryEmbedding, FalconLinearScalingRotaryEmbedding, FalconRotaryEmbedding, + FalconYaRNScalingRotaryEmbedding, ) @@ -443,7 +445,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -510,6 +512,40 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check YaRN RoPE scaling + yarn_scaling_rope = FalconYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + # Sanity check Dynamic YaRN RoPE scaling + dynamic_yarn_scaling_rope = FalconDynamicYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) + dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) + dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) + dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) + # TODO: @Fxmarty @is_flaky(max_attempts=3, description="flaky on some models.") @require_torch_sdpa diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 92d130b35101b..cc645c73d0c2e 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -40,8 +40,10 @@ ) from transformers.models.gpt_neox.modeling_gpt_neox import ( GPTNeoXDynamicNTKScalingRotaryEmbedding, + GPTNeoXDynamicYaRNScalingRotaryEmbedding, GPTNeoXLinearScalingRotaryEmbedding, GPTNeoXRotaryEmbedding, + GPTNeoXYaRNScalingRotaryEmbedding, ) @@ -305,7 +307,7 @@ def test_model_for_token_classification(self): def test_feed_forward_chunking(self): pass - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -329,7 +331,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -397,6 +399,40 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check YaRN RoPE scaling + yarn_scaling_rope = GPTNeoXYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + # Sanity check Dynamic YaRN RoPE scaling + dynamic_yarn_scaling_rope = GPTNeoXDynamicYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + scaling_factor=scaling_factor, + ).to(torch_device) + dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) + dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) + dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) + dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) + @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 58269d62e08c2..0d7be13d20a7c 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -53,8 +53,10 @@ ) from transformers.models.llama.modeling_llama import ( LlamaDynamicNTKScalingRotaryEmbedding, + LlamaDynamicYaRNScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding, + LlamaYaRNScalingRotaryEmbedding, ) @@ -397,7 +399,7 @@ def test_llama_token_classification_model(self): def test_save_load_fast_init_from_base(self): pass - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -420,7 +422,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -491,6 +493,38 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check YaRN RoPE scaling + yarn_scaling_rope = LlamaYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + # Sanity check Dynamic YaRN RoPE scaling + dynamic_yarn_scaling_rope = LlamaDynamicYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, position_ids_short) + dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) + @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 906bd73a70d2a..bea40355e326d 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -322,7 +322,7 @@ def test_save_load_fast_init_from_base(self): def test_eager_matches_sdpa_generate(self): super().test_eager_matches_sdpa_generate() - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -345,7 +345,7 @@ def test_model_rope_scaling(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 46a650c55abfe..b692ac71c10db 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -49,8 +49,10 @@ ) from transformers.models.persimmon.modeling_persimmon import ( PersimmonDynamicNTKScalingRotaryEmbedding, + PersimmonDynamicYaRNScalingRotaryEmbedding, PersimmonLinearScalingRotaryEmbedding, PersimmonRotaryEmbedding, + PersimmonYaRNScalingRotaryEmbedding, ) @@ -390,7 +392,7 @@ def test_persimmon_token_classification_model(self): def test_save_load_fast_init_from_base(self): pass - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -414,7 +416,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -482,6 +484,40 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check YaRN RoPE scaling + yarn_scaling_rope = PersimmonYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + # Sanity check Dynamic YaRN RoPE scaling + dynamic_yarn_scaling_rope = PersimmonDynamicYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) + dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) + dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) + dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) + @require_torch class PersimmonIntegrationTest(unittest.TestCase): diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index e3c145bfa268c..63a72cfee806f 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -49,8 +49,10 @@ ) from transformers.models.phi.modeling_phi import ( PhiDynamicNTKScalingRotaryEmbedding, + PhiDynamicYaRNScalingRotaryEmbedding, PhiLinearScalingRotaryEmbedding, PhiRotaryEmbedding, + PhiYaRNScalingRotaryEmbedding, ) @@ -366,7 +368,7 @@ def test_phi_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -390,7 +392,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -458,6 +460,40 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check YaRN RoPE scaling + yarn_scaling_rope = PhiYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + # Sanity check Dynamic YaRN RoPE scaling + dynamic_yarn_scaling_rope = PhiDynamicYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) + dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) + dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) + dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) + @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 083f928612a03..a8b060acccab3 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -48,12 +48,14 @@ ) from transformers.models.stablelm.modeling_stablelm import ( StableLmDynamicNTKScalingRotaryEmbedding, + StableLmDynamicYaRNScalingRotaryEmbedding, StableLmLinearScalingRotaryEmbedding, StableLmRotaryEmbedding, + StableLmYaRNScalingRotaryEmbedding, ) -# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm +# Copied from transformers.tests.models.StableLm.test_modeling_StableLm.StableLmModelTester with StableLm -> StableLm class StableLmModelTester: # Ignore copy def __init__( @@ -376,7 +378,7 @@ def test_stablelm_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -400,7 +402,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": + if scaling_type in ("dynamic", "dynamic-yarn"): self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -468,6 +470,40 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check YaRN RoPE scaling + yarn_scaling_rope = StableLmYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + # Sanity check Dynamic YaRN RoPE scaling + dynamic_yarn_scaling_rope = StableLmDynamicYaRNScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) + dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) + dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) + dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) + @require_torch class StableLmModelIntegrationTest(unittest.TestCase): From 85552b3c057e8955ee5382292adf0861854d4963 Mon Sep 17 00:00:00 2001 From: Miguel Almeida Date: Sun, 16 Jun 2024 22:30:24 +0100 Subject: [PATCH 2/2] Refactor YaRN implementation for LLaMA Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas --- .../open_llama/configuration_open_llama.py | 88 +------- .../open_llama/modeling_open_llama.py | 191 ----------------- .../models/falcon/configuration_falcon.py | 84 +------- .../models/falcon/modeling_falcon.py | 189 ----------------- .../models/fuyu/configuration_fuyu.py | 5 +- .../models/gpt_neox/configuration_gpt_neox.py | 89 +------- .../models/gpt_neox/modeling_gpt_neox.py | 190 ----------------- .../models/llama/configuration_llama.py | 71 ++----- .../models/llama/modeling_llama.py | 104 ++++------ .../models/olmo/configuration_olmo.py | 89 +------- src/transformers/models/olmo/modeling_olmo.py | 192 ------------------ .../persimmon/configuration_persimmon.py | 85 +------- .../models/persimmon/modeling_persimmon.py | 190 ----------------- .../models/phi/configuration_phi.py | 85 +------- src/transformers/models/phi/modeling_phi.py | 190 ----------------- .../models/stablelm/configuration_stablelm.py | 85 +------- .../models/stablelm/modeling_stablelm.py | 190 ----------------- tests/models/falcon/test_modeling_falcon.py | 38 +--- .../models/gpt_neox/test_modeling_gpt_neox.py | 40 +--- tests/models/llama/test_modeling_llama.py | 24 ++- tests/models/olmo/test_modeling_olmo.py | 4 +- .../persimmon/test_modeling_persimmon.py | 40 +--- tests/models/phi/test_modeling_phi.py | 40 +--- .../models/stablelm/test_modeling_stablelm.py | 40 +--- 24 files changed, 104 insertions(+), 2239 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py index 296f8d95c68b6..e20c33f24a322 100644 --- a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py @@ -66,31 +66,13 @@ class OpenLlamaConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling - strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. Example: @@ -132,14 +114,6 @@ def __init__( shared_input_output_embedding=True, rope_theta=10000.0, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, **kwargs, ): self.vocab_size = vocab_size @@ -162,8 +136,6 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() - self.yarn_rope_scaling = yarn_rope_scaling - self._yarn_rope_scaling_validation() super().__init__( pad_token_id=pad_token_id, @@ -186,61 +158,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 7b446eb5c362b..7d2098f2f63ff 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -144,165 +144,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaYaRNScalingRotaryEmbedding with Llama->OpenLlama -class OpenLlamaYaRNScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device, scaling_factor) - - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, position_ids=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) - return ( - self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicYaRNScalingRotaryEmbedding with Llama->OpenLlama -class OpenLlamaDynamicYaRNScalingRotaryEmbedding(OpenLlamaYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - - def forward(self, x, position_ids=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) - return ( - self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), - ) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -392,13 +233,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -413,31 +247,6 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "yarn": - self.rotary_emb = OpenLlamaYaRNScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = OpenLlamaDynamicYaRNScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index 25faf60fdf1f2..0dd61047dd275 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -77,31 +77,13 @@ class FalconConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling - strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. bos_token_id (`int`, *optional*, defaults to 11): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 11): @@ -151,14 +133,6 @@ def __init__( max_position_embeddings=2048, rope_theta=10000.0, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, bos_token_id=11, eos_token_id=11, ffn_hidden_size=None, @@ -188,14 +162,12 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self.yarn_rope_scaling = yarn_rope_scaling self.activation = activation if ffn_hidden_size is None: self.ffn_hidden_size = hidden_size * 4 else: self.ffn_hidden_size = ffn_hidden_size self._rope_scaling_validation() - self._yarn_rope_scaling_validation() super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -229,55 +201,3 @@ def _rope_scaling_validation(self): ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a5ca2b9eae1e7..5f2f2d8a1bfc9 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -209,162 +209,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -class FalconYaRNScalingRotaryEmbedding(FalconRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device) - - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, seq_len=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -class FalconDynamicYaRNScalingRotaryEmbedding(FalconYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - def forward(self, x, seq_len=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) @@ -453,7 +297,6 @@ def __init__(self, config: FalconConfig): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 - # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = FalconRotaryEmbedding( @@ -464,13 +307,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = FalconLinearScalingRotaryEmbedding( self.head_dim, @@ -485,31 +321,6 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "yarn": - self.rotary_emb = FalconYaRNScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = FalconDynamicYaRNScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/fuyu/configuration_fuyu.py b/src/transformers/models/fuyu/configuration_fuyu.py index deb484db66848..d9d00d4828829 100644 --- a/src/transformers/models/fuyu/configuration_fuyu.py +++ b/src/transformers/models/fuyu/configuration_fuyu.py @@ -186,7 +186,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -200,9 +199,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index ab45c19e760e1..944dbb5e02f09 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -74,31 +74,13 @@ class GPTNeoXConfig(PretrainedConfig): Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training speedup at large scales (e.g. 20B). rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling - strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to`{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. attention_bias (`bool`, *optional*, defaults to `True`): Whether to use a bias in the query, key, value and output projection layers during self-attention. @@ -142,14 +124,6 @@ def __init__( tie_word_embeddings=False, use_parallel_residual=True, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, attention_bias=True, **kwargs, ): @@ -172,17 +146,14 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.use_parallel_residual = use_parallel_residual self.rope_scaling = rope_scaling - self.yarn_rope_scaling = yarn_rope_scaling self.attention_bias = attention_bias self._rope_scaling_validation() - self._yarn_rope_scaling_validation() if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size is not divisble by the number of attention heads! Make sure to update them!" ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -196,61 +167,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e027ae0e7dc09..bde881226fb8c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch GPTNeoX model.""" -import math from typing import Optional, Tuple, Union import torch @@ -137,13 +136,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding( self.rotary_ndims, @@ -158,31 +150,6 @@ def _init_rope(self): base=self.config.rotary_emb_base, scaling_factor=scaling_factor, ) - elif scaling_type == "yarn": - self.rotary_emb = GPTNeoXYaRNScalingRotaryEmbedding( - self.rotary_ndims, - self.config.max_position_embeddings, - base=self.config.rotary_emb_base, - scaling_factor=scaling_factor, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = GPTNeoXDynamicYaRNScalingRotaryEmbedding( - self.rotary_ndims, - self.config.max_position_embeddings, - base=self.config.rotary_emb_base, - scaling_factor=scaling_factor, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -640,163 +607,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin(), persistent=False) -# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->GPTNeoX -class GPTNeoXYaRNScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device) - - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, seq_len=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -class GPTNeoXDynamicYaRNScalingRotaryEmbedding(GPTNeoXYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - def forward(self, x, seq_len=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 593c4221684cf..dc1abee0e0e3a 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -91,24 +91,15 @@ class LlamaConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): + For `yarn` and `dynamic-yarn` strategies, the dictionary may also contain the following fields: + `original_max_position_embeddings` (`int`, *optional*): The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): + `attention_factor` (`float`, *optional*): + The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio. + `beta_fast` (`float`, *optional*): Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): + `beta_slow` (`float`, *optional*): Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -152,14 +143,6 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, attention_bias=False, attention_dropout=0.0, mlp_bias=False, @@ -185,8 +168,6 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() - self.yarn_rope_scaling = yarn_rope_scaling - self._yarn_rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias @@ -206,9 +187,10 @@ def _rope_scaling_validation(self): if self.rope_scaling is None: return - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2: raise ValueError( - "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + "`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, " + f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) @@ -219,53 +201,38 @@ def _rope_scaling_validation(self): if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: + if rope_scaling_type not in ["yarn", "dynamic-yarn"]: return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " + "`rope_scaling` with type " + f"{rope_scaling_type}" + " must be a dictionary with a maximum of six fields, `type`, `factor`," + "`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, " f"got {self.rope_scaling}" ) original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) attention_factor = self.rope_scaling.get("attention_factor", None) beta_fast = self.rope_scaling.get("beta_fast", None) beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" + f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" ) if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") + raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") b_fast = beta_fast if beta_fast is not None else 32 b_slow = beta_slow if beta_slow is not None else 1 if b_fast < b_slow: raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index bf434f87e3844..4eef0743800bd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -149,7 +149,7 @@ def forward(self, x, position_ids): return cos, sin -class LlamaYaRNScalingRotaryEmbedding(LlamaRotaryEmbedding): +class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__( self, dim, @@ -157,8 +157,7 @@ def __init__( base=10000, scaling_factor=1, original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, + attention_factor=None, beta_fast=32, beta_slow=1, device=None, @@ -166,12 +165,14 @@ def __init__( super().__init__(dim, max_position_embeddings, base, device, scaling_factor) self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor self.attention_factor = attention_factor self.beta_fast = beta_fast self.beta_slow = beta_slow - self.yarn(device) + if self.attention_factor is None: + self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0 + + self.compute_yarn_scaling(device) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings @@ -205,28 +206,16 @@ def linear_ramp_mask(self, min, max, dim): ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - def forward(self, x, position_ids=None): # Difference to the original RoPE: applies a scaling factor computed with # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) - return ( - self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), - ) + # x: [bs, num_attention_heads, seq_len, head_size] + cos, sin = super().forward(x, position_ids) + cos = cos * self.mscale + sin = sin * self.mscale + return cos, sin - def yarn(self, device): + def compute_yarn_scaling(self, device): pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) @@ -235,17 +224,15 @@ def yarn(self, device): self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings ) # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor + inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask self.register_buffer("inv_freq", inv_freq) # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) + self.mscale = self.attention_factor -class LlamaDynamicYaRNScalingRotaryEmbedding(LlamaYaRNScalingRotaryEmbedding): +class LlamaDynamicYarnScalingRotaryEmbedding(LlamaYarnScalingRotaryEmbedding): def __init__( self, dim, @@ -253,11 +240,9 @@ def __init__( base=10000, scaling_factor=1, original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, + attention_factor=None, beta_fast=32, beta_slow=1, - finetuned=False, device=None, ): super().__init__( @@ -266,16 +251,15 @@ def __init__( base, scaling_factor, original_max_position_embeddings, - extrapolation_factor, attention_factor, beta_fast, beta_slow, device, ) - if finetuned: + if self.max_position_embeddings != self.original_max_position_embeddings: self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) + self.compute_yarn_scaling(device) else: inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) @@ -290,20 +274,13 @@ def __init__( def forward(self, x, position_ids=None): # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] + # x: [bs, num_attention_heads, seq_len, head_size] seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) - return ( - self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), - ) + self.scaling_factor = seq_len / self.original_max_position_embeddings + self.compute_yarn_scaling(x.device) + + cos, sin = super().forward(x, position_ids) + return cos, sin def rotate_half(x): @@ -432,13 +409,15 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] + # Yarn parameters + kwargs = { + "dim": self.config.rope_scaling.get("original_max_position_embeddings", None), + "max_position_embeddings": self.config.rope_scaling.get("attention_factor", None), + "base": self.config.rope_scaling.get("beta_fast", None), + "scaling_factor": self.config.rope_scaling.get("beta_slow", None), + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -454,29 +433,20 @@ def _init_rope(self): base=self.rope_theta, ) elif scaling_type == "yarn": - self.rotary_emb = LlamaYaRNScalingRotaryEmbedding( + self.rotary_emb = LlamaYarnScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, + **kwargs, ) elif scaling_type == "dynamic-yarn": - self.rotary_emb = LlamaDynamicYaRNScalingRotaryEmbedding( + self.rotary_emb = LlamaDynamicYarnScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, + **kwargs, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index c162675948063..2b0a23d03df88 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -76,31 +76,13 @@ class OlmoConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling - strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -143,14 +125,6 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, attention_bias=False, attention_dropout=0.0, clip_qkv=None, @@ -174,8 +148,6 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() - self.yarn_rope_scaling = yarn_rope_scaling - self._yarn_rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.clip_qkv = clip_qkv @@ -188,7 +160,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -202,61 +173,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 4ab81ae5f9dcb..2b29becf2ad05 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -149,165 +149,6 @@ def forward(self, x, position_ids): return cos, sin -# Copied from transformers.models.llama.modeling_llama.LlamaYaRNScalingRotaryEmbedding with Llama->Olmo -class OlmoYaRNScalingRotaryEmbedding(OlmoRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device, scaling_factor) - - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, position_ids=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) - return ( - self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicYaRNScalingRotaryEmbedding with Llama->Olmo -class OlmoDynamicYaRNScalingRotaryEmbedding(OlmoYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) - - def forward(self, x, position_ids=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(x.dtype) - return ( - self._cos_cached[:, :seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:, :seq_len, ...].to(dtype=x.dtype), - ) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -409,7 +250,6 @@ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self._init_rope() - # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = OlmoRotaryEmbedding( @@ -420,13 +260,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = OlmoLinearScalingRotaryEmbedding( self.head_dim, @@ -441,31 +274,6 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "yarn": - self.rotary_emb = OlmoYaRNScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = OlmoDynamicYaRNScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/persimmon/configuration_persimmon.py b/src/transformers/models/persimmon/configuration_persimmon.py index faa757e9d3e19..11f4c66d73e6b 100644 --- a/src/transformers/models/persimmon/configuration_persimmon.py +++ b/src/transformers/models/persimmon/configuration_persimmon.py @@ -67,24 +67,6 @@ class PersimmonConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to`{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. qk_layernorm (`bool`, *optional*, default to `True`): Whether or not to normalize the Queries and Keys after projecting the hidden states hidden_dropout (`float`, *optional*, default to 0.0): @@ -121,14 +103,6 @@ def __init__( tie_word_embeddings=False, rope_theta=25000.0, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, qk_layernorm=True, hidden_dropout=0.0, attention_dropout=0.0, @@ -150,13 +124,11 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self.yarn_rope_scaling = yarn_rope_scaling self.qk_layernorm = qk_layernorm self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() - self._yarn_rope_scaling_validation() super().__init__( pad_token_id=pad_token_id, @@ -166,7 +138,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -180,61 +151,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index e87612f5ba381..9458c3361d2e8 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -130,164 +130,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->Persimmon -class PersimmonYaRNScalingRotaryEmbedding(PersimmonRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device) - - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, seq_len=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicYaRNScalingRotaryEmbedding with Falcon->Persimmon -class PersimmonDynamicYaRNScalingRotaryEmbedding(PersimmonYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - def forward(self, x, seq_len=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -391,13 +233,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = PersimmonLinearScalingRotaryEmbedding( int(self.partial_rotary_factor * self.head_dim), @@ -412,31 +247,6 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "yarn": - self.rotary_emb = PersimmonYaRNScalingRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = PersimmonDynamicYaRNScalingRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index 7193c12cc18f7..3353199adafb8 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -82,24 +82,6 @@ class PhiConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. qk_layernorm (`bool`, *optional*, defaults to `False`): @@ -146,14 +128,6 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, partial_rotary_factor=0.5, qk_layernorm=False, bos_token_id=1, @@ -180,11 +154,9 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self.yarn_rope_scaling = yarn_rope_scaling self.partial_rotary_factor = partial_rotary_factor self.qk_layernorm = qk_layernorm self._rope_scaling_validation() - self._yarn_rope_scaling_validation() super().__init__( bos_token_id=bos_token_id, @@ -193,7 +165,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -207,61 +178,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 7e7acc63bb137..a2c3793c01194 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -159,164 +159,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->Phi -class PhiYaRNScalingRotaryEmbedding(PhiRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device) - - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, seq_len=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicYaRNScalingRotaryEmbedding with Falcon->Phi -class PhiDynamicYaRNScalingRotaryEmbedding(PhiYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - def forward(self, x, seq_len=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -440,13 +282,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = PhiLinearScalingRotaryEmbedding( int(self.partial_rotary_factor * self.head_dim), @@ -461,31 +296,6 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "yarn": - self.rotary_emb = PhiYaRNScalingRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = PhiDynamicYaRNScalingRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/stablelm/configuration_stablelm.py b/src/transformers/models/stablelm/configuration_stablelm.py index c160786f40b5a..006aa504cc1a6 100644 --- a/src/transformers/models/stablelm/configuration_stablelm.py +++ b/src/transformers/models/stablelm/configuration_stablelm.py @@ -78,24 +78,6 @@ class StableLmConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): - Dictionary containing the YaRN-specific scaling configuration for the RoPE embeddings. The expected format is - `{"original_max_position_embeddings": int, "extrapolation_factor": float, "attention_factor": float, - "beta_fast": float, "beta_slow": float,"finetuned": bool}`. - Fields: - original_max_position_embeddings (`int`, *optional*, defaults to 2048): - The original maximum sequence length. This is used to scale the RoPE embeddings. - extrapolation_factor (`float`, defaults to 1): - Factor to ajust the n-dimensional rotational scaling for extrapolation. - attention_factor (`float`, *optional*, defaults to 1): - Factor to adjust the weight attention scaling mechanism. - beta_fast (`float`, *optional*, defaults to 32): - Parameter to set the boundary for extrapolation (only) in the linear ramp function. - beta_slow (`float`, *optional*, defaults to 1): - Parameter to set the boundary for interpolation (only) in the linear ramp function. - finetuned (`bool`, *optional*, defaults to `False`): - [Dynamic] Whether the model is finetuned or not. - For more details please refer to https://arxiv.org/abs/2309.00071. use_qkv_bias (`bool`, *optional*, defaults to `False`): Whether or not the model should use bias for qkv layers. qk_layernorm (`bool`, *optional*, defaults to `False`): @@ -142,14 +124,6 @@ def __init__( tie_word_embeddings=False, rope_theta=10_000, rope_scaling=None, - yarn_rope_scaling={ - "original_max_position_embeddings": 2048, - "extrapolation_factor": 1.0, - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "finetuned": False, - }, use_qkv_bias=False, qk_layernorm=False, use_parallel_residual=False, @@ -175,7 +149,6 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self.yarn_rope_scaling = yarn_rope_scaling self.use_qkv_bias = use_qkv_bias self.qk_layernorm = qk_layernorm self.use_parallel_residual = use_parallel_residual @@ -183,7 +156,6 @@ def __init__( self.attention_dropout = attention_dropout self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() - self._yarn_rope_scaling_validation() super().__init__( bos_token_id=bos_token_id, @@ -192,7 +164,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -206,61 +177,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation - def _yarn_rope_scaling_validation(self): - """ - Validate the `yarn_rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: - raise ValueError( - "`yarn_rope_scaling` must be a dictionary with a maximum of six fields, `original_max_position_embeddings`, " - "`extrapolation_factor`, `attention_factor`, `beta_fast`, `beta_slow`, `finetuned`, " - f"got {self.rope_scaling}" - ) - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) - extrapolation_factor = self.rope_scaling.get("extrapolation_factor", None) - attention_factor = self.rope_scaling.get("attention_factor", None) - beta_fast = self.rope_scaling.get("beta_fast", None) - beta_slow = self.rope_scaling.get("beta_slow", None) - finetuned = self.rope_scaling.get("finetuned", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - f"`yarn_rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" - ) - if ( - extrapolation_factor is not None - and not isinstance(extrapolation_factor, float) - or extrapolation_factor < 0 - or extrapolation_factor > 1 - ): - raise ValueError( - f"`yarn_rope_scaling`'s extrapolation_factor field must be a float between 0 and 1, got {extrapolation_factor}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`yarn_rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`yarn_rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - if finetuned is not None and not isinstance(finetuned, bool): - raise ValueError(f"`yarn_rope_scaling`'s finetuned field must be a bool, got {finetuned}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`yarn_rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" - ) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 888ab1ac86958..264bc3e973944 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -156,164 +156,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.falcon.modeling_falcon.FalconYaRNScalingRotaryEmbedding with Falcon->StableLm -class StableLmYaRNScalingRotaryEmbedding(StableLmRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - device=None, - ): - super().__init__(dim, max_position_embeddings, base, device) - - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attention_factor = attention_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - self.yarn(device) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - # Get positional embeddings based on the current max sequence length - def get_pos_embeddings(self, device): - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - # Inverse dimension formula to find the dimension based on the number of rotations - def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - # Find dimension range bounds based on rotations - def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_mask(self, min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(self, scaling_factor=1): - if scaling_factor <= 1: - return 1.0 - return 0.1 * math.log(scaling_factor) + 1.0 - - def forward(self, x, seq_len=None): - # Difference to the original RoPE: applies a scaling factor computed with - # the YaRN method (NTK-by-Parts + Attn Scaling) - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - - low, high = self.find_correction_range( - self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings - ) - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - - self.register_buffer("inv_freq", inv_freq) - # Get n-dimensional magnitude scaling corrected for interpolation - self.mscale = float(self.get_mscale(self.scaling_factor) * self.attention_factor) - - -# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicYaRNScalingRotaryEmbedding with Falcon->StableLm -class StableLmDynamicYaRNScalingRotaryEmbedding(StableLmYaRNScalingRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attention_factor=1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__( - dim, - max_position_embeddings, - base, - scaling_factor, - original_max_position_embeddings, - extrapolation_factor, - attention_factor, - beta_fast, - beta_slow, - device, - ) - - if finetuned: - self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings - self.yarn(device) - else: - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - emb = self.get_pos_embeddings(device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(torch.get_default_dtype()) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(torch.get_default_dtype()) - - def forward(self, x, seq_len=None): - # Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded - # x: [batch_size, seq_len, head_dim] - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.scaling_factor = seq_len / self.original_max_position_embeddings - self.yarn(x.device) - emb = self.get_pos_embeddings(x.device) - - self._cos_cached = (emb.cos() * self.mscale)[:, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.mscale)[:, :].to(x.dtype) - return ( - self._cos_cached[:seq_len, ...].to(dtype=x.dtype), - self._sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -449,13 +291,6 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - # YaRN parameters - original_max_position_embeddings = self.config.yarn_rope_scaling["original_max_position_embeddings"] - extrapolation_factor = self.config.yarn_rope_scaling["extrapolation_factor"] - attention_factor = self.config.yarn_rope_scaling["attention_factor"] - beta_fast = self.config.yarn_rope_scaling["beta_fast"] - beta_slow = self.config.yarn_rope_scaling["beta_slow"] - finetuned = self.config.yarn_rope_scaling["finetuned"] if scaling_type == "linear": self.rotary_emb = StableLmLinearScalingRotaryEmbedding( int(self.partial_rotary_factor * self.head_dim), @@ -470,31 +305,6 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "yarn": - self.rotary_emb = StableLmYaRNScalingRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - ) - elif scaling_type == "dynamic-yarn": - self.rotary_emb = StableLmDynamicYaRNScalingRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - original_max_position_embeddings=original_max_position_embeddings, - extrapolation_factor=extrapolation_factor, - attention_factor=attention_factor, - beta_fast=beta_fast, - beta_slow=beta_slow, - finetuned=finetuned, - ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 275b2e646b850..50e8fcdbb4b0a 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -53,10 +53,8 @@ ) from transformers.models.falcon.modeling_falcon import ( FalconDynamicNTKScalingRotaryEmbedding, - FalconDynamicYaRNScalingRotaryEmbedding, FalconLinearScalingRotaryEmbedding, FalconRotaryEmbedding, - FalconYaRNScalingRotaryEmbedding, ) @@ -444,7 +442,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -511,40 +509,6 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - # Sanity check YaRN RoPE scaling - yarn_scaling_rope = FalconYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - # Sanity check Dynamic YaRN RoPE scaling - dynamic_yarn_scaling_rope = FalconDynamicYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) - dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) - dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) - dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) - torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) - torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) - # TODO: @Fxmarty @is_flaky(max_attempts=3, description="flaky on some models.") @require_torch_sdpa diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ff2c36a0f0e70..ed5bcac55e45e 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -39,10 +39,8 @@ ) from transformers.models.gpt_neox.modeling_gpt_neox import ( GPTNeoXDynamicNTKScalingRotaryEmbedding, - GPTNeoXDynamicYaRNScalingRotaryEmbedding, GPTNeoXLinearScalingRotaryEmbedding, GPTNeoXRotaryEmbedding, - GPTNeoXYaRNScalingRotaryEmbedding, ) @@ -306,7 +304,7 @@ def test_model_for_token_classification(self): def test_feed_forward_chunking(self): pass - @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) + @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -330,7 +328,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -398,40 +396,6 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - # Sanity check YaRN RoPE scaling - yarn_scaling_rope = GPTNeoXYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - scaling_factor=scaling_factor, - ).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - # Sanity check Dynamic YaRN RoPE scaling - dynamic_yarn_scaling_rope = GPTNeoXDynamicYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - scaling_factor=scaling_factor, - ).to(torch_device) - dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) - dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) - dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) - dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) - torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) - torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) - @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 5ab3b54e697aa..6441e539d03d8 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -53,10 +53,10 @@ ) from transformers.models.llama.modeling_llama import ( LlamaDynamicNTKScalingRotaryEmbedding, - LlamaDynamicYaRNScalingRotaryEmbedding, + LlamaDynamicYarnScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding, - LlamaYaRNScalingRotaryEmbedding, + LlamaYarnScalingRotaryEmbedding, ) @@ -422,7 +422,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -493,8 +493,8 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - # Sanity check YaRN RoPE scaling - yarn_scaling_rope = LlamaYaRNScalingRotaryEmbedding( + # Sanity check Yarn RoPE scaling + yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, @@ -504,13 +504,17 @@ def test_model_rope_scaling(self): yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_cos_long, original_cos_long) with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - # Sanity check Dynamic YaRN RoPE scaling - dynamic_yarn_scaling_rope = LlamaDynamicYaRNScalingRotaryEmbedding( + # Sanity check Dynamic Yarn RoPE scaling + dynamic_yarn_scaling_rope = LlamaDynamicYarnScalingRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, @@ -518,8 +522,10 @@ def test_model_rope_scaling(self): ).to(torch_device) dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, position_ids_short) dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, position_ids_long) - torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) - torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) with self.assertRaises(AssertionError): torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) with self.assertRaises(AssertionError): diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 3b76a4a93e9e3..ee87521c5ba09 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -322,7 +322,7 @@ def test_save_load_fast_init_from_base(self): def test_eager_matches_sdpa_generate(self): super().test_eager_matches_sdpa_generate() - @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) + @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -345,7 +345,7 @@ def test_model_rope_scaling(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 308ed2da8717c..518cb7e0379c2 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -48,10 +48,8 @@ ) from transformers.models.persimmon.modeling_persimmon import ( PersimmonDynamicNTKScalingRotaryEmbedding, - PersimmonDynamicYaRNScalingRotaryEmbedding, PersimmonLinearScalingRotaryEmbedding, PersimmonRotaryEmbedding, - PersimmonYaRNScalingRotaryEmbedding, ) @@ -391,7 +389,7 @@ def test_persimmon_token_classification_model(self): def test_save_load_fast_init_from_base(self): pass - @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) + @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -415,7 +413,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -483,40 +481,6 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - # Sanity check YaRN RoPE scaling - yarn_scaling_rope = PersimmonYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - # Sanity check Dynamic YaRN RoPE scaling - dynamic_yarn_scaling_rope = PersimmonDynamicYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) - dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) - dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) - dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) - torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) - torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) - @require_torch class PersimmonIntegrationTest(unittest.TestCase): diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index 0397b799f8a2e..f395b70c1ee2c 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -48,10 +48,8 @@ ) from transformers.models.phi.modeling_phi import ( PhiDynamicNTKScalingRotaryEmbedding, - PhiDynamicYaRNScalingRotaryEmbedding, PhiLinearScalingRotaryEmbedding, PhiRotaryEmbedding, - PhiYaRNScalingRotaryEmbedding, ) @@ -367,7 +365,7 @@ def test_phi_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) + @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -391,7 +389,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -459,40 +457,6 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - # Sanity check YaRN RoPE scaling - yarn_scaling_rope = PhiYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - # Sanity check Dynamic YaRN RoPE scaling - dynamic_yarn_scaling_rope = PhiDynamicYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) - dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) - dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) - dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) - torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) - torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) - @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index ccc3fd3c82092..2e84612eca265 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -47,10 +47,8 @@ ) from transformers.models.stablelm.modeling_stablelm import ( StableLmDynamicNTKScalingRotaryEmbedding, - StableLmDynamicYaRNScalingRotaryEmbedding, StableLmLinearScalingRotaryEmbedding, StableLmRotaryEmbedding, - StableLmYaRNScalingRotaryEmbedding, ) @@ -377,7 +375,7 @@ def test_stablelm_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)]) + @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -401,7 +399,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. - if scaling_type in ("dynamic", "dynamic-yarn"): + if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) @@ -469,40 +467,6 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - # Sanity check YaRN RoPE scaling - yarn_scaling_rope = StableLmYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, short_input_length) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, long_input_length) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - # Sanity check Dynamic YaRN RoPE scaling - dynamic_yarn_scaling_rope = StableLmDynamicYaRNScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, short_input_length) - dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, long_input_length) - dynamic_yarn_cos_short = dynamic_yarn_cos_short.squeeze(0) - dynamic_yarn_sin_short = dynamic_yarn_sin_short.squeeze(0) - torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short) - torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long) - @require_torch class StableLmModelIntegrationTest(unittest.TestCase):