In [2]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import cv2
from ultralytics import YOLO
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional

In [4]:
class ImageState(TypedDict):
    image_path: str
    is_diagram: Optional[bool]
    classification_confidence: Optional[float]
    diagram_detections: Optional[list]
    verbose: bool

In [5]:
def read_image(state: ImageState):
    if state["verbose"]:
        print(f"🔍 Reading image: {state['image_path']}")
    return {}

In [6]:
def classify_image(state: ImageState):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    image = Image.open(state["image_path"]).convert("RGB")
    inputs = processor(text=["a diagram", "not a diagram"], images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)
    
    is_diagram = probs.argmax() == 0
    return {"is_diagram": is_diagram, "classification_confidence": probs.max().item()}

In [7]:
def handle_non_diagram(state: ImageState):
    print("🚫 Not a diagram. Ending pipeline.")
    return {}

In [8]:
def route_decision(state: ImageState) -> str:
    return "run_yolo_detection" if state["is_diagram"] else "handle_non_diagram"

In [9]:
graph = StateGraph(ImageState)
graph.add_node("read_image", read_image)
graph.add_node("classify_image", classify_image)
graph.add_node("handle_non_diagram", handle_non_diagram)
graph.add_node("run_yolo_detection", run_yolo_detection)

graph.add_edge(START, "read_image")
graph.add_edge("read_image", "classify_image")
graph.add_conditional_edges("classify_image", route_decision, {
    "run_yolo_detection": "run_yolo_detection",
    "handle_non_diagram": "handle_non_diagram"
})
graph.add_edge("run_yolo_detection", END)
graph.add_edge("handle_non_diagram", END)

compiled_graph = graph.compile()

NameError: name 'run_yolo_detection' is not defined

In [None]:
compiled_graph.invoke({
    "image_path": "datasets/test/images/1.png",
    "verbose": True
})