Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi: static cache & compile compatibility #30688

Closed
wants to merge 15 commits into from

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented May 7, 2024

What does this PR do?

This PR enables compile for Phi models. Checked the correctness by running speed benchmark script (the results is below) and a test for dynamic vs static match.

A few observations while testing the generation quality:

  • using static cache sometimes generates total gibberish in batched input, when the input gets left padded + kinda right padded because of the static cache. Yes, the attn mask is there, but for some reason the generation gets better only when i try to crop trailing zeros at the end of key/values
  • The tests I ran were without logits check, those are failing right now for compiled static cache condition (with a tolerance=0.1)
Benchmark results Latency
Script to evaluate on text-level match between dynamic vs static cache
import os
import argparse

import torch
import torch._dynamo.config
import torch._inductor.config
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "0"

# torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.fx_graph_cache = True
# torch._dynamo.config.cache_size_limit = 32
torch.set_float32_matmul_precision('high')

CBOLD = '\033[1m'
CRED = '\033[91m'
CEND = '\033[0m'

def check_outputs(
        text_dynamic_cache,
        text_other,
        dynamic_logits,
        other_logits,
        error_msg,
        check_logits,
        atol=1e-03,
        rtol=1e-03
    ): 
    assert(text_dynamic_cache == text_other), f"Texts do not match for {CBOLD}{error_msg}{CEND}"
    if check_logits:
        for token_id, (t1, t2) in enumerate(zip(dynamic_logits, other_logits)):
            assert(torch.allclose(t1, t2, atol=atol, rtol=rtol)), \
                    f"Logits at token position {token_id} do not match for {CBOLD}{error_msg}{CEND}"


def check_static_cache(model, tokenizer, max_new_tokens=100, check_logits=False, static_enabled=True, verbose=True):
    prompts = [
        "The sun dipped below the horizon, painting the sky in red.",
        "I almost missed the bus this morning, but luckily the driver saw me and",
    ]

    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(model.device)
    inputs_length = inputs.input_ids.shape[1]
    generate_kwargs = {
        "pad_token_id": tokenizer.pad_token_id,
        "min_new_tokens": max_new_tokens,
        "max_new_tokens": max_new_tokens,
        "do_sample": False,
        "temperature": 1.0,
        "top_p": 1.0,
        "output_logits": True,
        "return_dict_in_generate": True
        }

    # eager + dynamic cache
    out_dynamic_cache = model.generate(**inputs, **generate_kwargs)
    text_dynamic_cache = tokenizer.batch_decode(
        out_dynamic_cache.sequences[:, inputs_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )
    dynamic_logits = out_dynamic_cache.logits

    if verbose:
        print("-" * 100)
        print(f"{CBOLD}Dynamic Cache output:{CEND} {text_dynamic_cache}")
        print("-" * 100)

    if static_enabled:
        # eager + static cache
        out_static_cache = model.generate(**inputs, **generate_kwargs, cache_implementation="static")
        text_static_cache = tokenizer.batch_decode(
            out_static_cache.sequences[:, inputs_length:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        static_logits = out_static_cache.logits
        
        if verbose:
            print(f"{CBOLD}Static Cache output:{CEND} {text_static_cache}") 
            print("-" * 100)
        
        check_outputs(
            text_dynamic_cache,
            text_static_cache,
            dynamic_logits,
            static_logits,
            "Static Cache vs Dynamic Cache",
            check_logits=check_logits,
            atol=0.1,
            rtol=0.1,
        )


        # compiled (fullgraph=true) + static cache
        model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
        out_static_cache_complied = model.generate(**inputs, **generate_kwargs, cache_implementation="static")
        text_static_cache_compiled = tokenizer.batch_decode(
            out_static_cache_complied.sequences[:, inputs_length:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        static_compiled_logits = out_static_cache_complied.logits

        if verbose:
            print(f"{CBOLD}Compiled Static Cache + compiled output:{CEND} {text_static_cache_compiled}")
            print("-" * 100)
        
        check_outputs(
            text_dynamic_cache,
            text_static_cache_compiled,
            dynamic_logits,
            static_compiled_logits,
            "Compiled Static Cache vs Dynamic Cache",
            check_logits=check_logits,
            atol=0.1,
            rtol=0.1,
        )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="microsoft/phi-2")
    parser.add_argument("--attn_implementation", type=str, default="eager")
    parser.add_argument("--trust_remote_code", action="store_false")
    parser.add_argument("--dtype", type=str, default="fp16")

    parser.add_argument("--max_new_tokens", type=int, default=100)
    parser.add_argument("--static_cache_enabled", action="store_false")
    parser.add_argument("--check_logits", action="store_true")
    parser.add_argument("--verbose", action="store_false")

    args = parser.parse_args()

    if args.dtype == "fp16":
        dtype = torch.float16
    elif args.dtype == "fp32":
        dtype = torch.float32
    elif args.dtype == "bf16":
        dtype = torch.bfloat16
    else:
        raise ValueError(f"Unknown dtype: {args.dtype}")

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=bool(args.trust_remote_code),
        attn_implementation=args.attn_implementation,
        torch_dtype=dtype
    ).to("cuda:0")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code), padding_side="left")
    tokenizer.pad_token_id = tokenizer.eos_token_id

    check_static_cache(
        model,
        tokenizer,
        check_logits=args.check_logits,
        static_enabled=args.static_cache_enabled,
        max_new_tokens=args.max_new_tokens,
        verbose=args.verbose,
    )

if __name__ == "__main__":
    main()

@zucchini-nlp zucchini-nlp requested a review from gante May 7, 2024 09:13
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member

gante commented May 8, 2024

@zucchini-nlp to clarify:

using static cache sometimes generates total gibberish in batched input, when the input gets left padded + kinda right padded because of the static cache. Yes, the attn mask is there, but for some reason the generation gets better only when i try to crop trailing zeros at the end of key/values

This is with static cache AND compile, correct? Without compile it has no problems, correct? (I haven't seen them yet, if it happens without compile a reproduction example would be helpful!)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me, a few nits to be addressed!

Also -- let's enable slow phi tests in this PR 🔥

src/transformers/models/phi/modeling_phi.py Outdated Show resolved Hide resolved
src/transformers/models/phi/modeling_phi.py Outdated Show resolved Hide resolved
src/transformers/models/phi/modeling_phi.py Outdated Show resolved Hide resolved
src/transformers/models/phi/modeling_phi.py Show resolved Hide resolved
src/transformers/models/phi/modeling_phi.py Outdated Show resolved Hide resolved
src/transformers/models/phi/modeling_phi.py Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

@gante

This is with static cache AND compile, correct? Without compile it has no problems, correct? (I haven't seen them yet, if it happens without compile a reproduction example would be helpful!)

I found some pattern that it happens only in eager-fp32 precision for Phi models, while in half-precision everything is okay. Since Llama is also compile compatible, I tested on that and found Llama has garbage generation in eager-fp16 😭

I am quite lost right now about what might be the issue, I will try to investigate more next week. If you have time, feel free to take a look. The below commands will reproduce it with the provided script in PR description

python static.py --attn_implementation eager --static_cache_enabled --dtype fp16 --model_name_or_path meta-llama/Llama-2-7b-chat-hf -> for DynamicCache
python static.py --attn_implementation eager --static_cache_enabled --dtype fp32 --model_name_or_path microsoft/phi-2 -> for Compiled StaticCache

@zucchini-nlp
Copy link
Member Author

@gante as we discussed, I will not dig into the gibberish generation for fp32. In that case the PR should be ready to merge when we get the slow-test passing. Pushed a [run-slow] commit, can you approve it to run?

@zucchini-nlp zucchini-nlp requested a review from gante May 16, 2024 08:32
@hegderavin
Copy link

Can you please port the changes to Phi3 as well? I can help test it if you want

@zucchini-nlp
Copy link
Member Author

@hegderavin sure, we will be porting models one by one (#28981). Right now I am waiting for this PR to be merged, so that we can work on other models

I can add Phi3 as a separate PR around next week, if you wanted to pull changes and compile the model :)

@zucchini-nlp
Copy link
Member Author

Updates:

  1. Added Phi3 as per the above requests. But Phi3 cannot do sliding window in flash attention with static cache. Should it be able to? I think it's doable but I didn't change anything yet, cause I am not sure if it's a real use-case.
  2. Added SDPA support for Phi3, which was available in the code, but the attribute _supports_sdpa=True was missing

Also, I wanted to suggest to move _update_causal_mask to modeling_utils.py, given that we're getting more and more static cache models.

@@ -171,7 +172,7 @@ def __init__(self, dim, config, device=None):

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
seq_len = position_ids.shape[-1]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that position ids are not always same size as input. Will come back to revert this later, which means that compile still doesn;t work for rope scaling in Phi3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, with left-padding the maximum value of position_ids is not the sequence length. For that (sequence length) we have cache_positions :)

@gante
Copy link
Member

gante commented Jun 14, 2024

@zucchini-nlp

Also, I wanted to suggest to move _update_causal_mask to modeling_utils.py, given that we're getting more and more static cache models.

Perhaps. It is not a model architecture piece of code, but rather an input preparation one, so it could make sense to live in a shared module. Raise the discussion on slack!

Added Phi3 as per the above requests. But Phi3 cannot do sliding window in flash attention with static cache. Should it be able to? I think it's doable but I didn't change anything yet, cause I am not sure if it's a real use-case.

Have you looked into the SlidingWindowCache class? It's a static cache with a sliding window :)

@zucchini-nlp
Copy link
Member Author

Have you looked into the SlidingWindowCache class? It's a static cache with a sliding window :)

Perfect, this will do the work, thanks!

@gante
Copy link
Member

gante commented Jun 17, 2024

(lmk when it is ready for a re-review)

@zucchini-nlp zucchini-nlp added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jul 12, 2024
@huggingface huggingface deleted a comment from github-actions bot Jul 12, 2024
@helunwencser
Copy link
Contributor

helunwencser commented Jul 29, 2024

hi @zucchini-nlp , @gante is there any more update on this PR? I want to use phi-3 with static cache. This PR is super useful.

@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented Jul 30, 2024

@helunwencser this PR moved a bit into #31421 (review) where we're handling new format cache for all models. That PR probably will enable compile for all models, except for some cases when there's a model-specific dynamic control flow. I am planning to deal with that after the PR is merged

Also, regarding Phi-3, it has a dynamic control flow if we use scaled RoPE embeddings and I didn't decide how to better solve this issue.

EDIT: hehe, sorry, I forgot that we already merged Phi with new cache format, So it should be compilable as long as you use the '4k' checkpoint which has no scaling. For scaling, the PR is coming soon

@helunwencser
Copy link
Contributor

Thanks! Unfortunately I need to use StaticCache for Phi3. Looking forward to having the new PR.

@zucchini-nlp
Copy link
Member Author

closing as Phi will be compile compatible in #32617 and for Phi3 we have to wait dynamic control flow support from Pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants