This example will walk you throught the basic usage of MULTI-MODAL models in PromptBench. We hope that you can get familiar with the APIs and use it in your own projects later.

First, there is a unified import of `import promptbench as pb` that easily imports the package.

In [1]:
import promptbench as pb

  from .autonotebook import tqdm as notebook_tqdm


## Load dataset

First, PromptBench supports easy load of datasets.

In [2]:
# print all supported datasets in promptbench
print('All supported datasets: ')
print(pb.SUPPORTED_DATASETS_VLM)

# load a dataset, MMMMU, for instance.
# if the dataset is not available locally, it will be downloaded automatically.
dataset = pb.DatasetLoader.load_dataset("mmmu")

# print the first 5 examples
for idx in range(5):
    print(dataset[idx])

All supported datasets: 
['vqav2', 'nocaps', 'science_qa', 'math_vista', 'ai2d', 'mmmu', 'chart_qa']
Images already saved to local, loading file:  /home/v-mingxia/promptbench/promptbench/data/mmmu/validation.json
{'images': [<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=733x237 at 0x7F13BA2CD160>], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/0_image_1.png'], 'answer': 'B', 'question': '<image 1> Baxter Company has a relevant range of production between 15,000 and 30,000 units. The following cost data represents average variable costs per unit for 25,000 units of production. If 30,000 units are produced, what are the per unit manufacturing overhead costs incurred?\nA: $6\nB: $7\nC: $8\nD: $9'}
{'images': [<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=342x310 at 0x7F13BA2CD550>], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/1_image_1.png'], 'answer': 'C', 'question': 'Assume accounts have normal balances, 

## Load models

Then, you can easily load VLM models via promptbench.

In [3]:
# print all supported models in promptbench
print('All supported models: ')
print(pb.SUPPORTED_MODELS_VLM)

# load a model, llava-1.5-7b, for instance.
model = pb.VLMModel(model='llava-hf/llava-1.5-7b-hf', max_new_tokens=2048, temperature=0.0001, device='cuda')

All supported models: 
['Salesforce/blip2-opt-2.7b', 'Salesforce/blip2-opt-6.7b', 'Salesforce/blip2-flan-t5-xl', 'Salesforce/blip2-flan-t5-xxl', 'llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-13b-hf', 'gemini-pro-vision', 'gpt-4-vision-preview', 'Qwen/Qwen-VL', 'Qwen/Qwen-VL-Chat', 'qwen-vl-plus', 'qwen-vl-max', 'internlm/internlm-xcomposer2-vl-7b']


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


## Construct prompts

Prompts are the key interaction interface to VLMs. You can easily construct a prompt by call the Prompt API.

In [4]:
# Prompt API supports a list, so you can pass multiple prompts at once.
prompts = pb.Prompt([
    "You are a helpful assistant. Here is the question:{question}\nANSWER:",
    "USER:{question}\nANSWER:",
])

## Perform evaluation using prompts, datasets, and models

Finally, you can perform standard evaluation using the loaded prompts, datasets, and labels.

In [5]:
from tqdm import tqdm
for prompt in prompts:
    preds = []
    labels = []
    for data in tqdm(dataset):
        # process input
        input_text = pb.InputProcess.basic_format(prompt, data)
        input_images = data['images']  # please use data['image_paths'] instead of data['images'] for models that only support image path/url, such as GPT-4v
        label = data['answer']
        raw_pred = model(input_images, input_text)
        # process output
        pred = pb.OutputProcess.pattern_split(raw_pred, 'ANSWER:')
        preds.append(pred)
        labels.append(label)
    
    # evaluate
    score = pb.Eval.compute_cls_accuracy(preds, labels)
    print(f"{score:.3f}, {repr(prompt)}")

  0%|          | 0/900 [00:00<?, ?it/s]

100%|██████████| 900/900 [17:35<00:00,  1.17s/it]  


0.333, 'You are a helpful assistant. Here is the question:{question}\nANSWER:'


100%|██████████| 900/900 [17:27<00:00,  1.16s/it]  

0.316, 'USER:{question}\nANSWER:'



