In [1]:
import torch
import math
import sys
import os
from transformers import AutoModel, AutoTokenizer

# Add the desired directory to the Python path
sys.path.append(os.path.abspath('/data/students/earl/llava-dissector/NVLM-D-72B'))
os.environ["CUDA_VISIBLE_DEVICES"] = "5,6,7,1,2,3,0,4"

import urllib
from io import BytesIO
from PIL import Image

from typing import Optional, Union, List, Tuple

def split_model():
    device_ids = [5, 6, 7]  # List of GPU IDs you want to use
    device_map = {}
    world_size = len(device_ids)
    num_layers = 80
    num_layers_per_gpu = num_layers // world_size
    remainder = num_layers % world_size
    layer_cnt = 0
    for i, device_id in enumerate(device_ids):
        layers_this_gpu = num_layers_per_gpu + (1 if i < remainder else 0)
        for _ in range(layers_this_gpu):
            device_map[f'language_model.model.layers.{layer_cnt}'] = device_id
            layer_cnt += 1
    # Assign other modules to the first GPU
    device_map['vision_model'] = device_ids[0]
    device_map['mlp1'] = device_ids[0]
    device_map['language_model.model.tok_embeddings'] = device_ids[0]
    device_map['language_model.model.embed_tokens'] = device_ids[0]
    device_map['language_model.output'] = device_ids[0]
    device_map['language_model.model.norm'] = device_ids[0]
    device_map['language_model.lm_head'] = device_ids[0]
    device_map['language_model.model.rotary_emb'] = device_ids[0]
    device_map[f'language_model.model.layers.{num_layers - 1}'] = device_ids[0]
    return device_map


model_name = "/data/students/earl/llava-dissector/NVLM-D-72B"
device_map = split_model()
model = AutoModel.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=False,
    trust_remote_code=True,
    device_map=device_map
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)

generation_config = dict(max_new_tokens=1024, do_sample=False)
#device = "cuda:5" #model.device



  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards:  24%|██▍       | 11/46 [00:16<00:53,  1.52s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 462.00 MiB. GPU 5 has a total capacity of 39.38 GiB of which 263.38 MiB is free. Including non-PyTorch memory, this process has 39.12 GiB memory in use. Of the allocated memory 38.71 GiB is allocated by PyTorch, and 890.00 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import torch
import math
import requests
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from PIL import Image

import sys
import os
import io


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    # Check if image_file is a URL
    if isinstance(image_file, str) and (image_file.startswith('http://') or image_file.startswith('https://')):
        response = requests.get(image_file)
        image = Image.open(io.BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

cls = "tree at the far right"


url = "https://farm3.staticflickr.com/2402/2480652763_e6b62303ee_z.jpg"
#url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
pixel_values = load_image(url).to(torch.bfloat16) # Load image from URL
question = f"<image>\nPlease provide the bounding box coordinate of the region this sentence describes: <ref>{cls}</ref>"
response, history = model.chat(tokenizer, pixel_values, question, generation_config)
print(f'User: {question}\nAssistant: {response}')


## Visualize using cv2

# Extract the bounding box coordinates from the output
# Make sure the list is a list of floats
# Example output: "Bounding box coordinates: [[x1, y1, x2, y2], [x1, y1, x2, y2]]"
import re
import matplotlib.pyplot as plt

# regex the part of string to extract bounding boxes
import re

pattern = r'\[([0-9]+), ([0-9]+), ([0-9]+), ([0-9]+)\]'
matches = re.findall(pattern, response)
bounding_boxes = [list(map(int, match)) for match in matches]
print(f'Bounding Boxes: {bounding_boxes}')

# Load the image directly as a NumPy array in BGR format
#img = cv2.imread('/data/students/earl/llava-dissector/InternVL2_5-8B-MPO/examples/image1.jpg')
# Load the original image from URL
response = requests.get(url)
orig_img = Image.open(io.BytesIO(response.content)).convert('RGB')
orig_img_np = np.array(orig_img)  # HWC, RGB, uint8
img = cv2.cvtColor(orig_img_np, cv2.COLOR_RGB2BGR)


img_width, img_height = img.shape[1], img.shape[0]
print(f'Image Width: {img_width}, Image Height: {img_height}')

# Using cv2 to draw bounding boxes on the image
import cv2
import numpy as np

label = cls  # Example label, you can modify this as needed
for bbox in bounding_boxes:
    print(f'Drawing box: {bbox}')
    box = bbox
    #box = normalize_coordinates(bbox, image_width=img_width, image_height=img_height)
    #print(f'Normalized box: {box}')
    cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
    cv2.putText(img, label, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

plt.figure(figsize=(10, 10))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show() 
