In [None]:
import gc
import os
import json
import time
import math
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from google.colab import drive
from transformers import AutoTokenizer, AutoModel, AutoConfig

# Memory management optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Google Drive connection
drive.mount('/content/drive')

# Constants for image preprocessing
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    """Build image transformation pipeline"""
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """Find the closest aspect ratio for dynamic preprocessing"""
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    """Dynamic preprocessing for better image handling"""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # Calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # Find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # Calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # Resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # Split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image_file, input_size=448, max_num=12):
    """Load and preprocess image for InternVL3"""
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values





def is_image_file(filename):
    return filename.lower().endswith(('.png', '.jpg', '.jpeg'))


def run_internvl3_on_existing_json(image_root, input_json_path, model_path = "OpenGVLab/InternVL3-8B"):
    """Appends InternVL3 result to existing JSON file."""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("Initializing InternVL3 model...")

    model_name = model_path.split("/")[-1].lower()

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)

    # Single GPU setup
    print("Using single GPU setup")
    model = AutoModel.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        use_flash_attn=True,
        trust_remote_code=True
    ).eval().cuda()

    # Load current JSON
    with open(input_json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    start_time = time.time()

    for filename in tqdm([k for k in data if not k.startswith("_")], desc="OCR process"):
        for root, _, files in os.walk(image_root):
            if 'text_category' in root:
                continue
            if filename in files:
                img_path = os.path.join(root, filename)
                try:
                    # Load and preprocess image
                    pixel_values = load_image(img_path, input_size=448, max_num=12).to(torch.bfloat16).cuda()

                    # InternVL3 prompt for Turkish OCR
                    question = '<image>\nExtract the Turkish text from the image exactly as it appears. Do not repeat any text, do not add any comments, and do not generate additional text. Only return the raw text found in the image, without any formatting or repetition. If the text or image is not clear to extract text or doing ocr process, just give me blank(empty) output.'

                    # Generation configuration
                    generation_config = dict(max_new_tokens=512, do_sample=False, temperature=0.0)

                    # Generate response
                    with torch.inference_mode():
                        extracted_text = model.chat(tokenizer, pixel_values, question, generation_config)

                    # Clean up
                    del pixel_values
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()

                    # Write OCR results
                    data[filename]["models"][model_name] = {
                        "prediction": extracted_text,
                        "cer": None,
                        "wer": None
                    }

                except Exception as e:
                    print(f"{filename} error: {str(e)}")
                break 
    
    elapsed = round(time.time() - start_time,2)
    print(f"\n OCR completed: {elapsed:.2f} seconds")

    meta = data.get("_meta", {})
    processing_times = meta.get("processing_times", {})
    processing_times[model_name] = elapsed
    meta["processing_times"] = processing_times

    data["_meta"] = meta

    # save updated JSON
    with open(input_json_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f" Updated JSON saved: {input_json_path}")
    return data

input_json_path = '/content/drive/MyDrive/nutuk/benchmark/converted_data.json'
image_root = '/content/drive/MyDrive/nutuk/benchmark/'

# start OCR process
updated = run_internvl3_on_existing_json(image_root, input_json_path)