In [4]:
import os
import json
import tqdm
import torch
import argparse
import datetime
import requests
from PIL import Image
from io import BytesIO
# from peft import PeftModel
from copy import deepcopy
from torch.utils.data import DataLoader,Dataset
from dataclasses import dataclass, field

from typing import Dict, Sequence, Optional,List
# from accelerate import PartialState,Accelerator
from tqdm import tqdm
from functools import partial
import threading

# from mhr.alignment.models.llava_v1_5.llava.utils import disable_torch_init
# from mhr.alignment.models.llava_v1_5.llava.model.builder import load_pretrained_model
# from mhr.alignment.models.llava_v1_5.llava.conversation import conv_templates, SeparatorStyle
# from mhr.alignment.models.llava_v1_5.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
# from mhr.alignment.models.llava_v1_5.llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.model.builder import load_pretrained_model
from transformers import HfArgumentParser

def initialize_model(model_path, device='cuda', peft_model_path=None):
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
                model_path=model_path, 
                model_base=None, 
                model_name=model_name,
                load_8bit=False, 
                load_4bit=False, 
                device=device,
            )
    if peft_model_path:
        model = PeftModel.from_pretrained(model, peft_model_path, adapter_name="dpo")
        print("peft model loaded")
    model.to(torch.float16)
    return tokenizer, model, image_processor, context_len




# print(inp)
model_path = "/mnt/petrelfs/songmingyang/songmingyang/model/others/llava-v1.5-13b"
tokenizer, model, image_processor, context_len = initialize_model(model_path, device="cuda")




You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 3/3 [00:57<00:00, 19.29s/it]


In [5]:
# image_path = "/mnt/petrelfs/songmingyang/code/mm/robustLMM/robustlmm/model_inference/llava_infer/samples/test1.jpg"
image_path = "/mnt/petrelfs/songmingyang/songmingyang/data/mm/imgs/llava_aug/controlaug/000000000595_aug.jpg"
inp = "Please Describe this image in detail"
image = Image.open(image_path).convert("RGB")
image_tensor = process_images([image], image_processor, model.config).to(model.dtype).to(model.device)

conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
inp = inp.strip().replace('\n', ' ').replace(DEFAULT_IMAGE_TOKEN, '').replace(DEFAULT_IM_START_TOKEN, '').replace(DEFAULT_IM_END_TOKEN, '').replace("<image>","")
assert DEFAULT_IMAGE_TOKEN not in inp
assert image is not None

if image is not None and DEFAULT_IMAGE_TOKEN not in inp:
    # first message
    if model.config.mm_use_im_start_end:
        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    else:
        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
    conv.append_message(conv.roles[0], inp)
    image = None
else:
    # later messages
    conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
assert prompt.count(DEFAULT_IM_START_TOKEN) == 0
# assert prompt.count("\n") == 0

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)


generation_num=1
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]   
    


stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)


In [2]:
with torch.inference_mode():
        output_ids = model.generate(
                inputs=input_ids,
                images=image_tensor,
                do_sample=False,
                temperature=0,
                max_new_tokens=512,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
        )
    # output_ids = self.model.generate(
    #                     inputs=input_ids,
    #                     images=image_tensors,
    #                     do_sample=False,
    #                     temperature=0,
    #                     max_new_tokens=512,
    #                     use_cache=True,
    #                     stopping_criteria=[stopping_criteria],
    #                     )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)     



In [3]:
outputs

['The image features a graph displaying sales data for a company over a period of time. The graph shows a steady increase in sales, with the sales numbers rising from 2009 to 2012. The sales data is presented in a clear and organized manner, making it easy to understand the trend. The graph is predominantly blue, with the sales numbers represented by various shades of blue, ranging from light to dark. The overall trend of the graph indicates a positive growth in sales for the company.']

In [None]:
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    with_flops=True,  # 启用 FLOPs 统计
    record_shapes=True,
    with_stack=True
) as prof:
     output_ids = model.generate(
                inputs=input_ids,
                images=image_tensor,
                do_sample=False,
                temperature=0,
                max_new_tokens=512,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
        )
total_flops = sum(event.flops for event in prof.key_averages() if event.flops is not None)
print(f"Total FLOPs: {convert_to_human_readable_size(total_flops)}")

STAGE:2025-01-31 02:58:43 60069:60069 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2025-01-31 02:59:02 60069:60069 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2025-01-31 02:59:03 60069:60069 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


In [9]:
def convert_to_human_readable_size(num):
    if num / 1e27 > 1:
        return f"{num / 1e27:.2f} R"
    elif num / 1e24 > 1:
        return f"{num / 1e24:.2f} Y"
    elif num / 1e21 > 1:
        return f"{num / 1e21:.2f} Z"
    elif num / 1e18 > 1:
        return f"{num / 1e18:.2f} E"
    elif num / 1e15 > 1:
        return f"{num / 1e15:.2f} P"
    elif num / 1e12 > 1:
        return f"{num / 1e12:.2f} T"
    elif num / 1e9 > 1:
        return f"{num / 1e9:.2f} B"
    elif num / 1e6 > 1:
        return f"{num / 1e6:.2f} M"
    elif num / 1e3 > 1:
        return f"{num / 1e3:.2f} K"
    else:
        return f"{num}"

In [12]:

print(f"Total FLOPs: {convert_to_human_readable_size(total_flops)}")

Total FLOPs: 19.83 T


In [11]:
balance = 581745
aug = 665298
real_flops =  total_flops * (aug-balance)
print(f"real flops:{convert_to_human_readable_size(real_flops)}")

real flops:1.66 E


In [13]:
balance = 1168639
aug = 1246901
real_flops =  total_flops * (aug-balance)
print(f"real flops:{convert_to_human_readable_size(real_flops)}")

real flops:1.55 E
