# SAM Interactive Test (Colab-compatible)

This notebook replicates the interactive click-based test from `custom_gym_implns/envs/sam_seg_env.py` using a Colab/Jupyter-friendly UI.

- Left click adds a positive prompt (foreground).
- Right click adds a negative prompt (background).
- If right-click doesn’t work in your environment, use the toggle to switch between positive/negative.

Requirements to run:
- The repository should be available in the runtime (this notebook is under `tests/`).
- The RepViT-SAM weights file should exist at `RepViT/sam/weights/repvit_sam.pt`.
- A COCO dataset should be available at `data/coco-dataset` or adjust `data_config` below accordingly.

Tip: In Colab, run each cell in order. If interactivity doesn’t work, ensure the ipympl widget backend is enabled.


In [None]:
# If running on Colab, uncomment below to clone and install (adjust as needed)
# !git clone https://github.com/<your-org-or-user>/AlignSAM-CVPR2024-Unofficial.git
# %cd AlignSAM-CVPR2024-Unofficial
# !pip install -r requirements.txt
# !pip install ipympl

import os
import sys

# Ensure project root is on sys.path (this notebook sits under tests/)
project_root = os.path.abspath(os.path.join(os.getcwd()))
sys.path.append(project_root)

print('Project root:', project_root)


In [None]:
# Enable interactive backend for Jupyter/Colab
# In Colab, this requires ipympl
%config InlineBackend.figure_format = 'retina'

try:
    import ipympl  # noqa: F401
    import matplotlib
    matplotlib.use('module://ipympl.backend_nbagg')
    print('Using ipympl backend for interactivity.')
except Exception as e:
    print('ipympl not available, falling back to inline backend (limited interactivity).', e)

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
import cv2


In [None]:
# Import the environment that encapsulates SAM behavior
from custom_gym_implns.envs.sam_seg_env import SamSegEnv

# Configuration mirroring sam_seg_env.py's __main__ block
img_shape = (375, 500, 3)  # HxWxC
embedding_shape = (256, 64, 64)  # CxHxW
mask_shape = (256, 256)  # HxW
render_frame_shape = (320, 426)  # HxW
max_steps = 5
penalize_for_wrong_input = False
use_dice_score = True
img_patch_size = 32
render_mode = 'rgb_array'  # ensure render() returns RGB arrays for notebook

target_categories = ['person', 'cat', 'dog', 'car', 'bicycle', 'bus']

data_config = {
    'type': 'coco',
    'data_dir': os.path.join(project_root, 'data', 'coco-dataset'),
    'data_type': 'val2017',
    'seed': 42,
    'max_instances': 5,
}

sam_ckpt_fp = os.path.join(project_root, 'RepViT', 'sam', 'weights', 'repvit_sam.pt')

env = SamSegEnv(
    img_shape=img_shape,
    embedding_shape=embedding_shape,
    mask_shape=mask_shape,
    render_frame_shape=render_frame_shape,
    max_steps=max_steps,
    target_categories=target_categories,
    data_config=data_config,
    penalize_for_wrong_input=penalize_for_wrong_input,
    use_dice_score=use_dice_score,
    sam_ckpt_fp=sam_ckpt_fp,
    img_patch_size=img_patch_size,
    render_mode=render_mode,
)

obs, info = env.reset()
print('Action space size:', env.action_space.n)
print('Current target category:', obs['target_category'])


In [None]:
# Interactive loop using matplotlib clicks
import ipywidgets as widgets
from IPython.display import display, clear_output

fig, ax = plt.subplots(figsize=(12, 4))
plt.tight_layout()

mode = widgets.ToggleButtons(
    options=[('Positive', 'pos'), ('Negative', 'neg')],
    value='pos',
    description='Click mode:',
)
status = widgets.HTML(value='Ready. Left click to add points. Use toggle for label.')

out = widgets.Output()

# Internal state analogous to the OpenCV version
sample_action = env.action_space.n - 2  # default

render_img = env.render()  # RGB array
display_img = ax.imshow(render_img)
ax.set_axis_off()

# Helper: map matplotlib click to env action
h_fig, w_fig = render_img.shape[:2]
img_h, img_w = img_shape[:2]
rf_h, rf_w = render_frame_shape  # single panel size (H, W)

def canvas_to_image_coords(event):
    if event.xdata is None or event.ydata is None:
        return None
    x_full = np.clip(int(event.xdata), 0, w_fig - 1)
    y_full = np.clip(int(event.ydata), 0, h_fig - 1)
    # Map to single panel by modulo, same as OpenCV demo
    x_panel = x_full % rf_w
    y_panel = y_full % rf_h
    scaled_x = int(x_panel * img_w / rf_w)
    scaled_y = int(y_panel * img_h / rf_h)
    return scaled_x, scaled_y

@out.capture(clear_output=False)
def on_click(event):
    global sample_action
    if event.button is None:
        return
    # button: 1=left, 2=middle, 3=right
    if event.button == 1:
        tgt_label = mode.value  # rely on toggle for label
    elif event.button == 3:
        tgt_label = 'neg'
    else:
        return

    coords = canvas_to_image_coords(event)
    if coords is None:
        return

    sample_action = env.convert_raw_input_to_action(coords, tgt_label)
    step_and_render()

cid = fig.canvas.mpl_connect('button_release_event', on_click)

def step_and_render():
    global sample_action
    obs, reward, done, trunc, info = env.step(sample_action)
    with out:
        print('reward:', reward)
        print(info['last_input_labels'], info['last_input_points'])
        if done or trunc:
            print('Task done. Resetting...')
            env.reset()
    frame = env.render()
    display_img.set_data(frame)
    fig.canvas.draw_idle()
    sample_action = env.action_space.n - 1  # default to done

controls = widgets.HBox([mode])
display(controls, out)
plt.show()


In [None]:
# Utility to reset the environment manually
reset_btn = widgets.Button(description='Reset env', button_style='warning')

@out.capture(clear_output=False)
def on_reset_clicked(_):
    env.reset()
    frame = env.render()
    display_img.set_data(frame)
    fig.canvas.draw_idle()
    print('Environment reset.')

reset_btn.on_click(on_reset_clicked)
display(reset_btn)
