In [None]:
"""
Demo: Step-by-step RASC inference
Shows objects, relationships, and caption generation in real-time
"""

import json
from pathlib import Path
from pprint import pprint

from src.models.inference import RASCInference  # your RASCInference class

# -----------------------------
# Configuration
# -----------------------------
CONFIG_PATH = "configs/config.yaml"  # Path to your config file
IMAGE_PATH = "sample_images/living_room.jpg"  # Example image
YOLO_WEIGHTS = "runs/detect/experiments/runs/yolo_experiment_1/weights/best.pt # Defaults to yolov8n.pt"
RELATIONSHIP_WEIGHTS = "models/relationship_predictor/neural_motifs_best.pt"  # Defaults to neural_motifs.pt
CAPTION_MODEL = "models/caption_generator/t5_scene_best.pt"  # Defaults to t5_scene

# -----------------------------
# Initialize pipeline
# -----------------------------
pipeline = RASCInference(
    config_path=CONFIG_PATH,
    yolo_weights=YOLO_WEIGHTS,
    relationship_weights=RELATIONSHIP_WEIGHTS,
    caption_model=CAPTION_MODEL
)

# -----------------------------
# Step 1: Detect Objects
# -----------------------------
print("\n=== Step 1: Object Detection ===")
objects = pipeline.detect_objects(IMAGE_PATH)
for i, (cls_id, bbox) in enumerate(objects):
    print(f"Object {i}: Class={cls_id}, BBox={bbox.tolist()}")

# -----------------------------
# Step 2: Predict Relationships
# -----------------------------
print("\n=== Step 2: Relationship Prediction ===")
relationships = pipeline.predict_relationships(objects)
for rel in relationships:
    print(rel)

# -----------------------------
# Step 3: Generate Caption
# -----------------------------
print("\n=== Step 3: Caption Generation ===")
caption = pipeline.generate_caption(relationships)
print(f"Generated Caption: {caption}")

# -----------------------------
# Step 4: Full Pipeline Results
# -----------------------------
print("\n=== Step 4: Full Pipeline Output ===")
results = pipeline.run(IMAGE_PATH, verbose=True)

# Optionally save results
OUTPUT_PATH = Path("results/demo_output.json")
with open(OUTPUT_PATH, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to {OUTPUT_PATH}")
