Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GPTNeoX models #32

Merged
merged 17 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from llama_attn_replace import replace_llama_attn
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier

Expand All @@ -39,7 +40,8 @@

@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
model_type: Optional[str] = field(default="gpt-neox")

@dataclass
class TrainingArguments(transformers.TrainingArguments):
Expand Down Expand Up @@ -99,7 +101,11 @@ def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()

replace_llama_attn(training_args.use_flash_attn)
# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn)
else:
replace_llama_attn(training_args.use_flash_attn)

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
Expand All @@ -117,14 +123,15 @@ def train():
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
use_fast=True,
)

special_tokens_dict = dict()
Expand Down Expand Up @@ -157,10 +164,16 @@ def train():
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

if training_args.low_rank_training:
if model_args.model_type == "gpt-neox":
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
targets = ["query_key_value", "dense"]
else:
targets=["q_proj", "k_proj", "v_proj", "o_proj"],

config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
Expand All @@ -169,9 +182,9 @@ def train():
# enable trainable params
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]

model.config.use_cache = False # required for gradient checkpointing
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing

trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset["train"],
Expand Down
166 changes: 166 additions & 0 deletions gptneox_attn_replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Modified based on https://github.com/dvlab-research/LongLoRA

from typing import Optional, Tuple
import warnings
import torch
import transformers

from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input


group_size_ratio = 1/4

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def _flash_attn(query, key, value, attention_mask=None, head_mask=None):
# transform the data into the qkv packed form
qkv = torch.stack(
[query, key, value], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
bsz, q_len = qkv.shape[:2]

qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)

# disable attn weights by returning None when using flash attention
return output, None


def get_forward_function(use_flash_attn=True, use_full=False):

def forward_attention(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
# NOTE: compute SS group size
bsz, q_len, _ = hidden_states.size()
has_layer_past = layer_past is not None

# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)

# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)

# [batch, seq_len, num_attention_heads, 3 * head_size]
# --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# [bsz, nh, q_len, hd]

# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]

# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None

# NOTE: apply shift
group_size = int(q_len * group_size_ratio)
if q_len % group_size > 0:
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
num_group = q_len // group_size
if self.training and not use_full:
def shift(qkv, num_heads, head_dim):
# qkv = [bsz, nh, q_len, d]
qkv = qkv.transpose(1, 2)
# qkv = [bsz, q_len, nh, d]
qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1)

# -> [bsz * n_group, group_s, nh, d)
# -> [bsz * n_group, nh, group_s, d)
qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2)
return qkv

# contiguous is required as self._attn() will attempt to apply .view() on them
query = shift(query, self.num_attention_heads, self.head_size).contiguous()
key = shift(key, self.num_attention_heads, self.head_size).contiguous()
value = shift(value, self.num_attention_heads, self.head_size).contiguous()

attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)

# Compute attention
if use_flash_attn:
attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

# NOTE: shift back
if self.training and not use_full:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size)
# [bsz, q_len, nh, hd]
attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1)
attn_output = attn_output.transpose(1, 2)

# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.dense(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs

return forward_attention


def replace_gpt_neox_attn(use_flash_attn=True, use_full=False):
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if use_flash_attn and cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
"Resorting to plain attention..."
)
use_flash_attn = False

forward_fn = get_forward_function(use_flash_attn, use_full)
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn
22 changes: 18 additions & 4 deletions supervised-fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from llama_attn_replace import replace_llama_attn
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier

Expand Down Expand Up @@ -64,7 +65,8 @@ def jload(f, mode="r"):

@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
model_type: Optional[str] = field(default="gpt-neox")


@dataclass
Expand Down Expand Up @@ -219,7 +221,11 @@ def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

replace_llama_attn(training_args.use_flash_attn, True)
# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn)
else:
replace_llama_attn(training_args.use_flash_attn)

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
Expand All @@ -237,14 +243,15 @@ def train():
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
use_fast=True,
)

special_tokens_dict = dict()
Expand All @@ -266,10 +273,16 @@ def train():
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

if training_args.low_rank_training:
if model_args.model_type == "gpt-neox":
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
targets = ["query_key_value", "dense"]
else:
targets=["q_proj", "k_proj", "v_proj", "o_proj"],

config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
Expand All @@ -278,6 +291,7 @@ def train():
# enable trainable params
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]

model.config.use_cache = False # required for gradient checkpointing
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing

Expand Down