### Napari tests

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os , sys
sys.path.append('..')
from pathlib import Path
cur_path = Path(os.getcwd()).parent

from matplotlib import pyplot as plt
import cmasher as cmr
import numpy as np
import seaborn as sns
sns.set_theme(style='white')
%config InlineBackend.figure_format = 'retina'

In [3]:
import napari
from napari.qt.threading import thread_worker
from napari.utils import DirectLabelColormap
import warnings
warnings.simplefilter(action='always', category=FutureWarning)
import zarr
import time

#### Importing additional stuff 
from skimage import measure
from skimage.draw import polygon2mask

In [4]:
viewer = napari.Viewer()

#### Initiate the SAM2 model like you do in Napari

In [6]:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from octron.sam2_octron.helpers.build_sam2_octron import build_sam2_octron

In [7]:
sam2_folder = Path('sam2_octron')
checkpoint = 'sam2.1_hiera_base_plus.pt' # under folder /checkpoints
model_cfg = 'sam2.1/sam2.1_hiera_b+.yaml' # under folder /configs
# ----------------------------------------------------------------------------
sam2_checkpoint = cur_path / sam2_folder / Path(f'checkpoints/{checkpoint}')
model_cfg = Path(f'configs/{model_cfg}')


predictor, device  = build_sam2_octron(config_file=model_cfg.as_posix(), 
                                       ckpt_path=sam2_checkpoint.as_posix(), 
                                       )


Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might give numerically different outputs and sometimes degraded performance on MPS. See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.
Uing device: mps


Loaded SAM2VideoPredictor OCTRON
Model image size: 1024


### From napari, after loading video file, extract info 

In [8]:
viewer.dims.set_point(0,0)
current_indices = viewer.dims.current_step
print(current_indices)

(0, 499, 499)


In [9]:
video_data = viewer.layers[0].data   # the whole video

print(f'video shape: {video_data.shape}')   
num_frames, height, width, channels = video_data.shape


video shape: (4067, 1000, 1000, 3)


In [10]:
# Create zarr store to save all resized images 
chunk_size = 20



# Create temp output dir 
sample_dir = cur_path / 'sample_data'
sample_dir.mkdir(exist_ok=True)
sample_data_zarr = sample_dir / 'sample_data.zip'
if sample_data_zarr.exists():
    os.remove(sample_data_zarr)

# Assuming local store on fast SSD, so no compression employed for now 
store = zarr.storage.ZipStore(sample_data_zarr, mode='w')
root = zarr.create_group(store=store)
image_zarr = root.create_array(name='masks',
                                shape=(num_frames, 3, predictor.image_size, predictor.image_size),  
                                chunks=(chunk_size, 3, predictor.image_size, predictor.image_size), 
                                fill_value=np.nan,
                                dtype='float32')


In [11]:
predictor.init_state(napari_viewer=viewer, video_layer_idx=0, zarr_store=image_zarr)
predictor.reset_state()

  return torch._C._nn.upsample_bicubic2d(


Initialized SAM2 model


In [12]:
# Trying to load previous checkpoint does not work 
# Since the model still expects at least one input

# import torch 
# state_path = '/Users/horst/Documents/python/OCTRON/octron/sample_data/model_output.pth'
# checkpoint = torch.load(state_path, weights_only=True)
# predictor.load_state_dict(checkpoint['model_state_dict'])

In [13]:
def run_new_pred(frame_idx,
                 obj_id, 
                 label,
                 point=None,
                 mask=None
                 ):
    assert label in [0,1], f'label must be 0 or 1, got {label}'
    assert point is not None or mask is not None
    if mask is not None:
        assert len(mask.shape) == 2
        
        print('Running mask prediction')
        frame_idx, obj_ids, video_res_masks = predictor.add_new_mask(
                                                    frame_idx=frame_idx,
                                                    obj_id=obj_id,
                                                    mask=np.array(mask, dtype=bool)
                                                    )
        mask = (video_res_masks[0] > 0).cpu().numpy().astype(np.uint8)
                
        
    if point is not None:
        assert len(point) == 2
        # Run point prediction
        print('Running point prediction')
        _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                                                    frame_idx=frame_idx,
                                                    obj_id=obj_id,
                                                    points=np.array([point],dtype=np.float32),
                                                    labels=np.array([label], np.int32)
                                                    )
        
        # Add the mask image as a new labels layer
        mask = (out_mask_logits[0] > 0).cpu().numpy().astype(np.uint8)
        
    current_label = obj_id+1
    if len(np.unique(mask))>1:
        mask[mask==np.unique(mask)[1]] = current_label 
    mask = mask.squeeze()
    return mask 

In [14]:
# Set up thread worker to deal with prefetching batches of images
@thread_worker
def thread_prefetch_images(batch_size):
    print('Running prefetcher')
    global viewer
    current_indices = viewer.dims.current_step
    _ = predictor.images[slice(current_indices[0],current_indices[0]+batch_size)]
prefetcher_worker = thread_prefetch_images(chunk_size)   
prefetcher_worker.setAutoDelete(False)
prefetcher_worker.start()

Running prefetcher


In [15]:
# - create a continous colormap and cmap_range for each label 
# --> this way we get sub colormaps and for the same label name can cycle through these 
def create_label_colors(cmap='cmr.tropical', label_n=4, obj_n=5):
    '''
    Create color map dictionary for labels 
    label(int) -> color list -> color(4D)
    
    For each label (n=label_cat_n) create a sub colormap with color_cat_n colors.
    
    '''

    label_cat_rel = np.linspace(0,1,label_n+1) # This must be the ugliest written fctn in the world

    label_colors = {}
    for cat in range(len(label_cat_rel)-1):
        rel_range = label_cat_rel[cat:cat+2]
        colors = np.array(cmr.take_cmap_colors(cmap, 
                                              obj_n, 
                                              cmap_range=(rel_range[0], rel_range[1]), 
                                              return_fmt='int'
                                              ) 
                        ) / 255.0
        colors4d = np.hstack([colors, np.ones((len(colors), 1))])
        label_colors[cat] = {i+1: color for i, color in enumerate(colors4d)} # Keys start at 1 !
        label_colors[cat][None] = np.array([0,0,0,0]).astype(np.float32)
    return label_colors

In [16]:
label_id = 1 # Only plays a role here ... not in predictor  
obj_id = 0

colors = create_label_colors(cmap='cmr.tropical')

In [17]:
# Instantiate the mask and annotation layers 
# Keep them empty at start 
mask_layer_dummy = np.zeros((num_frames, height, width), dtype=np.uint8)
mask_layer_dummy.shape


# Select colormap for labels layer based on category (label) and current object ID 
current_color_labelmap = DirectLabelColormap(color_dict=colors[label_id], 
                                             use_selection=True, 
                                             selection=obj_id+1,
                                             )
labels_layer = viewer.add_labels(
    mask_layer_dummy, 
    name='Mask',  # Name of the layer
    opacity=0.4,  # Optional: opacity of the labels
    blending='additive',  # Optional: blending mode
    colormap=current_color_labelmap, 
)

qctrl = viewer.window.qt_viewer.controls.widgets[labels_layer]
buttons_to_hide = ['erase_button',
                   'fill_button',
                   'paint_button',
                   'pick_button',
                   'polygon_button',
                   'transform_button',
                   ]
for btn in buttons_to_hide:
    getattr(qctrl, btn).setEnabled(False)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
v0.6.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  qctrl = viewer.window.qt_viewer.controls.widgets[labels_layer]


In [18]:
# Add the shapes layer to the viewer
shapes_layer = viewer.add_shapes(None, 
                                 ndim=3,
                                 name='Shape annotations', 
                                 scale=(1,1),
                                 edge_width=0,
                                 face_color=colors[label_id][obj_id+1],
                                 opacity=.5,
                                 )

# Store the initial length of the shapes data
previous_length_shapes = len(shapes_layer.data)

# Function to convert polygon data to mask
def polygons_to_mask(polygon, shape):
    mask = np.zeros(shape, dtype=np.uint8)
    mask = polygon2mask(mask.shape, polygon)
    mask = mask.astype(np.uint8)
    return mask

def on_shapes_added(event):
    global shapes_layer
    global labels_layer
    global previous_length_shapes
    global height, width
    current_length = len(shapes_layer.data)
    if current_length > previous_length_shapes:
        previous_length_shapes = current_length 

        # Execute prediction 
        newest_shape_data =  shapes_layer.data[-1]    
        frame_idx = int(newest_shape_data[:,0][0])
        print(frame_idx)
        input_mask = polygons_to_mask(newest_shape_data[:,1:], (height, width))
        label = 1
        mask = run_new_pred(frame_idx=frame_idx,
                            obj_id=0,
                            label=label,
                            mask=input_mask,
                            )
        
        labels_layer.data[frame_idx] = mask
        labels_layer.refresh()
        
        # Prefetch batch of images
        # This is done here since adding it as direct mouse interaction 
        # slows down the first prediction
        if not prefetcher_worker.is_running:
            prefetcher_worker.run()
        
    return

# Store the initial length of the points data
# previous_length_points = len(shapes_layer.data)
# Hide the transform, delete, and select buttons
qctrl = viewer.window.qt_viewer.controls.widgets[shapes_layer]
buttons_to_hide = ['transform_button', 
                   'delete_button', 
                   'select_button', 
                   'direct_button',
                   'ellipse_button',
                   'line_button',
                   'move_back_button',
                   'move_front_button',
                   'path_button',
                   'polygon_button',
                   'polyline_button',
                   'rectangle_button',
                   'vertex_insert_button',
                   'vertex_remove_button',

                   ]
for btn in buttons_to_hide:
    getattr(qctrl, btn).setEnabled(False)

# Select the current, add tool for the points layer
viewer.layers.selection.active = shapes_layer
viewer.layers.selection.active.mode = 'pan_zoom'

shapes_layer.events.data.connect(on_shapes_added)

v0.6.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  qctrl = viewer.window.qt_viewer.controls.widgets[shapes_layer]


<function __main__.on_shapes_added(event)>

In [19]:
# Add the points layer to the viewer
points_layer = viewer.add_points(None, 
                                 ndim=3,
                                 name='Annotations', 
                                 scale=(1,1),
                                 size=40,
                                 border_color='dimgrey',
                                 border_width=.2,
                                 opacity=.6,
                                 )
# Store the initial length of the points data
previous_length_points = len(points_layer.data)


left_right_click = 'left'
def on_mouse_press(layer, event):
    '''
    Generic function to catch left and right mouse clicks
    '''
    global left_right_click
    if event.type == 'mouse_press':
        if event.button == 1:  # Left-click
            left_right_click = 'left'
        elif event.button == 2:  # Right-click
            left_right_click = 'right'     
    

def on_points_added(event):
    '''
    Function to run when points are added to the points layer
    '''
    
    global points_layer
    global labels_layer
    global left_right_click
    global prefetcher_worker
    global previous_length_points
    
    current_length = len(points_layer.data)
    if current_length > previous_length_points:
        previous_length_points = current_length 

        # Execute prediction 
        newest_point_data =  points_layer.data[-1]
        if left_right_click == 'left':
            label = 1
            points_layer.face_color[-1] = [0.59607846, 0.98431373, 0.59607846, 1.]
            points_layer.symbol[-1] = 'o'
        elif left_right_click == 'right':
            label = 0
            points_layer.face_color[-1] = [1., 1., 1., 1.]
            points_layer.symbol[-1] = 'x'
        points_layer.refresh() 
        # Run prediction
        frame_idx  = int(newest_point_data[0])
        point_data = newest_point_data[1:][::-1]
        mask = run_new_pred(frame_idx=frame_idx,
                            obj_id=0,
                            label=label,
                            point=point_data,
                            )
        labels_layer.data[frame_idx,:,:] = mask
        labels_layer.refresh()   
        
        # Prefetch batch of images
        # This is done here since adding it as direct mouse interaction 
        # slows down the first prediction
        if not prefetcher_worker.is_running:
            prefetcher_worker.run()


points_layer.mouse_drag_callbacks.append(on_mouse_press)
points_layer.events.data.connect(on_points_added)

# Hide the transform, delete, and select buttons
qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]
buttons_to_hide = ['transform_button', 
                   'delete_button', 
                   'select_button', 
]
for btn in buttons_to_hide:
    getattr(qctrl, btn).setEnabled(False)
                   

# Select the current, add tool for the points layer
viewer.layers.selection.active = points_layer
viewer.layers.selection.active.mode = 'add'

v0.6.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]


In [20]:
obj_id = 0


@thread_worker
def thread_predict(frame_idx, max_imgs):
    global labels_layer

    video_segments = {} 
    start_time = time.time()
    # Prefetch images if they are not cached yet 
    _ = predictor.images[slice(frame_idx,frame_idx+max_imgs)]
    
    # Loop over frames and run prediction (single frame!)
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(start_frame_idx=frame_idx, 
                                                                                    max_frame_num_to_track=max_imgs):
        
        for i, out_obj_id in enumerate(out_obj_ids):
            video_segments[out_frame_idx] = {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()}
            if not out_obj_id in predictor.inference_state['centroids']:
                predictor.inference_state['centroids'][out_obj_id] = {}
            if not out_obj_id in predictor.inference_state['areas']:
                predictor.inference_state['areas'][out_obj_id] = {}
                
        # PICK ONE OBJ (OBJ_ID = 0 or whatever)
        
        #  Add the mask image as a new labels layer
        mask = video_segments[out_frame_idx][obj_id] # THIS NEEDS TO BE MADE LAYER SPECIFIC 
        current_label = obj_id+1
        if len(np.unique(mask))>1:
            mask[mask==np.unique(mask)[1]] = current_label 

        mask = mask.squeeze()
        props = measure.regionprops(mask.astype(int))[0]
        predictor.inference_state['centroids'][obj_id][out_frame_idx] = props.centroid
        predictor.inference_state['areas'][obj_id][out_frame_idx] = props.area
        labels_layer.data[out_frame_idx,:,:] = mask
        viewer.dims.set_point(0,out_frame_idx)
        labels_layer.refresh()
    end_time = time.time()
    #print(f'start idx {frame_idx} | {max_imgs} frames in {end_time-start_time} s')
        
        

0
Running mask prediction
Running prefetcher
6
Running mask prediction
Running prefetcher
14
Running mask prediction
Running prefetcher


In [22]:
print(f'Current chunk size: {chunk_size}')
worker = thread_predict(frame_idx=viewer.dims.current_step[0], max_imgs=chunk_size) 
#worker.returned.connect(viewer.add_image)  # connect callback functions
worker.start()  # start the thread!

Current chunk size: 20


Predicting: 100%|██████████| 21/21 [00:08<00:00,  2.40it/s]


In [22]:
# # Tried to save the checkpoint, 
# # but this does not work. 
# # the check point model_state does not contain enough info 
# import torch
# model_output_path = sample_dir / 'model_output.pth'    
# torch.save({
#             'model_state_dict': predictor.state_dict(),
#             }, model_output_path)

### Predict the whole video as test 

In [22]:
# Test 
for i in range(0,4000,chunk_size):
    
    prediction_worker = thread_predict(frame_idx=i, max_imgs=chunk_size)  
    prediction_worker.setAutoDelete(True)
    #worker.returned.connect(viewer.add_image)  # connect callback functions
    prediction_worker.start()  
    print(f'Highest cached index {int(np.nanmax(predictor.images.cached_indices))}')
    time.sleep(12)

Highest cached index 40


Predicting: 100%|██████████| 21/21 [00:07<00:00,  2.90it/s]


KeyboardInterrupt: 

### Plot some results 

In [26]:
# # Plotting
# import seaborn as sns
# sns.set_theme(style='white')
# %config InlineBackend.figure_format = 'retina'
# from matplotlib import pyplot as plt
# import matplotlib.gridspec as gridspec
# import matplotlib as mpl

# plt.style.use('dark_background')
# mpl.rcParams.update({"axes.grid" : True, "grid.color": "grey", "grid.alpha": .1})
# plt.rcParams['xtick.major.size'] = 10
# plt.rcParams['xtick.major.width'] = 1
# plt.rcParams['ytick.major.size'] = 10
# plt.rcParams['ytick.major.width'] = 1
# plt.rcParams['xtick.bottom'] = True
# plt.rcParams['ytick.left'] = True
# mpl.rcParams['savefig.pad_inches'] = .1

In [None]:
# # Plot the centroids over time
# centroids = list(predictor.inference_state['centroids'][0].values())
# centroids = np.stack(centroids)
# areas = np.array(list(predictor.inference_state['areas'][0].values())).astype(float)
# figure = plt.figure(figsize=(10,10))
# plt.imshow(viewer.layers[0].data[0], cmap='gray')
# #plt.plot(centroids[:,1], centroids[:,0], '-', color='k', alpha=.6)   
# plt.scatter(centroids[:,1], centroids[:,0], s=areas/50, marker='.', color='pink', alpha=.15, lw=0)   
# sns.despine(left=True,bottom=True)
# plt.title(f'Centroids over time (n={centroids.shape[0]} frames)')   

In [None]:
# plt.plot(list(predictor.inference_state['areas'][0].values()),'-', color='w', alpha=.6 )
# plt.title('Area over time')

In [None]:
# output_dict_per_obj is huge 
# Structure
# -> obj_id
# --> cond_frame_outputs
# --> non_cond_frame_outputs