# VLM-R3 Interactive Reasoning Demo

This notebook provides a polished, single-sample demonstration pipeline for the
V*R-R3 visual reasoning agent.  It focuses on:

- **Running** the agent on either an image or a video (the latter is converted
  into a cinematic frame grid automatically).
- **Capturing** the reasoning trace for every iteration, including tool calls
  and bounding boxes.
- **Visualising** each step with a sleek, high-tech styled layout that aligns
  the textual explanation with the spatial grounding.
- **Exporting** an artefact that can be consumed by the companion web demo to
  reproduce the animated experience showcased on the project website.

> ⚠️ Running this notebook requires the VLM checkpoint, vLLM, and the video
> dependencies (`PyAV`, FFmpeg) to be available in the execution environment.


In [None]:
from __future__ import annotations

import json
import math
import shutil
import sys
import textwrap
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import av
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import HTML, Markdown, display
from matplotlib import patches
from PIL import Image, ImageDraw
from transformers import AutoProcessor
from vllm import LLM

plt.style.use("dark_background")
plt.rcParams.update({
    "figure.figsize": (12, 6),
    "font.size": 12,
    "axes.facecolor": "#05060A",
    "savefig.facecolor": "#05060A",
})

REPO_ROOT = Path.cwd().resolve()
if not (REPO_ROOT / "src").exists():
    REPO_ROOT = REPO_ROOT.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from qwen_vl_utils import smart_resize
from src.model.r3_vllm import AgentVLMVLLM
from src.utils.reasoning_viz import (
    describe_run_for_export,
    parse_reasoning_steps,
    remap_bbox,
    extract_answer_text,
)


In [None]:
# ==== Configuration ======================================================
MODEL_PATH = Path("/path/to/your/VLM-R3-7b-rl-v1")
MEDIA_PATH = Path("/path/to/your/media_file.mp4")
QUESTION = "Describe what is happening in this scene."
ANSWER_CHOICES: List[str] = []  # e.g. ["(A) Option one", "(B) Option two", ...]

# vLLM / agent parameters
DEVICE = "cuda"
TENSOR_PARALLEL = 1
GPU_MEMORY_UTILIZATION = 0.92
MAX_ITERATIONS = 8

# Visual sampling
NUM_FRAMES = 256
GRID_ROWS = 16
GRID_COLS = 16

# Vision encoder constraints (must match the checkpoint expectations)
MIN_PIXELS = 32 * 28 * 28
MAX_PIXELS = 8192 * 28 * 28
CROP_MIN_PIXELS = 32 * 28 * 28
CROP_MAX_PIXELS = 4096 * 28 * 28

# Output bookkeeping
OUTPUT_DIR = Path("demo_notebook_outputs")
RUN_NAME = "demo_run"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".flv", ".wmv", ".mpg", ".mpeg"}
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}

PALETTE = [
    "#19F7FF", "#FF6B9A", "#8CFF55", "#FDBF2D",
    "#A855F7", "#2FD3C9", "#FF7847", "#63A4FF",
]


def is_video(path: Path) -> bool:
    return path.suffix.lower() in VIDEO_EXTS


def decode_video_frames(video_path: Path) -> List[Image.Image]:
    container = av.open(str(video_path))
    frames: List[Image.Image] = []
    try:
        for frame in container.decode(video=0):
            frames.append(frame.to_image().convert("RGB"))
    finally:
        container.close()
    if not frames:
        raise RuntimeError(f"No frames could be decoded from {video_path}")
    return frames


def sample_frames_uniformly(frames: List[Image.Image], count: int) -> List[Image.Image]:
    if count <= 0:
        raise ValueError("count must be positive")
    if len(frames) <= count:
        return list(frames)
    indices = np.linspace(0, len(frames) - 1, count, dtype=int)
    return [frames[idx] for idx in indices]


def create_video_grid(video_path: Path, output_path: Path, num_frames: int) -> Path:
    if GRID_ROWS * GRID_COLS != num_frames:
        raise ValueError(
            f"Grid requires exactly {GRID_ROWS * GRID_COLS} frames, got {num_frames}."
        )
    frames = decode_video_frames(video_path)
    sampled = sample_frames_uniformly(frames, num_frames)

    width = sampled[0].width
    height = sampled[0].height
    grid = Image.new("RGB", (width * GRID_COLS, height * GRID_ROWS))

    for idx, frame in enumerate(sampled):
        row = idx // GRID_COLS
        col = idx % GRID_COLS
        grid.paste(frame, (col * width, row * height))

    output_path.parent.mkdir(parents=True, exist_ok=True)
    grid.save(output_path)
    return output_path


def format_reasoning_text(text: str, width: int = 70) -> str:
    text = text.strip()
    if not text:
        return "(No textual content emitted in this step.)"
    return "\n".join(textwrap.wrap(text, width=width))


def map_steps_for_display(steps, resized_size, base_size):
    mapped = []
    for step in steps:
        bbox_display = None
        if step.bbox is not None:
            bbox_display = remap_bbox(step.bbox, resized_size, base_size)
        mapped.append(
            {
                "index": step.index,
                "text": step.text,
                "bbox_agent": step.bbox,
                "bbox_display": bbox_display,
            }
        )
    return mapped


def render_composite_figure(base_image_path: Path, mapped_steps):
    fig, ax = plt.subplots(figsize=(8, 8))
    base = Image.open(base_image_path).convert("RGB")
    ax.imshow(base)
    ax.axis("off")

    for step in mapped_steps:
        bbox = step["bbox_display"]
        if bbox is None:
            continue
        color = PALETTE[(step["index"] - 1) % len(PALETTE)]
        x1, y1, x2, y2 = bbox
        rect = patches.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=2.5,
            edgecolor=color,
            facecolor="none",
            linestyle="-",
        )
        ax.add_patch(rect)
        ax.text(
            x1,
            max(0, y1 - 10),
            f"Step {step['index']}",
            color=color,
            fontsize=11,
            fontweight="bold",
            bbox=dict(facecolor="#05060A", alpha=0.65, edgecolor=color, pad=2),
        )

    ax.set_title("Composite grounding overview", fontsize=14, color="#ECF2FF")
    plt.tight_layout()
    base.close()
    return fig


def render_single_step(base_image_path: Path, mapped_steps, step_index: int):
    step = mapped_steps[step_index - 1]
    base = Image.open(base_image_path).convert("RGB")
    fig, axes = plt.subplots(1, 2, gridspec_kw={"width_ratios": [3.5, 2.2]})

    axes[0].imshow(base)
    axes[0].axis("off")

    color = PALETTE[(step_index - 1) % len(PALETTE)]
    bbox = step["bbox_display"]
    if bbox is not None:
        x1, y1, x2, y2 = bbox
        rect = patches.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=3,
            edgecolor=color,
            facecolor="none",
        )
        axes[0].add_patch(rect)
        axes[0].text(
            x1,
            max(0, y1 - 12),
            f"Focus {step_index}",
            color=color,
            fontsize=13,
            fontweight="bold",
            bbox=dict(facecolor="#05060A", alpha=0.7, edgecolor=color, pad=3),
        )
    else:
        axes[0].text(
            0.5,
            0.5,
            "No crop used",
            transform=axes[0].transAxes,
            ha="center",
            va="center",
            color="#AAAAAA",
            fontsize=14,
        )

    axes[1].axis("off")
    axes[1].text(
        0.05,
        0.98,
        format_reasoning_text(step["text"]),
        color=color,
        fontsize=12,
        fontfamily="monospace",
        va="top",
        ha="left",
        bbox=dict(facecolor="#0B111A", edgecolor=color, boxstyle="round,pad=0.6", alpha=0.9),
    )

    fig.suptitle(
        f"Reasoning step {step_index}",
        color=color,
        fontsize=16,
        fontweight="bold",
    )
    plt.tight_layout()
    base.close()
    return fig


In [None]:
# ==== Load processor, vLLM, and agent ===================================
if 'processor' in globals():
    print("✅ Reusing cached processor instance.")
else:
    print("Loading processor from", MODEL_PATH)
    processor = AutoProcessor.from_pretrained(
        str(MODEL_PATH),
        min_pixels=MIN_PIXELS,
        max_pixels=MAX_PIXELS,
    )

if 'llm' in globals():
    print("✅ Reusing cached vLLM instance.")
else:
    print("Spawning vLLM...")
    llm = LLM(
        model=str(MODEL_PATH),
        device=DEVICE,
        tensor_parallel_size=TENSOR_PARALLEL,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
        dtype=torch.bfloat16,
        limit_mm_per_prompt={"image": 16, "video": 0},
        mm_processor_kwargs={
            "max_pixels": MAX_PIXELS,
            "min_pixels": MIN_PIXELS,
        },
        max_model_len=8192 * 4,
    )

if 'agent' in globals():
    print("✅ Reusing cached AgentVLMVLLM instance.")
else:
    agent = AgentVLMVLLM(
        model=llm,
        processor=processor,
        temp_dir=str(OUTPUT_DIR / "crops"),
        device=DEVICE,
        min_pixels=MIN_PIXELS,
        max_pixels=MAX_PIXELS,
        temperature=0.0,
        crop_min_pixels=CROP_MIN_PIXELS,
        crop_max_pixels=CROP_MAX_PIXELS,
    )


In [None]:
# ==== Prepare media and run the agent ====================================
media_path = MEDIA_PATH.expanduser().resolve()
if not media_path.exists():
    raise FileNotFoundError(f"Media file not found: {media_path}")

if is_video(media_path):
    grid_path = OUTPUT_DIR / f"{media_path.stem}_grid_{NUM_FRAMES}.png"
    print(f"Decoding video and building {GRID_ROWS}x{GRID_COLS} grid...")
    base_image_path = create_video_grid(media_path, grid_path, NUM_FRAMES)
    media_kind = "video_grid"
else:
    base_image_path = media_path
    media_kind = "image"

base_image_path = base_image_path.resolve()
base_image = Image.open(base_image_path).convert("RGB")
base_width, base_height = base_image.size
resized_height, resized_width = smart_resize(
    base_height,
    base_width,
    factor=28,
    min_pixels=MIN_PIXELS,
    max_pixels=MAX_PIXELS,
)
base_image.close()

question_text = QUESTION.strip()
if ANSWER_CHOICES:
    question_text += " " + " ".join(ANSWER_CHOICES)

print("
🚀 Launching multi-turn reasoning...")
response_list, crop_images, full_response, img_messages, text_messages = agent.process(
    str(base_image_path),
    question_text,
    max_iterations=MAX_ITERATIONS,
)

raw_steps = parse_reasoning_steps(response_list)
mapped_steps = map_steps_for_display(
    raw_steps,
    resized_size=(resized_width, resized_height),
    base_size=(base_width, base_height),
)
final_answer = extract_answer_text(full_response)

print(f"
Captured {len(mapped_steps)} reasoning segments.")
if final_answer:
    print("Model answer:", final_answer)
else:
    print("Model answer could not be extracted from <answer> tags.")


In [None]:
# ==== Summaries ==========================================================
summary_lines = []
for step in mapped_steps:
    color = PALETTE[(step["index"] - 1) % len(PALETTE)]
    summary_lines.append(
        f"<span style='color:{color}; font-weight:600;'>Step {step['index']}:</span> "
        f"<code>{step['text'].replace('\n', ' ')}</code>"
    )

html_template = '''
<div style="background:#05060A;border:1px solid #1E2A3A;border-radius:8px;padding:16px;">
  <h3 style="color:#7BE6FF;margin-top:0;">Reasoning Trace</h3>
  <p style="color:#C8D5FF;">{question}</p>
  <ol style="color:#E0E7FF;">
    {items}
  </ol>
  <p style="color:#7BE6FF;font-weight:600;">Final Answer: {answer}</p>
</div>
'''

html_summary = html_template.format(
    question=question_text,
    items="\n".join(f"<li>{line}</li>" for line in summary_lines),
    answer=final_answer or "(not available)",
)

display(HTML(html_summary))
fig = render_composite_figure(base_image_path, mapped_steps)
plt.show()


In [None]:
# ==== Interactive step-by-step visualisation ============================
if not mapped_steps:
    display(Markdown("**No reasoning steps were captured.**"))
else:
    slider = widgets.IntSlider(
        min=1,
        max=len(mapped_steps),
        value=1,
        step=1,
        description="Step",
        continuous_update=False,
        style={"description_width": "initial"},
        layout=widgets.Layout(width="600px"),
    )

    output = widgets.Output()

    def update_visual(change=None):
        with output:
            output.clear_output(wait=True)
            fig = render_single_step(base_image_path, mapped_steps, slider.value)
            display(fig)
            plt.close(fig)

    slider.observe(update_visual, names="value")
    display(slider)
    display(output)
    update_visual()


In [None]:
# ==== Persist payload for the web visualiser ===========================
now = datetime.utcnow().isoformat() + "Z"
export_payload = describe_run_for_export(
    question=question_text,
    media_path=media_path,
    base_image_path=base_image_path,
    base_size=(base_width, base_height),
    resized_size=(resized_width, resized_height),
    response_chunks=response_list,
    full_response=full_response,
    media_type=media_kind,
)
export_payload.update(
    {
        "response_chunks": response_list,
        "steps_agent_space": [
            {
                "index": step["index"],
                "bbox": step["bbox_agent"],
                "text": step["text"],
            }
            for step in mapped_steps
        ],
        "resized_size": {"width": resized_width, "height": resized_height},
        "question_raw": QUESTION,
        "answer_choices": ANSWER_CHOICES,
        "run": {
            "timestamp": now,
            "model_path": str(MODEL_PATH),
            "device": DEVICE,
            "max_iterations": MAX_ITERATIONS,
            "num_frames": NUM_FRAMES if media_kind == "video_grid" else None,
        },
    }
)

output_json = OUTPUT_DIR / f"{RUN_NAME}_interactive_payload.json"
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(export_payload, f, ensure_ascii=False, indent=2)

print(f"Payload saved to {output_json}")
