# Evaluate Single-Frame VLM variant


## Requirements
```
pip install git+https://github.com/haotian-liu/LLaVA.git
pip install decord==0.6.0
```

In [1]:
from IPython.display import clear_output

import os
from PIL import Image
import time
from tqdm import tqdm

import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from utils import get_ego_schema, calc_loglikelihood, download_ego_schema_center_frames
from vlm_inference import prepare_inputs

clear_output(wait=False)

In [2]:
# Load Model
disable_torch_init()
model_path = os.path.expanduser("liuhaotian/llava-v1.5-13b")
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)

# Setup conversation mode
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
    args.conv_mode = args.conv_mode + '_mmtag'
    print(
        f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')

# Load Data
dataset = get_ego_schema()
data_root = download_ego_schema_center_frames(save_path="temp_data")

clear_output(wait=False)

In [3]:
correct, total = 0, 0
st = time.time()
for index, datum in tqdm(enumerate(dataset), total=len(dataset)):
    # load center frame
    frame_path = f"{data_root}/{datum['q_uid']}.png"
    c_frame = Image.open(frame_path)

    batch, raw_prompts = prepare_inputs(c_frame, datum, model.config, tokenizer, image_processor)
    batch = {x: y.to(device='cuda', non_blocking=True) for x, y in batch.items()}
    batch['images'] = batch['images'].to(dtype=torch.float16)
    with torch.inference_mode():
        outputs = model(**batch)

    seq_len = batch['labels'].shape[-1]
    loss = calc_loglikelihood(outputs.logits.detach()[:, -seq_len:], batch['labels'])
    pred = loss.argmin().item()

    answer = datum['ans']
    correct += answer == pred
    total += 1
    if (total + 1) % 100 == 0:
        print(f"Accuracy: {correct / total}")
et = time.time()

print(f"Final Accuracy: {100 * correct / total} %")
print(f"Time Taken Per Iteration: {(et - st) / 500}")

 20%|████████████████████████████▌                                                                                                                   | 99/500 [02:46<11:11,  1.68s/it]

Accuracy: 0.5252525252525253


 40%|████████████████████████████████████████████████████████▉                                                                                      | 199/500 [05:36<08:28,  1.69s/it]

Accuracy: 0.5125628140703518


 60%|█████████████████████████████████████████████████████████████████████████████████████▌                                                         | 299/500 [08:27<05:42,  1.70s/it]

Accuracy: 0.5418060200668896


 80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████                             | 399/500 [11:17<02:53,  1.71s/it]

Accuracy: 0.5463659147869674


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋| 499/500 [14:08<00:01,  1.72s/it]

Accuracy: 0.5571142284569138


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [14:09<00:00,  1.70s/it]

Final Accuracy: 55.8 %
Time Taken Per Iteration: 1.6999149203300477



