Skip to content

Commit

Permalink
Fix Llava chat template examples (#30130)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored and Ita Zaporozhets committed May 14, 2024
1 parent 0020a04 commit 8c3019d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
- For better results, we recommend users to prompt the model with the correct prompt format:

```bash
"USER: <image>\n<prompt>ASSISTANT:"
"USER: <image>\n<prompt> ASSISTANT:"
```

For multiple turns conversation:

```bash
"USER: <image>\n<prompt1>ASSISTANT: <answer1>USER: <prompt2>ASSISTANT: <answer2>USER: <prompt3>ASSISTANT:"
"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"
```

### Using Flash Attention 2
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava model."""
"""PyTorch Llava model."""

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -388,16 +389,16 @@ def forward(
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down
12 changes: 6 additions & 6 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Llava model. """
"""Testing suite for the PyTorch Llava model."""

import copy
import gc
Expand Down Expand Up @@ -398,13 +398,13 @@ def test_small_model_integration_test_llama(self):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)

prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:"
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip

self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
Expand All @@ -421,8 +421,8 @@ def test_small_model_integration_test_llama_batched(self):
processor = AutoProcessor.from_pretrained(model_id)

prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
"USER: <image>\nWhat is this? ASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
Expand All @@ -431,7 +431,7 @@ def test_small_model_integration_test_llama_batched(self):

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

Expand Down

0 comments on commit 8c3019d

Please sign in to comment.