In [1]:
!pip install -U transformers accelerate safetensors httpx fastapi uvicorn qwen-vl-utils
%pip install -U bitsandbytes


Collecting transformers
  Downloading transformers-5.1.0-py3-none-any.whl.metadata (31 kB)
Collecting fastapi
  Downloading fastapi-0.129.0-py3-none-any.whl.metadata (30 kB)
Collecting qwen-vl-utils
  Downloading qwen_vl_utils-0.0.14-py3-none-any.whl.metadata (9.0 kB)
Collecting av (from qwen-vl-utils)
  Downloading av-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Downloading transformers-5.1.0-py3-none-any.whl (10.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m66.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fastapi-0.129.0-py3-none-any.whl (102 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading qwen_vl_utils-0.0.14-py3-none-any.whl (8.1 kB)
Downloading av-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl (41.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.2/41.2 MB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25

In [4]:
%%writefile app.py

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import time
import json
import queue
import torch
import httpx
import logging
import threading
import base64
from dataclasses import dataclass
from typing import Dict, Optional, Any, Union, List
from io import BytesIO
from PIL import Image

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor
)
from qwen_vl_utils import process_vision_info

# ------------------------
# Logging
# ------------------------

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("vlm_extractor")

app = FastAPI()

# ------------------------
# CONFIG
# ------------------------

MAX_NEW_TOKENS = 2048
MAX_RETRIES = 3
WEBHOOK_TIMEOUT = 120

# Using Qwen VLM model
MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"  # Using Qwen2-VL which is more stable

# ------------------------
# DATA MODEL
# ------------------------

@dataclass
class ExtractionTask:
    task_id: str
    prompt: str
    system_prompt: str
    template: str
    webhook_url: str
    retries_left: int  # Moved this BEFORE created_at (no default value)
    created_at: float  # Now this has a default value
    metadata: Dict[str, Any] = None
    images: Optional[List[str]] = None

    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}


# ------------------------
# VLM ENGINE
# ------------------------

class VLMProcessor:
    def __init__(self):
        self.model = None
        self.processor = None
        self.lock = threading.Lock()
        self.initialized = False

    def initialize(self):
        if self.initialized:
            return

        with self.lock:
            if self.initialized:
                return

            logger.info("Loading Qwen VLM model...")

            # Load model with appropriate precision
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                MODEL_ID,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                low_cpu_mem_usage=True
            )

            # Load processor for handling images and text
            self.processor = AutoProcessor.from_pretrained(MODEL_ID)

            self.model.eval()
            self.initialized = True
            logger.info("Qwen VLM model loaded successfully")

    def process_image(self, image_input: str) -> Image.Image:
        """Convert base64 string or URL to PIL Image"""
        if image_input.startswith(('http://', 'https://')):
            # Handle URL
            import requests
            response = requests.get(image_input, stream=True)
            response.raise_for_status()
            return Image.open(response.raw)
        else:
            # Handle base64
            try:
                # Check if it's base64
                image_data = base64.b64decode(image_input)
                return Image.open(BytesIO(image_data))
            except:
                # Assume it's a file path
                return Image.open(image_input)

    def extract(self,
                prompt_text: str,
                system_prompt: str,
                template: str,
                images: Optional[List[str]] = None) -> Union[dict, str]:
        """
        Extract structured data from text and images using provided system prompt and template.
        Supports both text-only and multimodal inputs.
        """
        self.initialize()

        # Format the template with the prompt
        formatted_template = template.replace("{prompt}", prompt_text)

        # Prepare messages in Qwen's chat format
        messages = []

        # Prepare user content
        user_content = []

        # Add images if provided
        if images and len(images) > 0:
            for img in images:
                try:
                    # For Qwen2-VL, we need to pass the image as a dict with "image" key
                    # The processor will handle the conversion
                    user_content.append({
                        "type": "image",
                        "image": img  # Can be URL, path, or PIL Image
                    })
                except Exception as e:
                    logger.warning(f"Failed to process image: {e}")

        # Add text prompt with system instruction
        text_content = system_prompt + "\n\nOutput JSON:\n" + formatted_template
        user_content.append({
            "type": "text",
            "text": text_content
        })

        messages.append({
            "role": "user",
            "content": user_content
        })

        # Apply chat template
        text = self.processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        # Prepare inputs
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        )

        # Move to model device if using GPU
        if torch.cuda.is_available():
            inputs = inputs.to(self.model.device)

        # Generate response
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.0,
                do_sample=False,
                repetition_penalty=1.1,
                eos_token_id=self.processor.tokenizer.eos_token_id
            )

        # Trim the input part from the output
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]

        # Decode the response
        raw = self.processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0]

        result = raw.strip()

        # Try to find JSON in the output
        json_start = result.find("{")
        json_end = result.rfind("}")

        if json_start >= 0 and json_end > json_start:
            json_str = result[json_start:json_end + 1]

            # Repair common JSON issues
            json_str = json_str.replace("'", '"')

            # Try to parse the JSON
            try:
                parsed_json = json.loads(json_str)
                logger.info(f"Successfully parsed JSON from VLM output")
                return parsed_json
            except json.JSONDecodeError as e:
                logger.warning(f"Failed to parse JSON from VLM output: {e}")
                logger.warning(f"Raw output was: {result[:500]}...")
                # Return the raw text if JSON parsing fails
                return {
                    "raw_output": result,
                    "json_error": str(e),
                    "note": "VLM failed to produce valid JSON"
                }
        else:
            # No JSON found in output
            logger.warning("No JSON structure found in VLM output")
            return {
                "raw_output": result,
                "json_error": "No JSON structure found in output",
                "note": "VLM failed to produce JSON structure"
            }


# ------------------------
# SINGLE WORKER QUEUE
# ------------------------

class TaskQueue:
    def __init__(self):
        self.q = queue.Queue()
        self.worker = None
        self.stop_event = threading.Event()
        self.processor = VLMProcessor()

    def start(self):
        self.worker = threading.Thread(
            target=self.loop,
            daemon=True
        )
        self.worker.start()
        logger.info("Single worker started")

    def stop(self):
        self.stop_event.set()
        self.q.put(None)

    def enqueue(self, task: ExtractionTask):
        self.q.put(task)
        logger.info(f"Queued {task.task_id}")

    def loop(self):
        while not self.stop_event.is_set():
            task = self.q.get()
            if task is None:
                break

            self.process(task)
            self.q.task_done()

    def process(self, task: ExtractionTask):
        try:
            logger.info(f"Processing {task.task_id}")

            # Extract using provided system prompt, template, and images
            extraction_result = self.processor.extract(
                prompt_text=task.prompt,
                system_prompt=task.system_prompt,
                template=task.template,
                images=task.images
            )

            # Always include the result, whether it's valid JSON or raw text
            payload = {
                "task_id": task.task_id,
                "status": "completed",
                "extracted_data": extraction_result,
                "metadata": task.metadata
            }

            self.send_webhook(task.webhook_url, payload)

        except Exception as e:
            logger.error(f"Task {task.task_id} failed: {e}")

            if task.retries_left > 0:
                task.retries_left -= 1
                logger.info(f"Retrying {task.task_id}, {task.retries_left} retries left")
                self.q.put(task)
            else:
                fail = {
                    "task_id": task.task_id,
                    "status": "failed",
                    "error": str(e),
                    "metadata": task.metadata
                }
                self.send_webhook(task.webhook_url, fail)

        finally:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    def send_webhook(self, url, payload):
        try:
            with httpx.Client(timeout=WEBHOOK_TIMEOUT) as client:
                response = client.post(url, json=payload)
                response.raise_for_status()
                logger.info(f"Webhook sent successfully to {url}")
        except Exception as e:
            logger.error(f"Webhook error: {e}")


task_queue = TaskQueue()

# ------------------------
# API SCHEMA
# ------------------------

class ExtractionRequest(BaseModel):
    task_id: str
    prompt: str
    system_prompt: str
    template: str
    webhook_url: str
    images: Optional[List[str]] = None  # List of base64 images, URLs, or file paths
    metadata: Optional[Dict[str, Any]] = {}


# ------------------------
# FASTAPI
# ------------------------

@app.on_event("startup")
async def startup():
    task_queue.start()
    # Initialize in background to not block startup
    threading.Thread(target=task_queue.processor.initialize, daemon=True).start()

@app.on_event("shutdown")
async def shutdown():
    task_queue.stop()

@app.post("/extract")
async def extract(req: ExtractionRequest):
    try:
        task = ExtractionTask(
            task_id=req.task_id,
            prompt=req.prompt,
            system_prompt=req.system_prompt,
            template=req.template,
            webhook_url=req.webhook_url,
            metadata=req.metadata,
            images=req.images,
            retries_left=MAX_RETRIES,
            created_at=time.time()
        )

        task_queue.enqueue(task)

        return {"status": "queued", "task_id": req.task_id}

    except Exception as e:
        logger.error(f"Error queueing task: {e}")
        raise HTTPException(500, str(e))

@app.get("/health")
async def health():
    return {
        "status": "ok",
        "queue_size": task_queue.q.qsize(),
        "model": MODEL_ID,
        "model_loaded": task_queue.processor.initialized
    }


Overwriting app.py


In [None]:
!uvicorn app:app --host 0.0.0.0 --port 8000

[32mINFO[0m:     Started server process [[36m1322[0m]
[32mINFO[0m:     Waiting for application startup.
INFO:vlm_extractor:Single worker started
INFO:vlm_extractor:Loading Qwen VLM model...
[32mINFO[0m:     Application startup complete.
[32mINFO[0m:     Uvicorn running on [1mhttp://0.0.0.0:8000[0m (Press CTRL+C to quit)
INFO:httpx:HTTP Request: HEAD https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/config.json "HTTP/1.1 307 Temporary Redirect"
INFO:httpx:HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/Qwen/Qwen2-VL-7B-Instruct/eed13092ef92e448dd6875b2a00151bd3f7db0ac/config.json "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET https://huggingface.co/api/resolve-cache/models/Qwen/Qwen2-VL-7B-Instruct/eed13092ef92e448dd6875b2a00151bd3f7db0ac/config.json "HTTP/1.1 200 OK"
config.json: 0.00B [00:00, ?B/s]config.json: 1.20kB [00:00, 3.40MB/s]
INFO:httpx:HTTP Request: HEAD https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/adapter_config.

In [None]:
# curl -L -o cloudflared https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 && chmod +x cloudflared && ./cloudflared tunnel --url http://localhost:8000
