Skip to content

Commit

Permalink
Release LLaVA-v1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Jan 31, 2024
1 parent ac89962 commit 0a16dea
Show file tree
Hide file tree
Showing 32 changed files with 850 additions and 2,372 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*

[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)]
[📢 [LLaVA-1.6 Blog](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)] [[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)]

🤝Community Contributions: [[llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436)] [[Colab](https://github.com/camenduru/LLaVA-colab)] [[🤗Space](https://huggingface.co/spaces/badayvedat/LLaVA)] [[Replicate](https://replicate.com/yorickvp/llava-13b)] [[AutoGen](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb)] [[BakLLaVA](https://github.com/SkunkworksAI/BakLLaVA)]

Expand All @@ -19,6 +19,7 @@


## Release
- [1/30] 🔥 LLaVA-1.6 is out! With additional scaling to LLaVA-1.5, LLaVA-1.6-34B outperforms Gemini Pro. It can now process 4x more pixels and perform more tasks/applications than before. Check out the [blog post](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). Training/eval data and scripts coming soon.
- [11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)]
- [11/2] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)]
- [10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA.
Expand Down
14 changes: 13 additions & 1 deletion docs/MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@

If you are interested in including any other details in Model Zoo, please open an issue :)

The model weights below are *merged* weights. You do not need to apply delta. The usage of LLaVA checkpoints should comply with the base LLM's model license: [Llama 2](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md).
The model weights below are *merged* weights. You do not need to apply delta. The usage of LLaVA checkpoints should comply with the base LLM's model license.

## LLaVA-v1.6

| Version | LLM | Schedule | Checkpoint | MMMU | MathVista | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MM-Bench | MM-Bench-CN | SEED-IMG | LLaVA-Bench-Wild | MM-Vet |
|----------|----------|-----------|-----------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| LLaVA-1.6 | Vicuna-7B | full_ft-1e | [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) | 35.8 | 34.6 | 81.8 | 64.2 | 57.6 | 70.1 | 64.9 | 86.5 | 1519/332 | 67.4 | 60.6 | 70.2 | 81.6 | 43.9 |
| LLaVA-1.6 | Vicuna-13B | full_ft-1e | [liuhaotian/llava-v1.6-vicuna-13b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b) | 36.2 | 35.3 | 82.8 | 65.4 | 60.5 | 73.6 | 67.1 | 86.2 | 1575/326 | 70 | 64.4 | 71.9 | 87.3 | 48.4 |
| LLaVA-1.6 | Mistral-7B | full_ft-1e | [liuhaotian/llava-v1.6-mistral-7b](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b) | 35.3 | 37.7 | 82.2 | 64.8 | 60.0 | 72.8 | 65.7 | 86.7 | 1498/321 | 68.7 | 61.2 | 72.2 | 83.2 | 47.3 |
| LLaVA-1.6 | Hermes-Yi-34B | full_ft-1e | [liuhaotian/llava-v1.6-34b](https://huggingface.co/liuhaotian/llava-v1.6-34b) | 51.1 | 46.5 | 83.7 | 67.1 | 63.8 | 81.8 | 69.5 | 87.7 | 1631/397 | 79.3 | 79 | 75.9 | 89.6 | 57.4 |

*LLaVA-1.6-34B outperforms Gemini Pro on benchmarks like MMMU and MathVista.*


## LLaVA-v1.5

Expand Down
27 changes: 26 additions & 1 deletion llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_prompt(self):
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""

Expand Down Expand Up @@ -357,13 +357,38 @@ def dict(self):
version="v1_mmtag",
)

conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)

conv_chatml_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,

"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
Expand Down
21 changes: 5 additions & 16 deletions llava/eval/model_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from PIL import Image
import math
Expand Down Expand Up @@ -55,17 +55,14 @@ def eval_model(args):

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

image = Image.open(os.path.join(args.image_folder, image_file))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
image_tensor = process_images([image], image_processor, model.config)[0]

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
image_sizes=[image.size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
Expand All @@ -74,15 +71,7 @@ def eval_model(args):
max_new_tokens=1024,
use_cache=True)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
Expand Down
21 changes: 12 additions & 9 deletions llava/eval/model_vqa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,24 @@ def __getitem__(self, index):

input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

return input_ids, image_tensor
return input_ids, image_tensor, image.size

def __len__(self):
return len(self.questions)


def collate_fn(batch):
input_ids, image_tensors, image_sizes = zip(*batch)
input_ids = torch.stack(input_ids, dim=0)
image_tensors = torch.stack(image_tensors, dim=0)
return input_ids, image_tensors, image_sizes


# DataLoader
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
assert batch_size == 1, "batch_size must be 1"
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
return data_loader


Expand All @@ -88,7 +95,7 @@ def eval_model(args):

data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)

for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)):
idx = line["question_id"]
cur_prompt = line["text"]

Expand All @@ -98,19 +105,15 @@ def eval_model(args):
output_ids = model.generate(
input_ids,
images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
image_sizes=image_sizes,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=args.max_new_tokens,
use_cache=True)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
Expand Down
14 changes: 2 additions & 12 deletions llava/eval/model_vqa_mmbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,12 @@ def eval_model(args):
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

image_tensor = process_images([image], image_processor, model.config)[0]
# image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
image_sizes=[image.size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
Expand All @@ -122,15 +120,7 @@ def eval_model(args):
max_new_tokens=1024,
use_cache=True)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
Expand Down
122 changes: 0 additions & 122 deletions llava/eval/model_vqa_qbench.py

This file was deleted.

Loading

0 comments on commit 0a16dea

Please sign in to comment.