### Napari tests

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os , sys
sys.path.append('..')
from pathlib import Path
cur_path = Path(os.getcwd()).parent

from matplotlib import pyplot as plt
import cmasher as cmr
import numpy as np
import seaborn as sns
sns.set_theme(style='white')
%config InlineBackend.figure_format = 'retina'

In [3]:
import napari
from napari.qt.threading import thread_worker
import warnings
warnings.simplefilter(action='always', category=FutureWarning)
import zarr
import shutil

In [4]:
viewer = napari.Viewer()

2025-01-29 16:30:31.035 python[33449:14995171] +[IMKClient subclass]: chose IMKClient_Modern
2025-01-29 16:30:31.035 python[33449:14995171] +[IMKInputSession subclass]: chose IMKInputSession_Modern


#### Initiate the SAM2 model like you do in Napari

In [5]:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
from octron.sam2_octron.helpers.build_sam2_octron import build_sam2_octron

In [6]:
sam2_folder = Path('sam2_octron')
checkpoint = 'sam2.1_hiera_base_plus.pt' # under folder /checkpoints
model_cfg = 'sam2.1/sam2.1_hiera_b+.yaml' # under folder /configs
# ------------------------------------------------------------------------------------
sam2_checkpoint = cur_path / sam2_folder / Path(f'checkpoints/{checkpoint}')
model_cfg = Path(f'configs/{model_cfg}')


predictor, device  = build_sam2_octron(config_file=model_cfg.as_posix(), 
                                                       ckpt_path=sam2_checkpoint.as_posix(), 
                                     )


Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might give numerically different outputs and sometimes degraded performance on MPS. See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.
Uing device: mps


Loaded SAM2VideoPredictor OCTRON
Model image size: 1024


### From napari, after loading video file, extract info 

In [7]:
viewer.dims.set_point(0,0)
current_indices = viewer.dims.current_step
print(current_indices)

(0, 771, 771)


In [8]:
video_data = viewer.layers[0].data   # the whole video

print(f'video shape: {video_data.shape}')   
num_frames, height, width, channels = video_data.shape


video shape: (10801, 1544, 1544, 3)


In [9]:
# Create zarr store to save all resized images 
chunk_size = 25





# Create temp output dir 
sample_dir = cur_path / 'sample_data'
sample_dir.mkdir(exist_ok=True)
sample_data_zarr = sample_dir / 'sample_data.zarr'
if sample_data_zarr.exists():
    shutil.rmtree(sample_data_zarr)

# Assuming local store on fast SSD, so no compression employed for now 
store = zarr.storage.LocalStore(sample_data_zarr, read_only=False)
root = zarr.create_group(store=store)
image_zarr = root.create_array(name='masks',
                                shape=(num_frames, 3, predictor.image_size, predictor.image_size),  
                                chunks=(chunk_size, 3, predictor.image_size, predictor.image_size), 
                                fill_value=np.nan,
                                dtype='float32')


In [10]:
predictor.init_state(napari_viewer=viewer, video_layer_idx=0, zarr_store=image_zarr)
predictor.reset_state()

  return torch._C._nn.upsample_bicubic2d(


Initialized SAM2 model


In [11]:
def run_new_pred(frame_idx,
                 obj_id, 
                 label,
                 point):
    
    assert label in [0,1]
    # Run prediction
    #obj_id : give a unique id to each object we interact with (it can be any integers)
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                                                frame_idx=frame_idx,
                                                obj_id=obj_id,
                                                points=np.array([point],dtype=np.float32),
                                                labels=np.array([label], np.int32)
                                                )
    
    # Add the mask image as a new labels layer
    mask = (out_mask_logits[0] > 0).cpu().numpy().astype(np.uint8)
    current_label = obj_id+1
    if len(np.unique(mask))>1:
        mask[mask==np.unique(mask)[1]] = current_label 
    mask = mask.squeeze()
    return mask 

In [12]:
# Set up thread worker to deal with prefetching batches of images

@thread_worker
def thread_prefetch_images(batch_size):
    global viewer
    current_indices = viewer.dims.current_step
    _ = predictor.images[slice(current_indices[0],current_indices[0]+chunk_size)]
prefetcher_worker = thread_prefetch_images(chunk_size)   
prefetcher_worker.setAutoDelete(False)
prefetcher_worker.start()


In [13]:
# Instantiate the mask and annotation layers 
# Keep them empty at start 
mask_layer_dummy = np.zeros((num_frames, height, width), dtype=np.uint8)
mask_layer_dummy.shape

colors = cmr.take_cmap_colors('cmr.tropical', 8, cmap_range=(0, 1),
                     return_fmt='int')
colors_norm = np.stack(colors) / 255.0  
cyclic_map = napari.utils.CyclicLabelColormap(
    np.hstack([np.stack(colors_norm), np.ones((len(colors), 1))])
)
labels_layer = viewer.add_labels(
    mask_layer_dummy, 
    name='Mask',  # Name of the layer
    opacity=0.4,  # Optional: opacity of the labels
    colormap=cyclic_map,
    blending='additive'  # Optional: blending mode
)
# Add the points layer to the viewer
points_layer = viewer.add_points(None, 
                                 ndim=3,
                                 name='Annotations', 
                                 scale=(1,1),
                                 size=40,
                                 border_color='dimgrey',
                                 border_width=.2,
                                 opacity=.6,
                                 )
# Store the initial length of the points data
previous_length_points = len(points_layer.data)


left_right_click = 'left'
def on_mouse_press(layer, event):
    '''
    Generic function to catch left and right mouse clicks
    '''
    global left_right_click
    if event.type == 'mouse_press':
        if event.button == 1:  # Left-click
            left_right_click = 'left'
        elif event.button == 2:  # Right-click
            left_right_click = 'right'     
    

def on_points_added(event):
    '''
    Function to run when points are added to the points layer
    '''
    
    global points_layer
    global labels_layer
    global left_right_click
    global prefetcher_worker
    global previous_length_points
    
    current_length = len(points_layer.data)
    if current_length > previous_length_points:
        previous_length_points = current_length 

        # # Execute prediction 
        newest_point_data =  points_layer.data[-1]
        if left_right_click == 'left':
            label = 1
            points_layer.face_color[-1] = [0.59607846, 0.98431373, 0.59607846, 1.]
            points_layer.symbol[-1] = 'o'
        elif left_right_click == 'right':
            label = 0
            points_layer.face_color[-1] = [1., 1., 1., 1.]
            points_layer.symbol[-1] = 'x'
        points_layer.refresh() 
        # Run prediction
        frame_idx  = int(newest_point_data[0])
        point_data = newest_point_data[1:][::-1]
        mask = run_new_pred(frame_idx=frame_idx,
                            obj_id=0,
                            label=label,
                            point=point_data,
                            )
        labels_layer.data[frame_idx,:,:] = mask
        labels_layer.refresh()   
        
        # Prefetch batch of images
        # This is done here since adding it as direct mouse interaction 
        # slows down the first prediction
        if not prefetcher_worker.is_running:
            prefetcher_worker.run()


points_layer.mouse_drag_callbacks.append(on_mouse_press)
points_layer.events.data.connect(on_points_added)

# Hide the transform, delete, and select buttons
qctrl = viewer.window.qt_viewer.controls.widgets[points_layer]
getattr(qctrl, 'transform_button').setVisible(False)
getattr(qctrl, 'delete_button').setVisible(False)
getattr(qctrl, 'select_button').setVisible(False)
# Select the current, add tool for the points layer
viewer.layers.selection.active = points_layer
viewer.layers.selection.active.mode = 'add'

In [14]:
# Add the shapes layer to the viewer
shapes_layer = viewer.add_shapes(None, 
                                 ndim=3,
                                 name='Shape annotations', 
                                 scale=(1,1),
                                 edge_width=0,
                                 face_color=np.array(cyclic_map.colors),
                                 opacity=.6,
                                 )
# Store the initial length of the points data
previous_length_points = len(shapes_layer.data)

In [30]:
obj_id = 0


@thread_worker
def thread_predict(frame_idx, max_imgs):
    global labels_layer
    video_segments = {} 
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(start_frame_idx=frame_idx, 
                                                                                    max_frame_num_to_track=max_imgs):
        
        print('predicted')
        for i, out_obj_id in enumerate(out_obj_ids):
            video_segments[out_frame_idx] = {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()}
            
        #  Add the mask image as a new labels layer
        mask = video_segments[out_frame_idx][obj_id] # THIS NEEDS TO BE MADE LAYER SPECIFIC 
        current_label = obj_id+1
        if len(np.unique(mask))>1:
            mask[mask==np.unique(mask)[1]] = current_label 
            
            
        mask = mask.squeeze()
        labels_layer.data[out_frame_idx+frame_idx,:,:]= mask
        viewer.dims.set_point(0,out_frame_idx+frame_idx)
        labels_layer.refresh()


In [32]:
worker = thread_predict(frame_idx=viewer.dims.current_step[0] , max_imgs=chunk_size)  # create "worker" object
#worker.returned.connect(viewer.add_image)  # connect callback functions
worker.start()  # start the thread!

Predicting:   0%|          | 0/26 [00:00<?, ?it/s]

non conditioned frame 0


Predicting:   0%|          | 0/26 [00:00<?, ?it/s]


Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::BFloat16 instead.


In [19]:
worker.quit()  # stop the thread

In [None]:
# output_dict_per_obj is huge 
# Structure
# -> obj_id
# --> cond_frame_outputs
# --> non_cond_frame_outputs

In [None]:
video_data.shape

In [24]:
data_snippet = video_data[100:200]

In [None]:
from torchvision.transforms import Resize
_resize_img = Resize(
                        size=(predictor.image_size) # This is 1024x1024 for the l model
                    )

In [52]:
data_snippet = video_data[100:150]  
data_snippet.shape

data_snippet_torch = _resize_img(torch.from_numpy(data_snippet).permute(0,3,1,2)).float()
data_snippet_torch /= 255.

In [None]:
img_mean.shape

In [54]:
img_mean = (0.485, 0.456, 0.406)
img_std  = (0.229, 0.224, 0.225)
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

data_snippet_torch -= img_mean
data_snippet_torch /= img_std

In [None]:
data_snippet_torch -= img_mean
data_snippet_torch /= img_std

In [None]:
data_snippet_torch.shape

In [None]:
plt.imshow(data_snippet_torch[0][2,:,:].numpy().squeeze())