Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llava-next demo / tutorial code does not work #30294

Closed
2 of 4 tasks
bghira opened this issue Apr 17, 2024 · 35 comments
Closed
2 of 4 tasks

llava-next demo / tutorial code does not work #30294

bghira opened this issue Apr 17, 2024 · 35 comments

Comments

@bghira
Copy link

bghira commented Apr 17, 2024

System Info

  • transformers version: 4.40.0.dev0
  • Platform: macOS-14.4.1-arm64-arm-64bit
  • Python version: 3.10.14
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.26.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: NO
    - mixed_precision: fp16
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.4.0.dev20240407 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@amyeroberts @arth

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Install the required dependencies
  2. Copy the demonstration code:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")

model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) 
model.to("mps:0")

# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"

inputs = processor(prompt, image, return_tensors="pt").to("mps:0")

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)

print(processor.decode(output[0], skip_special_tokens=True))
  1. Execute:
in oss file
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 19.79it/s]
Traceback (most recent call last):
  File "/Users/bghira/src/SimpleTuner/toolkit/captioning/caption_with_llava.py", line 19, in <module>
    output = model.generate(**inputs, max_new_tokens=100)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/transformers/models/llava_next/modeling_llava_next.py", line 555, in forward
    inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/transformers/models/llava_next/modeling_llava_next.py", line 411, in _merge_input_ids_with_image_features
    raise ValueError(
ValueError: The input provided to the model are wrong. The number of image tokens is 1 while the number of image given to the model is 1. This prevents correct indexing and breaks batch generation.

There is apparently some confusion on the other open issue about this error message, but it seems to be a distinctly different problem.

I can't use the Git version of Transformers or the latest release of Tokenizers, because this library (for SOME reason) needs an older release of Tokenizers.

Expected behavior

The demo should work.

@zucchini-nlp
Copy link
Member

Can this be related to the device "mps"? I am not able to reproduce on "cuda"

@bghira
Copy link
Author

bghira commented Apr 17, 2024

happens on CPU too, and other people on the Community page of the 34b model have reported the same issue

@bghira
Copy link
Author

bghira commented Apr 17, 2024

CPU Inputs: tensor([[    6,  1328,   144, 47329,   567,  3275,    98,     7,     6,  2942,
           144, 64000, 59568,   144,  5697,   620,  2709,   594,   719,  2728,
           100,     7,     6, 14135,   144]])
GPU Inputs: tensor([[    6,  1328,   144, 47329,   567,  3275,    98,     7,     6,  2942,
           144, 64000, 59568,   144,  5697,   620,  2709,   594,   719,  2728,
           100,     7,     6, 14135,   144]], device='mps:0')

@hvaara
Copy link
Contributor

hvaara commented Apr 19, 2024

@bghira I'm unable to reproduce on CPU. Can you show the MRE you have for CPU? Ie. a script/gist that I can run top to bottom that reproduces what you're seeing.

On MPS I'm able to reproduce, and I've verified there's a bug (in PyTorch/MPS). I have a way to mitigate the issue, but it's suboptimal to merge this approach in transformers since that specific bug is only affecting MPS. Please let me know if you want the transformers workaround while I work on getting this fixed in PyTorch.

@amyeroberts
Copy link
Collaborator

@hvaara Thanks for looking into this!

If you have a workaround, it would be great if you could share. As you say, it doesn't make sense to merge it on the transformers side but potentially very useful for the community whilst the fix is still to be merged into PyTorch.

@bghira
Copy link
Author

bghira commented Apr 19, 2024

hmm, i can't reproduce it on CPU, but it also never really returns anything now 😓 might have seen this error on a different version when I saw it there. maybe now it's simply working, computing the result.

@bghira
Copy link
Author

bghira commented May 1, 2024

is there a link to upstream bug?

@hvaara
Copy link
Contributor

hvaara commented May 1, 2024

Sorry, I completely forgot about this. I'll get both the patch to upstream and a workaround for transformers out ASAP (aiming for within a few hours).

@hvaara
Copy link
Contributor

hvaara commented May 1, 2024

The code changes for a workaround in transformers can be seen in hvaara@ed2f0df.

I've created a gist to show how to run the model with these changes patched in. Note that I'm using llava-hf/llava-v1.6-mistral-7b-hf in the gist as opposed to llava-hf/llava-v1.6-34b-hf from #30294 (comment).

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

is needed because the aten::isin operator is not natively implemented for MPS yet. This snippets lets PyTorch fall back to the CPU for that specific operator. Once pytorch/pytorch#124896 is merged it should be available in the nightlies of PyTorch. I think falling back to the CPU for our use-case in transformers will not have a significant impact on performance.

To answer your original question regarding a link to a bug in upstream, there isn't one yet (AFAIK). I think I'll skip the ceremony of creating a bug, and jump straight to a pull request since I have the code fix and regression test ready for it already. I'll provide an update when I've done that with a link to the PR. Please feel free to nudge me again if there isn't an update in a timely manner ;)

@bghira
Copy link
Author

bghira commented May 1, 2024

@pcuenca will love that! you're the best @hvaara

@pcuenca
Copy link
Member

pcuenca commented May 1, 2024

Thanks for the ping @bghira :) Very nice work @hvaara 🙌 Just fyi, note that an mps-friendly path for isin() was recently merged in #30376 :)

@hvaara
Copy link
Contributor

hvaara commented May 1, 2024

@pcuenca That's great! Thanks for letting us know 😄

pytorch/pytorch#96614 describes the problem in MPS/PyTorch seen in this issue.

@hvaara
Copy link
Contributor

hvaara commented May 1, 2024

PyTorch PR: pytorch/pytorch#125318

@hvaara
Copy link
Contributor

hvaara commented May 3, 2024

pytorch/pytorch#125318 has been merged. I don't know exactly when a new release will be cut, but I'd expect it to be available in the nightlies in <24h. Assuming it's not reverted of course ;)

@pcuenca
Copy link
Member

pcuenca commented May 3, 2024

pytorch/pytorch#125318 has been merged

Nice! 🙌

@bghira
Copy link
Author

bghira commented May 3, 2024

image

@hvaara
Copy link
Contributor

hvaara commented May 4, 2024

I guess this issue can be closed? :)

@bghira
Copy link
Author

bghira commented May 4, 2024

yes, this error is fixed but the model still just doesn't work on MPS... big sad

@bghira bghira closed this as completed May 4, 2024
@hvaara
Copy link
Contributor

hvaara commented May 5, 2024

For the 34B version? Does the 7b version work as expected? What's the failure mode? How do you provoke it? What do you expect to see?

@bghira
Copy link
Author

bghira commented May 5, 2024

yes, the 34b on 128G M3 Max.

it just gets stuck and never returns. i left it for >12 hours overnight. how long is it expected to take?

i do not have the network bandwidth to download the 7B as well yet, i never focused on it because it performed very poorly in my group's initial tests. it is very hamstrung by the 7B LLM.

@hvaara
Copy link
Contributor

hvaara commented May 5, 2024

What's the code/prompt you used? Can't guarantee I know how to fix it, but I want to try to replicate what you're seeing.

I don't know about speed, but I'd expect it to be measured in iterations per second, not iterations per day 😅

@bghira
Copy link
Author

bghira commented May 5, 2024

used the same demo code as initially posted:

[:)] % python llava.py 
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 22.16it/s]

@hvaara
Copy link
Contributor

hvaara commented May 7, 2024

I tested on an M3 Max 128 GB.

The problem you're seeing seems to be due to a resource utilization issue (see https://gist.github.com/hvaara/f8f911c6e271cae8aa0e4811b24d4eb5). I suspect another issue with MPS in PyTorch (memory leak in the MPS cache).

I'll try to give another update with a workaround in a couple of hours. Nudge me if I forget.

@bghira
Copy link
Author

bghira commented May 7, 2024

oh 🤔 can we disable the mps cache? i assume leaked objects means we can't even sync and force gc to clear them

@hvaara
Copy link
Contributor

hvaara commented May 7, 2024

Yup. The proposed workaround was doing torch.mps.empty_cache() in the loop:

 for chunk in tqdm(streamer, total=max_new_tokens+2):
     text.append(chunk)
     print(chunk)
+    torch.mps.empty_cache()
     print(f"{torch.mps.driver_allocated_memory()/1024**3 = }")
     print(f"{torch.mps.current_allocated_memory()/1024**3 = }")
 print(f"{len(text)}")

I'll know if this worked very soon 😄

@hvaara
Copy link
Contributor

hvaara commented May 7, 2024

It worked https://gist.github.com/hvaara/9a7c361478ad8aa429d8172b3119a657.

Are you able to reproduce on your end?

@bghira
Copy link
Author

bghira commented May 7, 2024

i can't really run the current pytorch nightly, so i'm unable to test this anymore. the error I receive is now inside pytorch, some hardcoded cuda codepath.

current pytorch 2.3 is back to the:

ValueError: The input provided to the model are wrong. The number of image tokens is 1 while the number of image given to the model is 1. This prevents correct indexing and breaks batch generation.

current pytorch 2.4:

Backend not supported for None

Original traceback:
None

when looking through the pytorch codebase for this error, it's in some rng function that seems to not be called from anywhere :D

image

@bghira
Copy link
Author

bghira commented May 7, 2024

ok, by updating that method and adding a helper:

    @run_and_save_rng_state.py_impl(DispatchKey.MPS)
    def impl_mps(op, *args, **kwargs):
        return torch.mps.get_rng_state(), op(*args, **kwargs)

    @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
    def impl_backend_select(op, *args, **kwargs):
        impl_map = {"cuda": impl_cuda, "cpu": impl_cpu, "mps": impl_mps}
        device = get_device(args, kwargs)
        assert device in impl_map, f"Backend not supported for {device}"
        impl = impl_map[device]
        return impl(op, *args, **kwargs)

i can get past the errors on pytorch nightly:

Class: <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [01:51<00:00,  7.45s/it]
Processing inputs.
Generating completion for prompt: "<|im_start|>system
Answer the questions.<|im_end|><|im_start|>user
<image>
What is shown in this image?<|im_end|><|im_start|>assistant
"
Emptying MPS cache. Memory allocated: 75535089664 bytes.
Emptying MPS cache. Memory allocated: 78829928448 bytes.
Emptying MPS cache. Memory allocated: 80082075648 bytes.
Emptying MPS cache. Memory allocated: 80195469312 bytes.
Emptying MPS cache. Memory allocated: 80881369088 bytes.
Emptying MPS cache. Memory allocated: 80881500160 bytes.

seems marginal but it's 'overing around that value

@hvaara
Copy link
Contributor

hvaara commented May 7, 2024

The reason it works on PyTorch nightly is probably because of pytorch/pytorch#125318. If you can't use the nightlies, does it work if you apply hvaara@ed2f0df as a workaround in transformers (with PyTorch 2.3)?

From your output it looks like it's working?

@bghira
Copy link
Author

bghira commented May 7, 2024

it appears equivalent to using pytorch 2.4 with my additional fix for rng impl

@hvaara
Copy link
Contributor

hvaara commented May 7, 2024

Great! Is everything working as you'd expect on your end now? (Although it's currently held together with string and duct tape 😅)

@bghira
Copy link
Author

bghira commented May 7, 2024

it appears this is the best we can do on mps for now though it's so slow i've never seen the example finish :( i'll have to try downloading the smaller model at some point.

@hvaara
Copy link
Contributor

hvaara commented May 7, 2024

Unfortunately the 34B model is quite slow on this hardware. Quantization and efficient alternative attention mechanisms with MPS support are still limited. This will eventually change. The 7B model is much faster.

FWIW I've allocated resources to find and hopefully mitigate the second batch of issues we saw (MPS cache).

@bghira
Copy link
Author

bghira commented May 7, 2024

yes, MPS unfortunately seems to receive less love than MLX does?

@hvaara
Copy link
Contributor

hvaara commented Jun 28, 2024

Short update: @skotapati (et al?) has found a solution to a memory leak in pooling (see pytorch/pytorch#125217) 🎉

The fix will be in an upcoming macOS release. I'm hopeful this will resolve the bulk of the memory congestion we saw in LLaVA-NeXT 34B.

I had nothing to do with the fix, just wanted to share the good news 😃

/cc @bghira

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants