Skip to content

Commit

Permalink
Add support for flash attention (#2033)
Browse files Browse the repository at this point in the history
Along with some miscellaneous changes, this PR adds support for patching
FlashAttention [1] into GPTNeoX models. This not only makes training
more efficient, but allows for almost arbitrarily long context sizes.

A few caveats apply:
- FlashAttention requires inputs to be in `fp16` and on CUDA.
- FlashAttention outputs slightly different results compared to the
reference GPTNeoX implementation. For a single layer, the outputs are
almost the same (less than `1e-3` absolute difference), but the error
can accumulate during the forward pass. If and how this affects
generated samples will have to be evaluated once a patched model has
been trained. Worst case, flash attention will have to be used during
inference.
- For an alternative memory-efficient attention implementation, we try
out xFormers [2].
- Non-contiguous attention masks are not supported by FlashAttention.
This is not an issue as long as attention masking is only used to mask
padding (which is usually the case).
- FlashAttention does not return attention scores, so even if
`output_attentions=True` is passed to the model, the attention scores
will just be `None`.

Misc. changes include:
- Exporting script for uploading models to HF (authored by
@andreaskoepf, committed by me).
- Fixed a bug where only parts of the config was logged to WandB.
- Fixed a bug where WandB logging was activated even when disabled in
the config.
- Updated training config

[1] https://github.com/HazyResearch/flash-attention
[2] https://github.com/facebookresearch/xformers
  • Loading branch information
dvruette committed Mar 12, 2023
1 parent 83f7f13 commit 2862c28
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 68 deletions.
32 changes: 22 additions & 10 deletions model/model_training/configs/config.yaml
Expand Up @@ -10,13 +10,13 @@ defaults:
weight_decay: 0.00
warmup_steps: 600
eval_steps: 200
save_steps: 500
save_steps: 1000
max_length: 512
num_train_epochs: 3
logging_steps: 10
max_grad_norm: 2.0
save_total_limit: 4
fp16: false
fp16: true
eval_accumulation_steps:
freeze_layer:
datasets:
Expand Down Expand Up @@ -55,9 +55,10 @@ defaults:
verbose: false
output_dir: saved_model
use_custom_sampler: false
random_offset_probability: 0.5 # probability for random message offsets
random_offset_probability: 0.8 # probability for random message offsets
label_masking: true
residual_dropout: 0.0
residual_dropout: 0.2
use_flash_attention: true
sort_by_length: false

oa_dataset_only:
Expand Down Expand Up @@ -94,20 +95,31 @@ pythia-1B:
model_name: EleutherAI/pythia-1b-deduped
weight_decay: 0.0
max_length: 520
warmup_steps: 1000
warmup_steps: 10
gradient_checkpointing: false
gradient_accumulation_steps: 2
per_device_train_batch_size: 16
per_device_eval_batch_size: 32
gradient_accumulation_steps: 1
per_device_train_batch_size: 4
per_device_eval_batch_size: 16

pythia-6.9B:
learning_rate: 8e-6
model_name: EleutherAI/pythia-6.9b-deduped
weight_decay: 0.0
max_length: 520
warmup_steps: 5
max_length: 2048
warmup_steps: 20
gradient_checkpointing: false
gradient_accumulation_steps: 2
per_device_train_batch_size: 4
per_device_eval_batch_size: 4

pythia-12B:
learning_rate: 6e-6
model_name: EleutherAI/pythia-12b-deduped
weight_decay: 0.0
max_length: 2048
warmup_steps: 20
gradient_checkpointing: false
gradient_accumulation_steps: 4
per_device_train_batch_size: 2
per_device_eval_batch_size: 2

Expand Down
7 changes: 5 additions & 2 deletions model/model_training/configs/zero_config.json
Expand Up @@ -26,14 +26,17 @@
}
},
"zero_optimization": {
"stage": 1,
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
Expand Down
46 changes: 0 additions & 46 deletions model/model_training/models/patch_resid_dropout.py

This file was deleted.

122 changes: 122 additions & 0 deletions model/model_training/models/patching.py
@@ -0,0 +1,122 @@
import warnings
from functools import partial
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn.modules.mha import FlashSelfAttention
from torch.nn.utils.rnn import pad_sequence
from transformers import GPTNeoXForCausalLM, GPTNeoXModel

SUPPORTED_MODELS = [
GPTNeoXModel,
GPTNeoXForCausalLM,
]


def _patched_mlp_forward(post_module: nn.Module, module: nn.Module, *args, **kwargs):
post_module.train(module.training)
out = module.old_forward(*args, **kwargs)
out = post_module(out)
return out


def _patched_attn_forward(post_module: nn.Module, module: nn.Module, *args, **kwargs):
post_module.train(module.training)
out = module.old_forward(*args, **kwargs)
hiddens = post_module(out[0])
return (hiddens,) + out[1:]


def _patched_gpt_neox_attn(
module: nn.Module,
flash_attn: FlashSelfAttention,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask=None,
head_mask=None,
):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
flash_attn.train(module.training)
out_dtype = value.dtype
batch_size, max_len = query.size(0), query.size(2)

q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
qkv = torch.stack([q, k, v], dim=2).to(torch.float16) # need to truncate here since rotary embeddings are fp32
cu_seqlens, max_seqlen = None, None

if attention_mask is not None:
# Limitation: attention mask can have "holes", which is currently not handled correctly
# model will be able to pay attention up to the last non-masked token, even if previous tokens are masked.
seqlens = (attention_mask[:, 0, 0, :] == 0).cumsum(dim=1).argmax(dim=1) + 1
qkv = torch.cat([qkv[i, : seqlens[i]] for i in range(batch_size)], dim=0)
cu_seqlens = torch.cat([torch.zeros_like(seqlens[:1]), seqlens.cumsum(dim=0)], dim=0).to(torch.int32)
max_seqlen = seqlens.max().item()

out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
# out: [bs, seq_len, num_attention_heads, attn_head_size]

if attention_mask is not None:
seqs = [out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])]
out = pad_sequence(seqs, batch_first=True)
# restore original sequence length
out = F.pad(out, (0, 0) * (out.dim() - 2) + (0, max_len - out.size(1)), value=0.0)
out = out.transpose(1, 2).to(out_dtype)
return out, None


def add_dropout(module: nn.Module, patched_fwd: Callable, p_dropout: float = 0.1):
dropout = nn.Dropout(p=p_dropout)
module.old_forward = module.forward
module.forward = partial(patched_fwd, dropout, module)


def add_flash_attn(module: nn.Module, causal: bool = True):
"""
Replaces the standard attention implementation with Flash Attention [1].
Limitations:
- Only works for fp16 or bf16 inputs
- Requires inputs to be on CUDA
- `outptu_attentions=True` does work after patching, attention weights will be None
- Non-contiguous attention masks are not supported (e.g. [1, 1, 0, 1, 1, 0, 0] will just become [1, 1, 1, 1, 1, 0, 0]).
[1] https://github.com/HazyResearch/flash-attention
"""

if not hasattr(module, "_attn"):
warnings.warn("Provided module doesn't have a _attn() function to be patched.")
flash_attn = FlashSelfAttention(causal=causal)
module._attn = partial(_patched_gpt_neox_attn, module, flash_attn)


def patch_model(model: GPTNeoXModel, resid_pdrop: Optional[float] = 0.1, flash_attention: bool = True):
"""
Helper function for patching HF language models.
Currently supports: GPTNeoX-based models
Limitations:
- Flash attention requires CUDA and fp16/bf16 training. It also requires contiguous attention masks.
- Residual dropout does not support multi-GPU training without DeepDpeed.
"""

if resid_pdrop is not None and (resid_pdrop < 0 or resid_pdrop > 1.0):
raise ValueError("Invalid argument: `resid_pdrop` must be between 0.0 and 1.0")

if not any(isinstance(model, model_class) for model_class in SUPPORTED_MODELS):
warnings.warn(
"Patching residual dropout has only been tested with this model class. "
f"Please make sure that it also works for `{model.__class__.__name__}`."
)

if isinstance(model, GPTNeoXForCausalLM):
model = model.gpt_neox

for layer in model.layers:
if resid_pdrop is not None:
add_dropout(layer.attention, _patched_attn_forward, resid_pdrop)
add_dropout(layer.mlp, _patched_mlp_forward, resid_pdrop)

if flash_attention:
add_flash_attn(layer.attention, causal=True)
48 changes: 42 additions & 6 deletions model/model_training/models/test_patched_gpt_neox.py
@@ -1,9 +1,44 @@
import torch
from patch_resid_dropout import patch_model
from transformers import GPTNeoXModel
from patching import patch_model
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTNeoXModel


def main():
def test_flash_attention_patch(dtype=torch.float16, device="cuda"):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer.add_special_tokens({"pad_token": "<pad>"})

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped", torch_dtype=dtype).to(device)
patched_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped", torch_dtype=dtype).to(device)
patch_model(patched_model, resid_pdrop=None, flash_attention=True)

device = model.device
n_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // n_heads

with torch.no_grad():
for layer1, layer2 in zip(model.gpt_neox.layers, patched_model.gpt_neox.layers):
q = torch.randn(4, n_heads, 10, head_dim, dtype=dtype, device=device)
k = torch.randn(4, n_heads, 10, head_dim, dtype=dtype, device=device)
v = torch.randn(4, n_heads, 10, head_dim, dtype=dtype, device=device)
attn1, attn2 = layer1.attention, layer2.attention

out1, _ = attn1._attn(q, k, v)
out2, _ = attn2._attn(q, k, v)

assert ((out1 - out2).abs() < 1e-2).all()

batch = tokenizer(["hello world", "lorem ipsum dolor sit amet"], padding=True, return_tensors="pt").to(device)
out1 = model(**batch).logits
out2 = patched_model(**batch).logits

diff = (out1 - out2) * batch["attention_mask"].unsqueeze(-1)
assert (diff.abs() < 1).all()

input_ids = torch.randint(0, model.config.vocab_size, size=(2, 10), device=device)
patched_model(input_ids).logits.mean().backward()


def test_resid_dropout_patch():
model = GPTNeoXModel.from_pretrained("EleutherAI/pythia-70m-deduped")
model.eval()

Expand All @@ -13,7 +48,7 @@ def main():

logits_before = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

patch_model(model, resid_pdrop=0.2)
patch_model(model, resid_pdrop=0.2, flash_attention=False)

logits_after = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

Expand All @@ -39,7 +74,7 @@ def main():
assert (y1 - y2).abs().sum() > 1e-5, "mlp output is the same for different forward passes"

model = GPTNeoXModel.from_pretrained("EleutherAI/pythia-70m-deduped")
patch_model(model, resid_pdrop=0.0)
patch_model(model, resid_pdrop=0.0, flash_attention=False)

with torch.no_grad():
logits1 = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
Expand All @@ -54,4 +89,5 @@ def main():


if __name__ == "__main__":
main()
test_flash_attention_patch()
test_resid_dropout_patch()
62 changes: 62 additions & 0 deletions model/model_training/tools/export.py
@@ -0,0 +1,62 @@
import argparse
import sys

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="checkpoint path or model name")
parser.add_argument("--dtype", type=str, default="float16", help="float16, float32")
parser.add_argument("--hf_repo_name", type=str, help="Huggingface repository name (used with --push)")
parser.add_argument("--auth_token", type=str, help="")
parser.add_argument("--output_folder", type=str, help="output folder path (used with)")
parser.add_argument("--max_shard_size", type=str, default="10GB")
parser.add_argument("--cache_dir", type=str)
return parser.parse_args()


def main():
args = parse_args()

if args.dtype in ("float16", "fp16"):
torch_dtype = torch.float16
elif args.dtype in ("float32", "fp32"):
torch_dtype = torch.float32
else:
print(f"Unsupported dtpye: {args.dtype}")
sys.exit(1)

if not args.hf_repo_name and not args.output_folder:
print(
"Please specify either `--hf_repo_name` to push to HF or `--output_folder` "
"to export the model to a local folder."
)
sys.exit(1)

print(f"Loading tokenizer '{args.model_name}' ...")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
print(f"{type(tokenizer).__name__} (vocab_size={len(tokenizer)})")

print(f"Loading model '{args.model_name}' ({args.dtype}) ...")
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch_dtype, cache_dir=args.cache_dir)
print(f"{type(model).__name__} (num_parameters={model.num_parameters()})")

if args.output_folder:
print(f"Saving model to: {args.output_folder}")
model.save_pretrained(args.output_folder, max_shard_size=args.max_shard_size)

print(f"Saving tokenizer to: {args.output_folder}")
tokenizer.save_pretrained(args.output_folder)

if args.hf_repo_name:
print("Uploading model to HF...")
model.push_to_hub(args.hf_repo_name, use_auth_token=args.auth_token, max_shard_size=args.max_shard_size)

print("Uploading tokenizer to HF...")
tokenizer.push_to_hub(args.hf_repo_name, use_auth_token=args.auth_token)


if __name__ == "__main__":
main()

0 comments on commit 2862c28

Please sign in to comment.