# Mini Prompting Pipeline for Image-Text Inference with InternVL2-Llama3-76B

following the quick start here:
https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B#quick-start
or here:
https://internvl.readthedocs.io/en/latest/internvl2.0/quick_start.html

Structure:
1. Split, load (and save) the model (on two 80GB GPUs)
2. Preprocess images
3. Mini pipeline taking images from a folder and inputting them into the model with the same prompt
4. Playground for video input

## 0. Preparations

In [None]:
# installs
!pip install transformers==4.37.2
!pip install timm
!pip install accelerate
!pip install bitsandbytes
!pip install decord
!pip install pandas
!pip install einops

In [None]:
# packages
import torch
import numpy as np
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import pandas as pd
import math
from PIL import Image
import os

## 1. Split and load the model

In [None]:
# split the model on 2 80GB GPUs

def split_model(model_name):
    device_map = {}
    world_size = torch.cuda.device_count()
    num_layers = {
        'InternVL2-1B': 24, 'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32,
        'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

In [None]:
# load the splitted model

path = "OpenGVLab/InternVL2-Llama3-76B"

device_map = split_model('InternVL2-Llama3-76B')

model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    load_in_8bit=True,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map).eval()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

In [None]:
'''
# save the model

model_save_name = "InternVL2-Llama3-76B.pt"
model_path = "..."

torch.save(model.state_dict(), model_path)

# load the saved model
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
'''

In [None]:
# put model in evaluation mode

model.eval()

## 2. Preprocessing of the images

In [None]:
# functions for preprocessing the input image

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

## 3. Generate Predictions

In [None]:
# generate one example prediction

# load image, set the max number of tiles in `max_num`
pixel_values = load_image("/images/image1.jpg", max_num=12).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=True)

# give 1 image and text as chat input (single-image single-round conversation) (find good prompt wording, insert template from LEIZA experts)
question = '<image>\nPretend to be an archivist who wants to catalog this photo card digitally. Write a description including different fields which are provided below. Also use the text on the photo card. ...'
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(f'Assistant: \n {response}')

In [None]:
# mini pipeline: generate predictions for all images in the folder and save it in csv

# prompt
question = '<image>\nPretend to be an archivist who wants to catalog this photo card digitally. Write a description including different fields which are provided below. Also use the text on the photo card. ...'

# images folder
image_folder = "/images"

# dataframe to store image, prompt and response
responses = pd.DataFrame(columns=['image', 'prompt', 'response'])

i = 0

#print(question)

for filename in os.listdir(image_folder):
  if filename.endswith('.jpg') or filename.endswith('.jpeg') or filename.endswith('.png'):  # Add other extensions if needed

    # open image from folder
    image_path = os.path.join(image_folder, filename)

    # load image, set the max number of tiles in `max_num`
    pixel_values = load_image(image_path, max_num=12).to(torch.bfloat16).cuda()
    generation_config = dict(max_new_tokens=1024, do_sample=True, temperature=0.01) # play around with temperature, num_beam and top_k

    # give image and text as chat input: single-image single-round conversation
    response = model.chat(tokenizer, pixel_values, question, generation_config)
    print(f'Image: {i}, {filename}\nAssistant: {response}')
    responses = pd.concat([responses, pd.DataFrame({'image': [filename], 'prompt': [question], 'response': [response]})], ignore_index=True)
    i += 1  

In [None]:
# save model in- and outputs in csv
responses_path = "/responses"
responses.to_csv(responses_path + "responses.csv")
responses.head()

---

## 4. Playground for video input

In [None]:
# video multi-round conversation

from decord import VideoReader, cpu

def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
    if bound:
        start, end = bound[0], bound[1]
    else:
        start, end = -100000, 100000
    start_idx = max(first_idx, round(start * fps))
    end_idx = min(round(end * fps), max_frame)
    seg_size = float(end_idx - start_idx) / num_segments
    frame_indices = np.array([
        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
        for idx in range(num_segments)
    ])
    return frame_indices

def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    max_frame = len(vr) - 1
    fps = float(vr.get_avg_fps())

    pixel_values_list, num_patches_list = [], []
    transform = build_transform(input_size=input_size)
    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
        img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(tile) for tile in img]
        pixel_values = torch.stack(pixel_values)
        num_patches_list.append(pixel_values.shape[0])
        pixel_values_list.append(pixel_values)
    pixel_values = torch.cat(pixel_values_list)
    return pixel_values, num_patches_list

# load the video
video_path = '/video01.avi'
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])

In [None]:
question = video_prefix + 'What is the man doing?'
# Frame1: <image>\nFrame2: <image>\n...\nFrame8: <image>\n{question}

response, history = model.chat(tokenizer, pixel_values, question, generation_config,
                               num_patches_list=num_patches_list, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')

question = 'Describe what happens in the video.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
                               num_patches_list=num_patches_list, history=history, return_history=True)
print(f'User: {question}\nAssistant: {response}')

---
---