### Napari tests

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os , sys
sys.path.append('..')
from pathlib import Path
cur_path = Path(os.getcwd()).parent

from matplotlib import pyplot as plt
import cmasher as cmr
import numpy as np
import seaborn as sns
sns.set_theme(style='white')
%config InlineBackend.figure_format = 'retina'

In [3]:
import napari
from napari.qt.threading import thread_worker
import warnings
warnings.simplefilter(action='always', category=FutureWarning)
import zarr
import shutil

In [4]:
viewer = napari.Viewer()

2025-01-30 10:31:58.840 python[53729:15536330] +[IMKClient subclass]: chose IMKClient_Modern
2025-01-30 10:31:58.840 python[53729:15536330] +[IMKInputSession subclass]: chose IMKInputSession_Modern


#### Initiate the SAM2 model like you do in Napari

In [5]:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from octron.sam2_octron.helpers.build_sam2_octron import build_sam2_octron

In [6]:
sam2_folder = Path('sam2_octron')
checkpoint = 'sam2.1_hiera_base_plus.pt' # under folder /checkpoints
model_cfg = 'sam2.1/sam2.1_hiera_b+.yaml' # under folder /configs
# ----------------------------------------------------------------------------
sam2_checkpoint = cur_path / sam2_folder / Path(f'checkpoints/{checkpoint}')
model_cfg = Path(f'configs/{model_cfg}')


predictor, device  = build_sam2_octron(config_file=model_cfg.as_posix(), 
                                       ckpt_path=sam2_checkpoint.as_posix(), 
                                       )


Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might give numerically different outputs and sometimes degraded performance on MPS. See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.
Uing device: mps


Loaded SAM2VideoPredictor OCTRON
Model image size: 1024


### From napari, after loading video file, extract info 

In [7]:
viewer.dims.set_point(0,0)
current_indices = viewer.dims.current_step
print(current_indices)

(0, 499, 499)


In [8]:
video_data = viewer.layers[0].data   # the whole video

print(f'video shape: {video_data.shape}')   
num_frames, height, width, channels = video_data.shape


video shape: (4067, 1000, 1000, 3)


In [9]:
# Create zarr store to save all resized images 
chunk_size = 20



# Create temp output dir 
sample_dir = cur_path / 'sample_data'
sample_dir.mkdir(exist_ok=True)
sample_data_zarr = sample_dir / 'sample_data.zarr'
if sample_data_zarr.exists():
    shutil.rmtree(sample_data_zarr)

# Assuming local store on fast SSD, so no compression employed for now 
store = zarr.storage.LocalStore(sample_data_zarr, read_only=False)
root = zarr.create_group(store=store)
image_zarr = root.create_array(name='masks',
                                shape=(num_frames, 3, predictor.image_size, predictor.image_size),  
                                chunks=(chunk_size, 3, predictor.image_size, predictor.image_size), 
                                fill_value=np.nan,
                                dtype='float32')


In [10]:
predictor.init_state(napari_viewer=viewer, video_layer_idx=0, zarr_store=image_zarr)
predictor.reset_state()

  return torch._C._nn.upsample_bicubic2d(


Initialized SAM2 model


In [11]:
def run_new_pred(frame_idx,
                 obj_id, 
                 label,
                 point):
    
    assert label in [0,1]
    # Run prediction
    #obj_id : give a unique id to each object we interact with (it can be any integers)
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                                                frame_idx=frame_idx,
                                                obj_id=obj_id,
                                                points=np.array([point],dtype=np.float32),
                                                labels=np.array([label], np.int32)
                                                )
    
    # Add the mask image as a new labels layer
    mask = (out_mask_logits[0] > 0).cpu().numpy().astype(np.uint8)
    current_label = obj_id+1
    if len(np.unique(mask))>1:
        mask[mask==np.unique(mask)[1]] = current_label 
    mask = mask.squeeze()
    return mask 

In [12]:
# Set up thread worker to deal with prefetching batches of images
@thread_worker
def thread_prefetch_images(batch_size):
    print('running prefetcher')
    global viewer
    current_indices = viewer.dims.current_step
    _ = predictor.images[slice(current_indices[0],current_indices[0]+batch_size)]
prefetcher_worker = thread_prefetch_images(chunk_size)   
prefetcher_worker.setAutoDelete(False)
prefetcher_worker.start()


running prefetcher


In [13]:
# Instantiate the mask and annotation layers 
# Keep them empty at start 
mask_layer_dummy = np.zeros((num_frames, height, width), dtype=np.uint8)
mask_layer_dummy.shape

colors = cmr.take_cmap_colors('cmr.tropical', 8, cmap_range=(0, 1),
                     return_fmt='int')
colors_norm = np.stack(colors) / 255.0  
cyclic_map = napari.utils.CyclicLabelColormap(
    np.hstack([np.stack(colors_norm), np.ones((len(colors), 1))])
)
labels_layer = viewer.add_labels(
    mask_layer_dummy, 
    name='Mask',  # Name of the layer
    opacity=0.4,  # Optional: opacity of the labels
    colormap=cyclic_map,
    blending='additive'  # Optional: blending mode
)
# Add the points layer to the viewer
points_layer = viewer.add_points(None, 
                                 ndim=3,
                                 name='Annotations', 
                                 scale=(1,1),
                                 size=40,
                                 border_color='dimgrey',
                                 border_width=.2,
                                 opacity=.6,
                                 )
# Store the initial length of the points data
previous_length_points = len(points_layer.data)


left_right_click = 'left'
def on_mouse_press(layer, event):
    '''
    Generic function to catch left and right mouse clicks
    '''
    global left_right_click
    if event.type == 'mouse_press':
        if event.button == 1:  # Left-click
            left_right_click = 'left'
        elif event.button == 2:  # Right-click
            left_right_click = 'right'     
    

def on_points_added(event):
    '''
    Function to run when points are added to the points layer
    '''
    
    global points_layer
    global labels_layer
    global left_right_click
    global prefetcher_worker
    global previous_length_points
    
    current_length = len(points_layer.data)
    if current_length > previous_length_points:
        previous_length_points = current_length 

        # # Execute prediction 
        newest_point_data =  points_layer.data[-1]
        if left_right_click == 'left':
            label = 1
            points_layer.face_color[-1] = [0.59607846, 0.98431373, 0.59607846, 1.]
            points_layer.symbol[-1] = 'o'
        elif left_right_click == 'right':
            label = 0
            points_layer.face_color[-1] = [1., 1., 1., 1.]
            points_layer.symbol[-1] = 'x'
        points_layer.refresh() 
        # Run prediction
        frame_idx  = int(newest_point_data[0])
        point_data = newest_point_data[1:][::-1]
        mask = run_new_pred(frame_idx=frame_idx,
                            obj_id=0,
                            label=label,
                            point=point_data,
                            )
        labels_layer.data[frame_idx,:,:] = mask
        labels_layer.refresh()   
        
        # Prefetch batch of images
        # This is done here since adding it as direct mouse interaction 
        # slows down the first prediction
        if not prefetcher_worker.is_running:
            prefetcher_worker.run()


points_layer.mouse_drag_callbacks.append(on_mouse_press)
points_layer.events.data.connect(on_points_added)

# Hide the transform, delete, and select buttons
qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]
getattr(qctrl, 'transform_button').setVisible(False)
getattr(qctrl, 'delete_button').setVisible(False)
getattr(qctrl, 'select_button').setVisible(False)
# Select the current, add tool for the points layer
viewer.layers.selection.active = points_layer
viewer.layers.selection.active.mode = 'add'

v0.6.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]


In [14]:
# # Add the shapes layer to the viewer
# shapes_layer = viewer.add_shapes(None, 
#                                  ndim=3,
#                                  name='Shape annotations', 
#                                  scale=(1,1),
#                                  edge_width=0,
#                                  face_color=np.array(cyclic_map.colors),
#                                  opacity=.6,
#                                  )
# # Store the initial length of the points data
# previous_length_points = len(shapes_layer.data)

In [15]:
import time

In [16]:
obj_id = 0


@thread_worker
def thread_predict(frame_idx, max_imgs):
    global labels_layer

    video_segments = {} 
    start_time = time.time()
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(start_frame_idx=frame_idx, 
                                                                                    max_frame_num_to_track=max_imgs):
        
        for i, out_obj_id in enumerate(out_obj_ids):
            video_segments[out_frame_idx] = {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()}
            
        #  Add the mask image as a new labels layer
        mask = video_segments[out_frame_idx][obj_id] # THIS NEEDS TO BE MADE LAYER SPECIFIC 
        current_label = obj_id+1
        if len(np.unique(mask))>1:
            mask[mask==np.unique(mask)[1]] = current_label 

        mask = mask.squeeze()
        labels_layer.data[out_frame_idx,:,:]= mask
        viewer.dims.set_point(0,out_frame_idx)
        labels_layer.refresh()
    end_time = time.time()
    print(f'start idx {frame_idx} | {max_imgs} frames in {end_time-start_time} s')
        
        

running prefetcher
running prefetcher
running prefetcher
running prefetcher
running prefetcher
running prefetcher


In [29]:
worker = thread_predict(frame_idx=viewer.dims.current_step[0], max_imgs=chunk_size)  # create "worker" object
#worker.returned.connect(viewer.add_image)  # connect callback functions
worker.start()  # start the thread!

Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.41it/s]


start idx 2042 | 20 frames in 8.734952926635742 s


In [17]:
#worker.run()

In [18]:
#worker.quit()  # stop the thread

In [21]:
#predictor.images.cur_cache_idx, predictor.images.cached_indices

In [18]:
import time 

In [26]:
# Test 
for i in range(0,4000,80):
    
    prediction_worker = thread_predict(frame_idx=i, max_imgs=chunk_size)  
    prediction_worker.setAutoDelete(True)
    #worker.returned.connect(viewer.add_image)  # connect callback functions
    prediction_worker.start()  
    print(f'Highest cached index {int(np.nanmax(predictor.images.cached_indices))}')
    time.sleep(12)



Highest cached index 100


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.98it/s]


start idx 0 | 20 frames in 7.041176080703735 s
Highest cached index 100


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.81it/s]


start idx 80 | 20 frames in 7.489584922790527 s
Highest cached index 100


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.53it/s]


start idx 160 | 20 frames in 8.30475401878357 s
Highest cached index 180


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.60it/s]


start idx 240 | 20 frames in 8.08763599395752 s
Highest cached index 260


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.63it/s]


start idx 320 | 20 frames in 7.994042873382568 s
Highest cached index 340


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.53it/s]


start idx 400 | 20 frames in 8.293858766555786 s
Highest cached index 420


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.46it/s]


start idx 480 | 20 frames in 8.553947925567627 s
Highest cached index 500


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.65it/s]


start idx 560 | 20 frames in 7.9137163162231445 s
Highest cached index 580


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.56it/s]


start idx 640 | 20 frames in 8.200807809829712 s
Highest cached index 660


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.44it/s]


start idx 720 | 20 frames in 8.609652996063232 s
Highest cached index 740


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.50it/s]


start idx 800 | 20 frames in 8.401429176330566 s
Highest cached index 820


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.46it/s]


start idx 880 | 20 frames in 8.550720930099487 s
Highest cached index 900


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.38it/s]


start idx 960 | 20 frames in 8.81672215461731 s
Highest cached index 980


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.57it/s]


start idx 1040 | 20 frames in 8.183781147003174 s
Highest cached index 1060


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.50it/s]


start idx 1120 | 20 frames in 8.39837121963501 s
Highest cached index 1140


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.42it/s]


start idx 1200 | 20 frames in 8.679450750350952 s
Highest cached index 1220


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.61it/s]


start idx 1280 | 20 frames in 8.048112154006958 s
Highest cached index 1300


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.51it/s]


start idx 1360 | 20 frames in 8.378323078155518 s
Highest cached index 1380


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.45it/s]


start idx 1440 | 20 frames in 8.580260992050171 s
Highest cached index 1460


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.67it/s]


start idx 1520 | 20 frames in 7.86220908164978 s
Highest cached index 1540


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.59it/s]


start idx 1600 | 20 frames in 8.10556697845459 s
Highest cached index 1620


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.39it/s]


start idx 1680 | 20 frames in 8.79525089263916 s
Highest cached index 1700


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.43it/s]


start idx 1760 | 20 frames in 8.638482809066772 s
Highest cached index 1780


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.38it/s]


start idx 1840 | 20 frames in 8.824990034103394 s
Highest cached index 1860


Predicting: 100%|██████████| 21/21 [00:09<00:00,  2.27it/s]


start idx 1920 | 20 frames in 9.234575748443604 s
Highest cached index 1940


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.45it/s]


start idx 2000 | 20 frames in 8.58573317527771 s
Highest cached index 2020


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.51it/s]


start idx 2080 | 20 frames in 8.366966009140015 s
Highest cached index 2100


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.47it/s]


start idx 2160 | 20 frames in 8.492623805999756 s
Highest cached index 2180


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.56it/s]


start idx 2240 | 20 frames in 8.19845700263977 s
Highest cached index 2260


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.64it/s]


start idx 2320 | 20 frames in 7.954664945602417 s
Highest cached index 2340


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.55it/s]


start idx 2400 | 20 frames in 8.227787971496582 s
Highest cached index 2420


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.41it/s]


start idx 2480 | 20 frames in 8.701306104660034 s
Highest cached index 2500


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.69it/s]


start idx 2560 | 20 frames in 7.8061230182647705 s
Highest cached index 2580


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.37it/s]


start idx 2640 | 20 frames in 8.866070985794067 s
Highest cached index 2660


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.44it/s]


start idx 2720 | 20 frames in 8.596689224243164 s
Highest cached index 2740


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.65it/s]


start idx 2800 | 20 frames in 7.922217845916748 s
Highest cached index 2820


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.56it/s]


start idx 2880 | 20 frames in 8.209043025970459 s
Highest cached index 2900


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.45it/s]


start idx 2960 | 20 frames in 8.572962045669556 s
Highest cached index 2980


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.67it/s]


start idx 3040 | 20 frames in 7.864348888397217 s
Highest cached index 3060


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.55it/s]


start idx 3120 | 20 frames in 8.22220492362976 s
Highest cached index 3140


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.49it/s]


start idx 3200 | 20 frames in 8.442840814590454 s
Highest cached index 3220


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.68it/s]


start idx 3280 | 20 frames in 7.827042818069458 s
Highest cached index 3300


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.58it/s]


start idx 3360 | 20 frames in 8.139503240585327 s
Highest cached index 3380


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.45it/s]


start idx 3440 | 20 frames in 8.568807125091553 s
Highest cached index 3460


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.71it/s]


start idx 3520 | 20 frames in 7.737731218338013 s
Highest cached index 3540


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.59it/s]


start idx 3600 | 20 frames in 8.100906133651733 s
Highest cached index 3620


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.52it/s]


start idx 3680 | 20 frames in 8.346771001815796 s
Highest cached index 3700


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.72it/s]


start idx 3760 | 20 frames in 7.731273412704468 s
Highest cached index 3780


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.62it/s]


start idx 3840 | 20 frames in 8.027160167694092 s
Highest cached index 3860


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.52it/s]


start idx 3920 | 20 frames in 8.345791101455688 s


running prefetcher
running prefetcher


In [None]:
# output_dict_per_obj is huge 
# Structure
# -> obj_id
# --> cond_frame_outputs
# --> non_cond_frame_outputs