In [13]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import cv2
from ultralytics import YOLO
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional
import os
import sys
from pathlib import Path

# Add the project root directory to Python path
project_root = Path.cwd()
while project_root.name.lower() != "aiml25-exam" and "aiml25-exam" in str(project_root).lower():
    project_root = project_root.parent
sys.path.append(str(project_root))

# Now you can import from src
from src.yolo.detect import run_custom_yolo
from src.edge_validator import EdgeValidator
from src.llm_detector import Detector
from src.llm_caller import LLMCaller
from src.yolo.yolo import Yolo
from src.graph import Graph

In [14]:

class ImageState(TypedDict):
    image_path: str
    is_diagram: Optional[bool]
    classification_confidence: Optional[float]
    diagram_detections: Optional[list]
    verbose: bool

In [15]:
def read_image(state: ImageState):
    if state["verbose"]:
        print(f"🔍 Reading image: {state['image_path']}")
    return {}

In [16]:
def classify_image(state: ImageState):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    image = Image.open(state["image_path"]).convert("RGB")
    inputs = processor(text=["a diagram", "not a diagram"], images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)
    
    is_diagram = probs.argmax() == 0
    return {"is_diagram": is_diagram, "classification_confidence": probs.max().item()}

In [17]:
def handle_non_diagram(state: ImageState):
    print("🚫 Not a diagram. Ending pipeline.")
    return {}

In [18]:
def route_decision(state: ImageState) -> str:
    return "run_yolo_detection" if state["is_diagram"] else "handle_non_diagram"

In [19]:
from src.yolo.detect import run_custom_yolo  # Import your helper function

def run_yolo_detection(state: ImageState):
    if state["verbose"]:
        print("📦 Running custom YOLO model...")
    detections = run_custom_yolo(state["image_path"])
    
    if state["verbose"]:
        for det in detections:
            print(f"Detected: {det['label']} at {det['coords']} (confidence: {det['confidence']:.2f})")

    return {"diagram_detections": detections}

In [20]:


def crop_diagram(state: ImageState):
    if state["verbose"]:
        print("✂️ Cropping detected parts of the diagram and saving to tmp/ ...")

    cropped_images = []
    image = Image.open(state["image_path"]).convert("RGB")

    os.makedirs("tmp", exist_ok=True)

    for idx, det in enumerate(state.get("diagram_detections", [])):
        x1, y1, x2, y2 = det["coords"]

        # ⚡ Convert to integers
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

        cropped = image.crop((x1, y1, x2, y2))
        cropped_images.append(cropped)

        save_path = f"tmp/cropped_{idx+1}.png"
        cropped.save(save_path)

        if state["verbose"]:
            print(f"💾 Saved cropped image to {save_path}")

    return {"cropped_images": cropped_images}

In [21]:
def detect_nodes(state: ImageState):
    if state["verbose"]:
        print("🔍 Detecting nodes in cropped images...")
    
    detector = Detector(model, yolo)
    detector.initiate_image(state["image_path"])
    detector.detect_nodes()
    
    return {"nodes_detected": True}

def detect_edges(state: ImageState):
    if state["verbose"]:
        print("🔍 Detecting edges between nodes...")
    
    detector.detect_edges()
    return {"edges_detected": True}

def create_graph(state: ImageState):
    if state["verbose"]:
        print("📊 Creating graph from detected nodes and edges...")
    
    graph = detector.get_graph()
    return {"graph": graph}

def validate_edges(state: ImageState):
    if state["verbose"]:
        print("✅ Validating detected edges...")
    
    validator = EdgeValidator.from_json_file(
        str(from_root("datasets/test/json/4.json")), 
        state["graph"].edges
    )
    validation_results = validator.validate()
    
    return {"validation_results": validation_results}


In [26]:
# Complete workflow definition
workflow = StateGraph(ImageState)

# Add nodes
workflow.add_edge(START, "read_image")
workflow.add_node("read_image", read_image)
workflow.add_node("classify_image", classify_image)
workflow.add_node("handle_non_diagram", handle_non_diagram)
workflow.add_node("detect_nodes", detect_nodes)
workflow.add_node("detect_edges", detect_edges)
workflow.add_node("create_graph", create_graph)
workflow.add_node("validate_edges", validate_edges)

# Add edges
workflow.add_edge("read_image", "classify_image")
workflow.add_conditional_edges(
    "classify_image",
    {
        "handle_non_diagram": lambda x: not x["is_diagram"],
        "detect_nodes": lambda x: x["is_diagram"]
    }
)
workflow.add_edge("detect_nodes", "detect_edges")
workflow.add_edge("detect_edges", "create_graph")
workflow.add_edge("create_graph", "validate_edges")
workflow.add_edge("validate_edges", END)
workflow.add_edge("handle_non_diagram", END)

# Compile the workflow
compiled_graph = workflow.compile()

In [35]:
# Define once here — used everywhere
image_path = "../datasets/test/images/3.png"

# Run the agent
result = compiled_graph.invoke({
    "image_path": image_path,
    "verbose": True
})

🔍 Reading image: ../datasets/test/images/3.png


In [36]:

#Visualization Helper of bounding boxes
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_image_with_boxes(image_path, detections):
    image = Image.open(image_path).convert("RGB")
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(image)

    for det in detections:
        x1, y1, x2, y2 = det["coords"]
        width = x2 - x1
        height = y2 - y1
        rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(x1, y1 - 5, f"{det['label']} ({det['confidence']:.2f})", color='white', backgroundcolor='red')

    plt.axis("off")
    plt.show()

In [37]:

# Conditionally show results
if result.get("is_diagram") and "diagram_detections" in result:
    show_image_with_boxes(image_path, result["diagram_detections"])
else:
    print("❌ This image was not classified as a diagram. No detection results to show.")

❌ This image was not classified as a diagram. No detection results to show.


In [38]:
for idx, img in enumerate(result.get("cropped_images", [])):
    print(f"Cropped part {idx+1}:")
    display(img)