In [1]:
%%capture
!pip install -U "transformers>=4.45.0" accelerate torch qwen-vl-utils timm "Pillow<11.0.0"


In [None]:
import os
import json
import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

In [None]:
# =========================
# CONFIG
# =========================
MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
IMAGE_DIR = "/kaggle/input/eqn-img/extracted_eqs/"
OUTPUT_JSON = "/kaggle/working/equations.json"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =========================
# LOAD MODEL & PROCESSOR
# =========================
# We use torch_dtype="auto" as per your structure
model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_ID, 
    torch_dtype="auto", 
    device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(MODEL_ID)

In [5]:
# =========================
PROMPT_TEXT = (
    "Extract the mathematical equation from this image into valid LaTeX. "
    "If there is a label (like 1.1 or 6.3.2), extract only the number for the 'label' field. "
    "Return ONLY a valid JSON object without any markdown formatting or code blocks. "
    "Format: {\"latex\": \"...\", \"label\": \"...\"}. "
    "Do not wrap the response in ```json or ``` tags."
)

# =========================
# HELPER FUNCTION
# =========================
def extract_json(response):
    """
    Robust JSON extraction from model response.
    Handles markdown code blocks, extra text, and malformed outputs.
    """
    # Remove markdown code blocks
    response = re.sub(r'```json\s*', '', response)
    response = re.sub(r'```\s*', '', response)
    response = response.strip()
    
    # Try to find JSON object using regex
    json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response, re.DOTALL)
    if json_match:
        json_str = json_match.group(0)
    else:
        json_str = response
    
    try:
        parsed = json.loads(json_str)
        return {
            "latex": parsed.get("latex", ""),
            "label": parsed.get("label"),
            "status": "success"
        }
    except json.JSONDecodeError as e:
        # Fallback: try to extract latex and label with regex
        latex_match = re.search(r'"latex"\s*:\s*"([^"]*(?:\\.[^"]*)*)"', response)
        label_match = re.search(r'"label"\s*:\s*"([^"]*)"', response)
        
        return {
            "latex": latex_match.group(1) if latex_match else response,
            "label": label_match.group(1) if label_match else None,
            "status": "fallback_regex"
        }

# =========================
# PROCESS IMAGES
# =========================
results = []
os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)

image_files = sorted([f for f in os.listdir(IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))])
print(f"ðŸš€ Processing {len(image_files)} images on {DEVICE}...")

for idx, img_name in enumerate(image_files):
    img_path = os.path.join(IMAGE_DIR, img_name)
    
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img_path},
                {"type": "text", "text": PROMPT_TEXT},
            ],
        }
    ]
    
    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(DEVICE)
    
    # Inference
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
    
    # Extract JSON
    parsed = extract_json(response)
    
    results.append({
        "image": img_name,
        "latex": parsed.get("latex"),
        "label": parsed.get("label")
    })
    
    print(f"âœ… [{idx+1}/{len(image_files)}] Extracted: {img_name} (Status: {parsed.get('status')})")

# =========================
# SAVE OUTPUT
# =========================
with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print(f"\nâœ¨ Done! Saved to {OUTPUT_JSON}")

# Optional: Print summary
success_count = sum(1 for r in results if r.get("status") == "success")
print(f"ðŸ“Š Successfully parsed: {success_count}/{len(results)}")

ðŸš€ Processing 14 images on cuda...
âœ… [1/14] Extracted: eq_00000_p1.png (Status: fallback_regex)
âœ… [2/14] Extracted: eq_00001_p2.png (Status: success)
âœ… [3/14] Extracted: eq_00002_p3.png (Status: success)
âœ… [4/14] Extracted: eq_00003_p4.png (Status: success)
âœ… [5/14] Extracted: eq_00004_p5.png (Status: success)
âœ… [6/14] Extracted: eq_00005_p6.png (Status: fallback_regex)
âœ… [7/14] Extracted: eq_00006_p7.png (Status: success)
âœ… [8/14] Extracted: eq_00007_p7.png (Status: fallback_regex)
âœ… [9/14] Extracted: eq_00008_p7.png (Status: success)
âœ… [10/14] Extracted: eq_00009_p7.png (Status: fallback_regex)
âœ… [11/14] Extracted: eq_00010_p7.png (Status: fallback_regex)
âœ… [12/14] Extracted: eq_00011_p8.png (Status: success)
âœ… [13/14] Extracted: eq_00012_p8.png (Status: fallback_regex)
âœ… [14/14] Extracted: eq_00013_p8.png (Status: fallback_regex)

âœ¨ Done! Saved to /kaggle/working/equations.json
ðŸ“Š Successfully parsed: 7/14


In [4]:
import re
