Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starcoder has higher eval loss with flash attention 2 #28925

Closed
2 of 4 tasks
lidingsnyk opened this issue Feb 8, 2024 · 2 comments
Closed
2 of 4 tasks

Starcoder has higher eval loss with flash attention 2 #28925

lidingsnyk opened this issue Feb 8, 2024 · 2 comments

Comments

@lidingsnyk
Copy link

lidingsnyk commented Feb 8, 2024

System Info

transformers version: 4.36.2
flash-attn: 2.5.2 flash_attn-2.5.2+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64
Platform: linux_x86_64 cp310 ubuntu-22.04
Python version: 3.10
Huggingface_hub version: 0.20.3
Safetensors version: 0.4.2
Accelerate version: 0.26.1
Accelerate config: not found
PyTorch version (GPU?): 2.1.2 (True) torch-2.1.2-cu118-cp310-cp310-linux_x86_64.whl
Tensorflow version (GPU?): not installed
Flax version (CPU?/GPU?/TPU?): not installed
Jax version: not installed
JaxLib version: not installed
Using GPU in script?: yes. A100
CUDA_VERSION: 11.8.0
Using distributed or parallel set-up in script?: yes (deepspeed 0.11.2)

Who can help?

There is a similar git issue, but I also have additional observations arounds inference.

After GPTBigCode adds support to flash attention 2 in transformers 4.36, I ran inference with flash attention 2 enabled on a fine-tuned starcoderbase-3b which was previously created with 4.35. The inference metrics of output-label exact match dropped significantly, with some slices as low as 0%. Upon inspection, many outputs are simply repeating one token, suggesting bugs around the attention mechanism.

I then tried fine tuning a new model with transformers 4.36 and flash attention 2 enabled. While exact match are now a bit higher, all metrics still see drops significantly compared with previous model without flash attention 2. For instance, eval_loss increased 0.53 -> 0.75.

However, final training loss are similar at around 0.07. Fine tuning with flash attention 2 is very unstable, with training loss at 0.28 with a different batch_size.

Enabling and disabling padding (batch_size=1, pad_to_multiple_of=None) in trainer makes no meaningful difference in the metrics.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Model is loaded the same for training and inference. The only difference being inference is loading a fine-tuned starcoder model.

model = AutoModelForCausalLM.from_pretrained("bigcode/starcoderbase-3b", trust_remote_code=True, use_flash_attention_2=True, torch_dtype=torch.bfloat16)
trainer = CustomerTrainer(
        model=model,
        tokenizer=tokenizer,
        args=args,
        train_ds=train_ds,
        val_ds=val_ds,
    )
    trainer.train()

class CustomTrainer(transformers.Trainer):
    def __init__(
        self, model, tokenizer, args, train_ds, val_ds,
    ):
        model_type = ModelType.infer_from_model(model)
        if model_type == ModelType.CAUSAL:
            data_collator = transformers.DataCollatorForLanguageModeling(
                tokenizer=tokenizer,
                mlm=False,
                return_tensors="pt",
            )

        super().__init__(
            model=model,
            train_dataset=cast(torch.utils.data.dataset.Dataset, train_ds),
            eval_dataset=cast(torch.utils.data.dataset.Dataset, val_ds),
            tokenizer=tokenizer,
            args=args.training_args,
            data_collator=data_collator,
        )

Some important training args:
learning_rate: 1e-5
gradient_accumulation_steps: 16
bf16: "True"
torch_compile_mode: max-autotune

inference args:
beam_size: 5
tokenizer_max_length: 512

Expected behavior

For training, loss should not go up compared with use_flash_attention_2=False.

For inference, a fine-tuned model (regardless of how it's trained) should produce the same / mostly same result in inference regardless of if flash attention 2 is enabled.

@lidingsnyk lidingsnyk changed the title Starcoder has higher loss with flash attention 2 Starcoder has higher eval loss with flash attention 2 Feb 8, 2024
@amyeroberts
Copy link
Collaborator

Hi @lidingsnyk, thanks for raising this issue!

There was a similar issue posted - #28891, which was resolved. Could you try installing from source to confirm if this resolves your issue?

@lidingsnyk
Copy link
Author

Thanks a lot @amyeroberts . Indeed the issue is fixed. I'm getting the exact same metrics in our batch inference with flash attention 2 enabled. Looking forward to next released version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants