<a href="https://colab.research.google.com/github/kairamilanifitria/PurpleBox-Intern/blob/main/RAG/2_IMAGE_DESCRIPTION.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import re
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

In [None]:
def load_internvl_model():
    path = 'OpenGVLab/InternVL2_5-1B'
    model = AutoModel.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        use_flash_attn=True,
        trust_remote_code=True
    ).eval().cuda()
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
    return model, tokenizer

def build_transform(input_size=448):
    return T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])

def load_image(image_file, input_size=448):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size)
    pixel_values = transform(image).unsqueeze(0).to(torch.bfloat16).cuda()
    return pixel_values

def extract_images_and_context(markdown_path):
    with open(markdown_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    image_data = []
    for i, line in enumerate(lines):
        match = re.search(r'!\[.*?\]\((.*?)\)', line)
        if match:
            img_path = match.group(1)
            context_before = " ".join(lines[max(0, i-2):i]).strip()
            context_after = " ".join(lines[i+1:min(len(lines), i+3)]).strip()
            image_data.append((img_path, context_before, context_after))
    return image_data, lines

def generate_caption(model, tokenizer, image_path, context_before, context_after):
    if not os.path.exists(image_path):
        print(f"Warning: Image not found - {image_path}")
        return "[Image description unavailable]"

    pixel_values = load_image(image_path)
    prompt = f"<image>\nContext: {context_before} ... {context_after}. Please describe the image shortly."
    generation_config = dict(max_new_tokens=1024, do_sample=True)
    response = model.chat(tokenizer, pixel_values, prompt, generation_config)
    return response

def update_markdown(markdown_path, image_data, lines):
    new_lines = []
    for line in lines:
        new_lines.append(line)
        match = re.search(r'!\[.*?\]\((.*?)\)', line)
        if match:
            img_path = match.group(1)
            caption = next((desc for img, _, _, desc in image_data if img == img_path), "[Image description unavailable]")
            new_lines.append(f"\n*Image Description:* {caption}\n")
    with open(markdown_path, "w", encoding="utf-8") as f:
        f.writelines(new_lines)

def main(markdown_path, image_folder):
    model, tokenizer = load_internvl_model()
    image_data, lines = extract_images_and_context(markdown_path)
    enriched_data = []
    for img_path, context_before, context_after in image_data:
        full_image_path = os.path.join(image_folder, img_path)
        caption = generate_caption(model, tokenizer, full_image_path, context_before, context_after)
        enriched_data.append((img_path, context_before, context_after, caption))
    update_markdown(markdown_path, enriched_data, lines)
    print("Markdown updated with image descriptions!")

In [None]:
# Example usage md, IMAGE folder
main("_________.md", "{filename}_artifacts")