Skip to content

Commit

Permalink
Add RoPE Interpolation (#3564)
Browse files Browse the repository at this point in the history
Added support for RopE interpolation via the SuperHOT method and its
variants proposed in

[reddit](https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/)
[scaled-rope
](https://github.com/jquesnelle/scaled-rope/tree/master/scaled_rope)

Supported methods
- Linear scaling
- NTK aware scaling
- Dynamic NTK

Supported Models
- LLAMA
- Falcon 

This can be easily extended and experimented with by configuring two
parameters
`superhot` and `superhot_config`
  • Loading branch information
shahules786 committed Jul 12, 2023
1 parent ed089f6 commit 018657b
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 3 deletions.
29 changes: 29 additions & 0 deletions model/model_training/configs/config.yaml
Expand Up @@ -779,3 +779,32 @@ debug:
verbose: true
num_train_epochs: 0.2
dtype: fp32

rope_scaling_test:
dtype: bf16
log_dir: "llama_log_7b"
learning_rate: 1e-5
model_name: "huggyllama/llama-7b"
deepspeed_config: configs/zero_config_falcon.json
output_dir: llama
weight_decay: 0.0
max_length: 4048
warmup_steps: 100
gradient_checkpointing: true
gradient_accumulation_steps: 2
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
eval_steps: 100
save_steps: 500
num_train_epochs: 8
save_total_limit: 4
use_flash_attention: false
residual_dropout: 0.3
residual_dropout_lima: true
log_wandb: true
peft_model: true
peft_type: "lora"
superhot: true
superhot_config:
type: linear
scale: 2
54 changes: 53 additions & 1 deletion model/model_training/models/patching.py
Expand Up @@ -6,12 +6,13 @@

import torch.nn as nn
import transformers
from transformers import GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

from .patching_llama import llama_forward_with_flash_attn
from .patching_neox import neox_forward_with_flash_attn
from .reward_model import GPTNeoXRewardModel
from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope

SUPPORTED_MODELS = [
GPTNeoXModel,
Expand Down Expand Up @@ -176,3 +177,54 @@ def patch_model(
if resid_pdrop is not None and resid_pdrop > 0:
add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop)
add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop)


class RopePatch:
def __init__(self, model_name, **kwargs):
self.args = kwargs
rope_type = self.args.pop("type")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "RWForCausalLM" in architecture:
self.architecture = "RWForCausalLM"
if rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
raise NotImplementedError()
elif "LlamaForCausalLM" in architecture:
self.architecture = "LlamaForCausalLM"
if rope_type == "linear":
self.patch_fun = LlamaLinearScaledRope
elif rope_type == "ntk":
self.patch_fun = LlamaNTKScaledRope
elif rope_type == "dynamic-ntk":
self.patch_fun = LlamaDynamicScaledRotaryEmbedding
else:
raise NotImplementedError()
else:
raise NotImplementedError()

@classmethod
def from_config(cls, config):
model_name = config.model_name
args = config.superhot_config
return cls(model_name, **args)

def patch(self, model):
if self.architecture == "RWForCausalLM":
self.patch_rw_model(model, **self.args)
elif self.architecture == "LlamaForCausalLM":
self.patch_llama_model(model, **self.args)
else:
raise NotImplementedError()

def patch_rw_model(self, model, **kwargs):
for each in model.transformer.h:
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)

def patch_llama_model(self, model, **kwargs):
kwargs.update({"device": model.device})
for each in model.model.layers:
each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs)
187 changes: 187 additions & 0 deletions model/model_training/models/rope.py
@@ -0,0 +1,187 @@
import torch


# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0


class RWNTKScaledRope(torch.nn.Module):

"""
NTK-Scaled RoPE for RefinedWebModel
"""

def __init__(
self,
head_dim: int,
base=10000,
alpha: int = 2,
):
super().__init__()
self.alpha = alpha
base = base * self.alpha ** (head_dim / (head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = None
self.batch_size_cached = None
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None

def cos_sin(
self,
seq_len: int,
device="cuda",
dtype=torch.bfloat16,
) -> torch.Tensor:
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]

self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

return self.cos_cached, self.sin_cached

def forward(self, q, k):
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


class LlamaLinearScaledRope(torch.nn.Module):
"""
reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = 1 / scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaNTKScaledRope(torch.nn.Module):

"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
super().__init__()
base = base * alpha ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk:
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (
self.dim / (self.dim - 2)
)
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
7 changes: 5 additions & 2 deletions model/model_training/trainer_sft.py
Expand Up @@ -11,6 +11,7 @@
# from model_training.custom_datasets.formatting import DatasetEntry
from model_training.custom_datasets.dialogue_collator import DialogueDataCollator
from model_training.efficiency_utils import fuse_gelu
from model_training.models.patching import RopePatch
from model_training.models.peft_modeling import peft_model
from model_training.utils.utils import (
PerDatasetSampler,
Expand Down Expand Up @@ -362,7 +363,6 @@ def main():
)

train, evals = get_dataset(training_conf)

show_dataset_stats = (training_conf.verbose or training_conf.show_dataset_stats) and (
not training_conf.deepspeed or training_conf.local_rank == 0
)
Expand Down Expand Up @@ -416,9 +416,12 @@ def main():
sampler = None

metrics, preprocess_fns = get_metrics(training_conf, tokenizer)

model = get_model(training_conf, tokenizer)

superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None
if superhot:
superhot.patch(model)

if training_conf.peft_model:
print("Using PEFT model")
model = peft_model(
Expand Down

0 comments on commit 018657b

Please sign in to comment.