Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama3 with torch.compile used more memory #31471

Closed
2 of 4 tasks
songh11 opened this issue Jun 18, 2024 · 6 comments
Closed
2 of 4 tasks

llama3 with torch.compile used more memory #31471

songh11 opened this issue Jun 18, 2024 · 6 comments

Comments

@songh11
Copy link

songh11 commented Jun 18, 2024

System Info

  • transformers version: 4.41.2
  • Platform: Linux-5.4.252-0504252-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@SunMarc, @zucchini-nlp, @gante,I hope I can get your help

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

# `torch.compile`-enabled Llama 3
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch, time, os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_path, device_map="auto", torch_dtype=torch.float16
).half()

model.generation_config.cache_implementation = "static"
model._static_cache = StaticCache(
    config=model.config,
    max_batch_size=1,
    max_cache_len=4096,
    device=model.device,
    dtype=torch.float16,
)
model.model.forward = torch.compile(model.model.forward, fullgraph=True, mode="reduce-overhead", dynamic=True)

# The first iteration is slow (compilation overhead). Subsequent iterations are
# faster than the uncompiled call.
inputs = tokenizer(["The quick brown"], return_tensors="pt", padding=True).to(model.device)
for i in range(5):
    start = time.time()
    print('{} before generate: {} G, reserved: {} G'.format(i, torch.cuda.max_memory_allocated() / 1024**3, torch.cuda.max_memory_reserved() / 1024**3))
    gen_out = model.generate(
        **inputs, do_sample=True, max_new_tokens=100, use_cache=True, 
    )
    print('{} after generate: {} G, reserved: {} G'.format(i, torch.cuda.max_memory_allocated() / 1024**3, torch.cuda.max_memory_reserved() / 1024**3))
    print(f"Time taken: {time.time() - start:.2f}s")
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))

Expected behavior

The first generate nvidia-smi showed memory about 16G, but during the second operation,nvidia-smi showed the memory will grow to 20G. But torch.cuda.max_memory_reserved() just showed about 16G. I don't know what the problem is, can you help me to answer it.

@gante
Copy link
Member

gante commented Jun 18, 2024

Hi @songh11 👋

If you check the documentation regarding torch.compile, especially relative to the "reduce-overhead" flag, you'll see an explanation :)

@songh11
Copy link
Author

songh11 commented Jun 18, 2024

Hi @songh11 👋

If you check the documentation regarding torch.compile, especially relative to the "reduce-overhead" flag, you'll see an explanation :)

Many thanks, another question I have is why does the second generate use more memory

@zucchini-nlp
Copy link
Member

Interestingly I didn't get sudden memory spike after second generation and after 5 steps the memory remained around 16GB 🤔 . My specs are:

PyTorch version: 2.3.0+cu121
CUDA used to build PyTorch: 12.1
OS: Ubuntu 20.04.6 LTS (x86_64) 
GPU: NVIDIA A100-SXM4-80GB

@gante
Copy link
Member

gante commented Jun 20, 2024

@zucchini-nlp In my experience the spikes are hardware-dependent, even when two devices have the same spare memory available.

@songh11 "You may might also notice that the second time we run our model with torch.compile is significantly slower than the other runs, although it is much faster than the first run. This is because the "reduce-overhead" mode runs a few warm-up iterations for CUDA graphs." (source)

@songh11
Copy link
Author

songh11 commented Jun 21, 2024

Interestingly I didn't get sudden memory spike after second generation and after 5 steps the memory remained around 16GB 🤔 . My specs are:

PyTorch version: 2.3.0+cu121
CUDA used to build PyTorch: 12.1
OS: Ubuntu 20.04.6 LTS (x86_64) 
GPU: NVIDIA A100-SXM4-80GB

NVIDIA RTX A5000, I think the second generation is also for warm-up.

@songh11
Copy link
Author

songh11 commented Jun 21, 2024

@zucchini-nlp In my experience the spikes are hardware-dependent, even when two devices have the same spare memory available.

@songh11 "You may might also notice that the second time we run our model with torch.compile is significantly slower than the other runs, although it is much faster than the first run. This is because the "reduce-overhead" mode runs a few warm-up iterations for CUDA graphs." (source)

Thank you for your reply. I can use default to pass.

@songh11 songh11 closed this as completed Jun 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants