# SAM Interactive Test (Colab-compatible)

This notebook replicates the interactive click-based test from `custom_gym_implns/envs/sam_seg_env.py` using a Colab/Jupyter-friendly UI.

- Left click adds a positive prompt (foreground).
- Right click adds a negative prompt (background).
- If right-click doesn’t work in your environment, use the toggle to switch between positive/negative.

Requirements to run:
- The repository should be available in the runtime (this notebook is under `tests/`).
- The RepViT-SAM weights file should exist at `RepViT/sam/weights/repvit_sam.pt`.
- A COCO dataset should be available at `data/coco-dataset` or adjust `dataset_config` below accordingly.

Tip: In Colab, run each cell in order. If interactivity doesn’t work, ensure the ipympl widget backend is enabled.


In [None]:
# Environment setup and Colab detection (mirrors RL-CC-SAM/notebooks)
import sys
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("🌐 Running in Google Colab")
    from google.colab import drive
    drive.mount('/content/drive')

    # Set up paths for Colab
    DRIVE_ROOT = '/content/drive/MyDrive'
    PROJECT_ROOT = f'{DRIVE_ROOT}/RL-CC-SAM'

    # Change to project directory and expose it to Python
    os.chdir(PROJECT_ROOT)
    sys.path.append(PROJECT_ROOT)

    # Add prebuilt Colab env (same as other notebooks)
    sys.path.append(f"{DRIVE_ROOT}/colab_envs/llms/lib/python3.11/site-packages")

    print(f"📁 Working directory: {os.getcwd()}")
else:
    print("💻 Running locally")
    # This notebook is under AlignSAM-CVPR2024-Unofficial/tests/
    # Go two levels up to the repo root (RL-CC-SAM)
    PROJECT_ROOT = Path.cwd().parents[2]
    os.chdir(PROJECT_ROOT)
    sys.path.append(str(PROJECT_ROOT))

    print(f"📁 Working directory: {PROJECT_ROOT}")

# Keep compatibility with existing code below that expects `project_root` (string)
project_root = os.path.abspath(os.getcwd())
print('Project root:', project_root)


In [None]:
# Enable interactive backend for Jupyter/Colab
%config InlineBackend.figure_format = 'retina'
from IPython import get_ipython

try:
    # Enable ipywidgets support in Colab if available
    if 'google.colab' in sys.modules:
        try:
            from google.colab import output as colab_output
            colab_output.enable_custom_widget_manager()
        except Exception:
            pass
    import ipympl  # noqa: F401
    # Prefer ipympl via the Jupyter magic
    get_ipython().run_line_magic('matplotlib', 'widget')
    print('Using ipympl widget backend for interactivity.')
except Exception as e:
    # Fallbacks if ipympl is not available
    try:
        import matplotlib
        matplotlib.use('nbagg')
        print('ipympl not available; using nbagg backend.')
    except Exception as e2:
        get_ipython().run_line_magic('matplotlib', 'inline')
        print('ipympl/nbagg unavailable, falling back to inline (limited interactivity).', e2)

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
import cv2


In [None]:
# Ensure AlignSAM submodule is importable when running from RL-CC-SAM root
ALIGN_SAM_DIR = os.path.join(project_root, 'AlignSAM-CVPR2024-Unofficial')
if os.path.isdir(ALIGN_SAM_DIR) and ALIGN_SAM_DIR not in sys.path:
    sys.path.insert(0, ALIGN_SAM_DIR)

# Ensure Python can import `repvit_sam` package
# The package lives under <root>/RepViT/sam/repvit_sam (or the same inside AlignSAM)
REP_VIT_SAM_PARENTS = [
    os.path.join(project_root, 'RepViT', 'sam'),
    os.path.join(ALIGN_SAM_DIR, 'RepViT', 'sam'),
]
for parent in REP_VIT_SAM_PARENTS:
    if os.path.isdir(os.path.join(parent, 'repvit_sam')) and parent not in sys.path:
        sys.path.insert(0, parent)

# Import the environment that encapsulates SAM behavior
from custom_gym_implns.envs.sam_seg_env import SamSegEnv

# Configuration mirroring sam_seg_env.py's __main__ block
img_shape = (375, 500, 3)  # HxWxC
embedding_shape = (256, 64, 64)  # CxHxW
mask_shape = (256, 256)  # HxW
render_frame_shape = (320, 426)  # HxW
max_steps = 5
penalize_for_wrong_input = False
use_dice_score = True
img_patch_size = 32
render_mode = 'rgb_array'  # ensure render() returns RGB arrays for notebook

target_categories = ['person', 'cat', 'dog', 'car', 'bicycle', 'bus']

# Resolve COCO data dir under either RL-CC-SAM root or AlignSAM subfolder
coco_candidates = [
    os.path.join(project_root, 'data', 'coco-dataset'),
    os.path.join(ALIGN_SAM_DIR, 'data', 'coco-dataset'),
]
resolved_coco_dir = next((p for p in coco_candidates if os.path.isdir(p)), coco_candidates[0])

dataset_config = {
    'type': 'coco',
    'data_dir': resolved_coco_dir,
    'data_type': 'val2017',
    'seed': 42,
    'max_instances': 5,
}

# Use top-level RepViT weights (present in RL-CC-SAM/RepViT) by default
sam_ckpt_fp = os.path.join(project_root, 'RepViT', 'sam', 'weights', 'repvit_sam.pt')
if not os.path.exists(sam_ckpt_fp):
    # Fallback to the AlignSAM submodule location
    sam_ckpt_fp = os.path.join(ALIGN_SAM_DIR, 'RepViT', 'sam', 'weights', 'repvit_sam.pt')

env = SamSegEnv(
    img_shape=img_shape,
    embedding_shape=embedding_shape,
    mask_shape=mask_shape,
    render_frame_shape=render_frame_shape,
    max_steps=max_steps,
    target_categories=target_categories,
    dataset_config=dataset_config,
    penalize_for_wrong_input=penalize_for_wrong_input,
    use_dice_score=use_dice_score,
    sam_ckpt_fp=sam_ckpt_fp,
    img_patch_size=img_patch_size,
    render_mode=render_mode,
)

obs, info = env.reset()
print('Action space size:', env.action_space.n)
print('Current target category:', obs['target_category'])


In [None]:
# Interactive loop using matplotlib clicks
import ipywidgets as widgets
from IPython.display import display, clear_output

fig, ax = plt.subplots(figsize=(12, 4))
plt.tight_layout()

mode = widgets.ToggleButtons(
    options=[('Positive', 'pos'), ('Negative', 'neg')],
    value='pos',
    description='Click mode:',
)
status = widgets.HTML(value='Ready. Left click to add points. Use toggle for label.')

out = widgets.Output()

# Internal state analogous to the OpenCV version
sample_action = env.action_space.n - 2  # default

render_img = env.render()  # RGB array
display_img = ax.imshow(render_img)
ax.set_axis_off()

# Helper: map matplotlib click to env action
h_fig, w_fig = render_img.shape[:2]
img_h, img_w = img_shape[:2]
rf_h, rf_w = render_frame_shape  # single panel size (H, W)

def canvas_to_image_coords(event):
    if event.xdata is None or event.ydata is None:
        return None
    x_full = np.clip(int(event.xdata), 0, w_fig - 1)
    y_full = np.clip(int(event.ydata), 0, h_fig - 1)
    # Map to single panel by modulo, same as OpenCV demo
    x_panel = x_full % rf_w
    y_panel = y_full % rf_h
    scaled_x = int(x_panel * img_w / rf_w)
    scaled_y = int(y_panel * img_h / rf_h)
    return scaled_x, scaled_y

@out.capture(clear_output=False)
def on_click(event):
    global sample_action
    if event.button is None:
        return
    # button: 1=left, 2=middle, 3=right
    if event.button == 1:
        tgt_label = mode.value  # rely on toggle for label
    elif event.button == 3:
        tgt_label = 'neg'
    else:
        return

    coords = canvas_to_image_coords(event)
    if coords is None:
        return

    sample_action = env.convert_raw_input_to_action(coords, tgt_label)
    step_and_render()

cid = fig.canvas.mpl_connect('button_release_event', on_click)

def step_and_render():
    global sample_action
    obs, reward, done, trunc, info = env.step(sample_action)
    with out:
        print('reward:', reward)
        print(info['last_input_labels'], info['last_input_points'])
        if done or trunc:
            print('Task done. Resetting...')
            env.reset()
    frame = env.render()
    display_img.set_data(frame)
    fig.canvas.draw_idle()
    sample_action = env.action_space.n - 1  # default to done

controls = widgets.HBox([mode])
display(controls, out)
plt.show()


In [None]:
# Utility to reset the environment manually
reset_btn = widgets.Button(description='Reset env', button_style='warning')

@out.capture(clear_output=False)
def on_reset_clicked(_):
    env.reset()
    frame = env.render()
    display_img.set_data(frame)
    fig.canvas.draw_idle()
    print('Environment reset.')

reset_btn.on_click(on_reset_clicked)
display(reset_btn)
