### Napari tests

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os , sys
sys.path.append('..')
from pathlib import Path
cur_path = Path(os.getcwd()).parent

from matplotlib import pyplot as plt
import cmasher as cmr
import numpy as np
import seaborn as sns
sns.set_theme(style='white')
%config InlineBackend.figure_format = 'retina'

In [3]:
import napari
from napari.qt.threading import thread_worker
import warnings
warnings.simplefilter(action='always', category=FutureWarning)
import zarr
import shutil

In [4]:
viewer = napari.Viewer()

2025-01-29 19:51:57.868 python[41002:15170145] +[IMKClient subclass]: chose IMKClient_Modern
2025-01-29 19:51:57.868 python[41002:15170145] +[IMKInputSession subclass]: chose IMKInputSession_Modern


#### Initiate the SAM2 model like you do in Napari

In [5]:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
from octron.sam2_octron.helpers.build_sam2_octron import build_sam2_octron

In [6]:
sam2_folder = Path('sam2_octron')
checkpoint = 'sam2.1_hiera_base_plus.pt' # under folder /checkpoints
model_cfg = 'sam2.1/sam2.1_hiera_b+.yaml' # under folder /configs
# ----------------------------------------------------------------------------
sam2_checkpoint = cur_path / sam2_folder / Path(f'checkpoints/{checkpoint}')
model_cfg = Path(f'configs/{model_cfg}')


predictor, device  = build_sam2_octron(config_file=model_cfg.as_posix(), 
                                       ckpt_path=sam2_checkpoint.as_posix(), 
                                       )


Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might give numerically different outputs and sometimes degraded performance on MPS. See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.
Uing device: mps


Loaded SAM2VideoPredictor OCTRON
Model image size: 1024


### From napari, after loading video file, extract info 

In [7]:
viewer.dims.set_point(0,0)
current_indices = viewer.dims.current_step
print(current_indices)

(0, 499, 499)


In [8]:
video_data = viewer.layers[0].data   # the whole video

print(f'video shape: {video_data.shape}')   
num_frames, height, width, channels = video_data.shape


video shape: (4067, 1000, 1000, 3)


In [9]:
# Create zarr store to save all resized images 
chunk_size = 25





# Create temp output dir 
sample_dir = cur_path / 'sample_data'
sample_dir.mkdir(exist_ok=True)
sample_data_zarr = sample_dir / 'sample_data.zarr'
if sample_data_zarr.exists():
    shutil.rmtree(sample_data_zarr)

# Assuming local store on fast SSD, so no compression employed for now 
store = zarr.storage.LocalStore(sample_data_zarr, read_only=False)
root = zarr.create_group(store=store)
image_zarr = root.create_array(name='masks',
                                shape=(num_frames, 3, predictor.image_size, predictor.image_size),  
                                chunks=(chunk_size, 3, predictor.image_size, predictor.image_size), 
                                fill_value=np.nan,
                                dtype='float32')


In [10]:
predictor.init_state(napari_viewer=viewer, video_layer_idx=0, zarr_store=image_zarr)
predictor.reset_state()

  return torch._C._nn.upsample_bicubic2d(


Initialized SAM2 model


In [11]:
def run_new_pred(frame_idx,
                 obj_id, 
                 label,
                 point):
    
    assert label in [0,1]
    # Run prediction
    #obj_id : give a unique id to each object we interact with (it can be any integers)
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                                                frame_idx=frame_idx,
                                                obj_id=obj_id,
                                                points=np.array([point],dtype=np.float32),
                                                labels=np.array([label], np.int32)
                                                )
    
    # Add the mask image as a new labels layer
    mask = (out_mask_logits[0] > 0).cpu().numpy().astype(np.uint8)
    current_label = obj_id+1
    if len(np.unique(mask))>1:
        mask[mask==np.unique(mask)[1]] = current_label 
    mask = mask.squeeze()
    return mask 

In [12]:
# Set up thread worker to deal with prefetching batches of images

@thread_worker
def thread_prefetch_images(batch_size):
    print('running prefetcher')
    global viewer
    current_indices = viewer.dims.current_step
    _ = predictor.images[slice(current_indices[0],current_indices[0]+batch_size)]
prefetcher_worker = thread_prefetch_images(chunk_size)   
prefetcher_worker.setAutoDelete(False)
prefetcher_worker.start()


running prefetcher


In [13]:
# Instantiate the mask and annotation layers 
# Keep them empty at start 
mask_layer_dummy = np.zeros((num_frames, height, width), dtype=np.uint8)
mask_layer_dummy.shape

colors = cmr.take_cmap_colors('cmr.tropical', 8, cmap_range=(0, 1),
                     return_fmt='int')
colors_norm = np.stack(colors) / 255.0  
cyclic_map = napari.utils.CyclicLabelColormap(
    np.hstack([np.stack(colors_norm), np.ones((len(colors), 1))])
)
labels_layer = viewer.add_labels(
    mask_layer_dummy, 
    name='Mask',  # Name of the layer
    opacity=0.4,  # Optional: opacity of the labels
    colormap=cyclic_map,
    blending='additive'  # Optional: blending mode
)
# Add the points layer to the viewer
points_layer = viewer.add_points(None, 
                                 ndim=3,
                                 name='Annotations', 
                                 scale=(1,1),
                                 size=40,
                                 border_color='dimgrey',
                                 border_width=.2,
                                 opacity=.6,
                                 )
# Store the initial length of the points data
previous_length_points = len(points_layer.data)


left_right_click = 'left'
def on_mouse_press(layer, event):
    '''
    Generic function to catch left and right mouse clicks
    '''
    global left_right_click
    if event.type == 'mouse_press':
        if event.button == 1:  # Left-click
            left_right_click = 'left'
        elif event.button == 2:  # Right-click
            left_right_click = 'right'     
    

def on_points_added(event):
    '''
    Function to run when points are added to the points layer
    '''
    
    global points_layer
    global labels_layer
    global left_right_click
    global prefetcher_worker
    global previous_length_points
    
    current_length = len(points_layer.data)
    if current_length > previous_length_points:
        previous_length_points = current_length 

        # # Execute prediction 
        newest_point_data =  points_layer.data[-1]
        if left_right_click == 'left':
            label = 1
            points_layer.face_color[-1] = [0.59607846, 0.98431373, 0.59607846, 1.]
            points_layer.symbol[-1] = 'o'
        elif left_right_click == 'right':
            label = 0
            points_layer.face_color[-1] = [1., 1., 1., 1.]
            points_layer.symbol[-1] = 'x'
        points_layer.refresh() 
        # Run prediction
        frame_idx  = int(newest_point_data[0])
        point_data = newest_point_data[1:][::-1]
        mask = run_new_pred(frame_idx=frame_idx,
                            obj_id=0,
                            label=label,
                            point=point_data,
                            )
        labels_layer.data[frame_idx,:,:] = mask
        labels_layer.refresh()   
        
        # Prefetch batch of images
        # This is done here since adding it as direct mouse interaction 
        # slows down the first prediction
        if not prefetcher_worker.is_running:
            prefetcher_worker.start()


points_layer.mouse_drag_callbacks.append(on_mouse_press)
points_layer.events.data.connect(on_points_added)

# Hide the transform, delete, and select buttons
qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]
getattr(qctrl, 'transform_button').setVisible(False)
getattr(qctrl, 'delete_button').setVisible(False)
getattr(qctrl, 'select_button').setVisible(False)
# Select the current, add tool for the points layer
viewer.layers.selection.active = points_layer
viewer.layers.selection.active.mode = 'add'

v0.6.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]


In [None]:
# Add the shapes layer to the viewer
shapes_layer = viewer.add_shapes(None, 
                                 ndim=3,
                                 name='Shape annotations', 
                                 scale=(1,1),
                                 edge_width=0,
                                 face_color=np.array(cyclic_map.colors),
                                 opacity=.6,
                                 )
# Store the initial length of the points data
previous_length_points = len(shapes_layer.data)

In [18]:
obj_id = 0


@thread_worker
def thread_predict(frame_idx, max_imgs):
    global labels_layer
    video_segments = {} 
    

    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(start_frame_idx=frame_idx, 
                                                                                    max_frame_num_to_track=max_imgs):
        
        for i, out_obj_id in enumerate(out_obj_ids):
            video_segments[out_frame_idx] = {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()}
            
        #  Add the mask image as a new labels layer
        mask = video_segments[out_frame_idx][obj_id] # THIS NEEDS TO BE MADE LAYER SPECIFIC 
        current_label = obj_id+1
        if len(np.unique(mask))>1:
            mask[mask==np.unique(mask)[1]] = current_label 

        mask = mask.squeeze()
        labels_layer.data[out_frame_idx,:,:]= mask
        viewer.dims.set_point(0,out_frame_idx)
        labels_layer.refresh()
        
        
        
        

In [26]:
worker = thread_predict(frame_idx=viewer.dims.current_step[0], max_imgs=chunk_size)  # create "worker" object
#worker.returned.connect(viewer.add_image)  # connect callback functions
worker.start()  # start the thread!

Indices not in cache []


Predicting:   0%|          | 0/26 [00:00<?, ?it/s]

Indices not in cache []


Predicting:   8%|▊         | 2/26 [00:01<00:15,  1.53it/s]

Indices not in cache []


Predicting:  12%|█▏        | 3/26 [00:02<00:17,  1.32it/s]

Indices not in cache []


Predicting:  15%|█▌        | 4/26 [00:03<00:19,  1.11it/s]

Indices not in cache []


Predicting:  19%|█▉        | 5/26 [00:04<00:19,  1.09it/s]

Indices not in cache []


Predicting:  23%|██▎       | 6/26 [00:06<00:24,  1.20s/it]

Indices not in cache []


Predicting:  27%|██▋       | 7/26 [00:07<00:21,  1.14s/it]

Indices not in cache []


Predicting:  31%|███       | 8/26 [00:08<00:23,  1.30s/it]

Indices not in cache []


Predicting:  35%|███▍      | 9/26 [00:09<00:19,  1.13s/it]

Indices not in cache []


Predicting:  38%|███▊      | 10/26 [00:10<00:15,  1.05it/s]

Indices not in cache []


Predicting:  42%|████▏     | 11/26 [00:10<00:12,  1.22it/s]

Indices not in cache []


Predicting:  46%|████▌     | 12/26 [00:11<00:10,  1.34it/s]

Indices not in cache []


Predicting:  50%|█████     | 13/26 [00:11<00:09,  1.40it/s]

Indices not in cache []


Predicting:  54%|█████▍    | 14/26 [00:12<00:08,  1.48it/s]

Indices not in cache []


Predicting:  58%|█████▊    | 15/26 [00:12<00:06,  1.60it/s]

Indices not in cache []


Predicting:  62%|██████▏   | 16/26 [00:13<00:05,  1.68it/s]

Indices not in cache []


Predicting:  65%|██████▌   | 17/26 [00:14<00:05,  1.57it/s]

Indices not in cache []


Predicting:  69%|██████▉   | 18/26 [00:14<00:04,  1.67it/s]

Indices not in cache []


Predicting:  73%|███████▎  | 19/26 [00:15<00:03,  1.76it/s]

Indices not in cache []


Predicting:  77%|███████▋  | 20/26 [00:15<00:03,  1.84it/s]

Indices not in cache []


Predicting:  81%|████████  | 21/26 [00:16<00:02,  1.90it/s]

Indices not in cache []


Predicting:  85%|████████▍ | 22/26 [00:16<00:02,  1.94it/s]

Indices not in cache []


Predicting:  88%|████████▊ | 23/26 [00:17<00:01,  1.93it/s]

Indices not in cache []


Predicting:  92%|█████████▏| 24/26 [00:17<00:01,  1.97it/s]

Indices not in cache []


Predicting:  96%|█████████▌| 25/26 [00:18<00:00,  1.99it/s]

Indices not in cache [3457]


Predicting: 100%|██████████| 26/26 [00:18<00:00,  1.40it/s]


In [31]:
worker.quit()  # stop the thread

In [24]:
predictor.images.cached_indices

array([2637., 2638., 2639., 2640., 2641., 2642., 2734., 2735., 2736.,
       2737., 2738., 2739., 2740., 2741., 2742., 2743., 2744., 2745.,
       2746., 2747., 2748., 2749., 2750., 2751., 2752., 2753., 2754.,
       2755., 2756., 2757., 2758., 2759., 3016., 3017., 3018., 3019.,
       3020., 3021., 3022., 3023., 3024., 3025., 3026., 3027., 3028.,
       3029., 3030., 3031., 3032., 3033., 3034., 3035., 3036., 3037.,
       3038., 3039., 3040., 3041., 1473., 1474., 1475., 1476., 1477.,
       1478., 1479., 1480., 1481., 1482., 1483., 1484., 1485., 1486.,
       1487., 1488., 1489., 1490., 1491., 1492., 1913., 1914., 1915.,
       1916., 1917., 1918., 1919., 1920., 1921., 1922., 1923., 1924.,
       1925., 1926., 1927., 1928., 1929., 1930., 1931., 1932., 1933.,
       1934., 1935., 1936., 1937., 1938., 2165., 2166., 2167., 2168.,
       2169., 2170., 2171., 2172., 2173., 2174., 2175., 2176., 2177.,
       2178., 2179., 2180., 2181., 2182., 2183., 2184., 2185., 2186.,
       2187., 2188.,

In [32]:
prefetcher_worker.quit()  # stop the thread 

In [None]:
# output_dict_per_obj is huge 
# Structure
# -> obj_id
# --> cond_frame_outputs
# --> non_cond_frame_outputs