### Napari tests

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os , sys
sys.path.append('..')
from pathlib import Path
cur_path = Path(os.getcwd()).parent
sam2_path = cur_path / 'sam2_octron'
sys.path.append(cur_path.as_posix())
from matplotlib import pyplot as plt
import cmasher as cmr
import numpy as np
import seaborn as sns
sns.set_theme(style='white')
%config InlineBackend.figure_format = 'retina'

In [3]:
import napari
from napari.qt.threading import thread_worker
from napari.utils import DirectLabelColormap
from napari.utils.notifications import show_info
import warnings
warnings.simplefilter(action='always', category=FutureWarning)
import time

#### Importing additional stuff 
from skimage import measure
from skimage.draw import polygon2mask

In [4]:
# ### Object organizer
# from octron.sam2_octron.object_organizer import Obj, ObjectOrganizer
# object_organizer = ObjectOrganizer()
# object_organizer.add_entry(0, Obj(label='worm', suffix='0'))
# object_organizer.add_entry(1, Obj(label='worm', suffix='1'))
# object_organizer.add_entry(2, Obj(label='worm', suffix='two'))
# object_organizer.add_entry(3, Obj(label='worm', suffix='three'))
# object_organizer.add_entry(4, Obj(label='octopus', suffix=''))
# object_organizer.add_entry(5, Obj(label='octopus', suffix='another'))
# object_organizer.add_entry(6, Obj(label='octopus', suffix='three'))

# object_organizer

In [5]:
import zarr

In [6]:
# zip_path = '/Users/horst/Downloads/octron_project/worm masks.zip'
# store = zarr.storage.ZipStore(zip_path, mode='w')

# num_frames = 100
# num_ch = 1
# image_height = 512
# image_width = 512
# chunk_size = 10 
# fill_value = 0
# dtype = 'uint8'
# image_zarr = zarr.create_array(store=store,
#                                name='masks',
#                                shape=(num_frames, num_ch, image_height, image_width),  
#                                chunks=(chunk_size, num_ch, image_height, image_width), 
#                                fill_value=fill_value,
#                                dtype=dtype,
#                                overwrite=True,
#                                       )
# store.close()

In [8]:
zip_path = Path('/Users/horst/Downloads/octron_project/water masks.zip')
zip_path.exists()

True

In [9]:
store = zarr.storage.LocalStore(zip_path, read_only=False)  
root = zarr.open_group(store=store, mode='a')

In [10]:

print("Existing keys in zarr archive:", list(root.array_keys()))

Existing keys in zarr archive: ['masks']


In [11]:
image_zarr = root['masks']
image_zarr.info

Type               : Array
Zarr format        : 3
Data type          : DataType.uint8
Shape              : (4067, 1000, 1000)
Chunk shape        : (20, 1000, 1000)
Order              : C
Read-only          : False
Store type         : LocalStore
Filters            : ()
Serializer         : BytesCodec(endian=<Endian.little: 'little'>)
Compressors        : (ZstdCodec(level=0, checksum=False),)
No. bytes          : 4067000000 (3.8G)

#### Initiate the SAM2 model like you do in Napari

In [None]:
viewer = napari.Viewer()    

In [None]:
from sam2_octron.helpers.sam2_checks import check_sam2_models
from sam2_octron.helpers.sam2_octron import run_new_pred
from sam2_octron.helpers.build_sam2_octron import build_sam2_octron
from sam2_octron.helpers.sam2_zarr import create_image_zarr
from sam2_octron.helpers.sam2_colors import create_label_colors

In [None]:
models_yaml_path = sam2_path / 'models.yaml'
models_dict = check_sam2_models(SAM2p1_BASE_URL='',
                                       models_yaml_path=models_yaml_path,
                                      force_download=False,
                                      )
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Careful . these path descriptors differ slightly between notebook and 
# plugin version
model = models_dict['sam2_large']

config_path = Path(model['config_path'])
checkpoint_path = sam2_path / Path(f"{model['checkpoint_path']}")
predictor, device  = build_sam2_octron(config_file=config_path.as_posix(), 
                                       ckpt_path=checkpoint_path.as_posix(), 
                                       )

### From napari, after loading video file, extract info 

In [9]:
viewer.dims.set_point(0,0)

In [12]:
device

device(type='mps')

In [11]:
predictor.image_size

1024

In [24]:
num_frames, video_height, video_width, n_ch = viewer.layers[0].data.shape
print(f"Video shape: {viewer.layers[0].data.shape}")

Video shape: (16012, 1258, 2048, 3)


In [None]:
largest_edge = max(video_width, video_height) 
image_scaler = predictor.image_size / largest_edge
resized_height = int(np.floor(image_scaler * video_height))
resized_width = int(np.floor(image_scaler * video_width))
assert max(resized_height, resized_width) == predictor.image_size

In [40]:
from sam2_octron.helpers.sam2_zarr import OctoZarr, Resize

In [42]:
_resize_img = Resize(size=(resized_height, resized_width))

In [45]:
import sam2

In [46]:
sam2.__file__

'/Users/horst/Documents/python/segment-anything-2/sam2/__init__.py'

In [43]:
158/152

1.0394736842105263

In [25]:
chunk_size = 15

In [None]:
# Create temp output dir 
sample_dir = cur_path / 'sample_data'
sample_dir.mkdir(exist_ok=True)
sample_data_zarr = sample_dir / 'sample_data.zip'

image_zarr = create_image_zarr(sample_data_zarr,
                               num_frames=num_frames,
                               image_height=predictor.image_size,
                               chunk_size=chunk_size,
                               )

In [None]:
predictor.init_state(video_data=video_data, zarr_store=image_zarr)
predictor.reset_state()

In [None]:
# Set up thread worker to deal with prefetching batches of images
@thread_worker
def thread_prefetch_images(batch_size):
    global viewer
    current_indices = viewer.dims.current_step
    print(f'Prefetching {batch_size} images, starting at frame {current_indices[0]}')
    _ = predictor.images[slice(current_indices[0],current_indices[0]+batch_size)]
prefetcher_worker = thread_prefetch_images(chunk_size)   
prefetcher_worker.setAutoDelete(False)
prefetcher_worker.start()

In [None]:
colors = create_label_colors(cmap='cmr.tropical')
# Select colormap for labels layer based on category (label) and current object ID 
base_color = DirectLabelColormap(color_dict=colors[0], 
                                 use_selection=True, 
                                 selection=1,
                                 )

#### Implement layer remove events

In [None]:
def on_layer_removed(event):
    global viewer
    global remove_current_layer, removed_layer
    print('Calling on_layer_removed')
    # if not remove_current_layer:
    #     viewer.add_layer(removed_layer)
    # else:
    #     print(f"Deleted layer {removed_layer}")
        
def on_layer_removing(event):
    global remove_current_layer, removed_layer
    
    layer2remove = event.source[event.index]
    # Not sure if possible to delete more than one
    # IF so, then take care of it ... event.sources is as list
    
    reply = QMessageBox.question(
        None, 
        "Confirmation", 
        f"Are you sure you want to delete layer\n'{layer2remove}'",
        QMessageBox.Yes | QMessageBox.No,
        QMessageBox.No
    )
    if reply == QMessageBox.No:
        remove_current_layer = False
        removed_layer = layer2remove
    else:
        remove_current_layer = True
        
viewer.layers.events.removing.connect(on_layer_removing)
viewer.layers.events.removed.connect(on_layer_removed)

#### Mask layer

In [None]:
labels_layer = viewer.add_labels(
    image_layer['mask_dummy'], 
    name='SAM2 masks',  
    opacity=0.4,  
    blending='additive',  
    colormap=base_color, 
)

qctrl = viewer.window.qt_viewer.controls.widgets[labels_layer]
buttons_to_hide = ['erase_button',
                   'fill_button',
                   'paint_button',
                   'pick_button',
                   'polygon_button',
                   'transform_button',
                   ]
for btn in buttons_to_hide: 
    getattr(qctrl, btn).hide()

#### Shapes layer 

In [None]:
label_name = 'wormsy '
label_name = label_name.strip().lower()
current_color = colors[0][1]

In [None]:
# Add the shapes layer to the viewer
shapes_layer = viewer.add_shapes(None, 
                                 ndim=3,
                                 name=label_name, 
                                 scale=(1,1),
                                 edge_width=1,
                                 edge_color=current_color,
                                 face_color=[1,1,1,0],
                                 opacity=.4,
                                 )


def on_shapes_changed(event):
    global skip_event
    action = event.action
    if action in ['added','removed','changed']:
        frame_idx = viewer.dims.current_step[0] 
        
        if shapes_layer.mode == 'add_rectangle':
            if action == 'removed':
                return
            # Take care of box input first. 
            # If the rectangle tool is selected, extract "box" coordinates
            box = shapes_layer.data[-1]
            if len(box) > 4:
                box = box[-4:]
            top_left, _, bottom_right, _ = box
            top_left, bottom_right = top_left[1:], bottom_right[1:]
            mask = run_new_pred(predictor=predictor,
                                frame_idx=frame_idx,
                                obj_id=0,
                                labels=[1],
                                box=[top_left[1],
                                     top_left[0],
                                     bottom_right[1],
                                     bottom_right[0]
                                     ],
                                )
            shapes_layer.data = shapes_layer.data[:-1]
            shapes_layer.refresh()  
        else:
            # In all other cases, just treat shapes as masks 
            shape_mask = shapes_layer.to_masks((video_height, video_width))
            shape_mask = np.sum(shape_mask, axis=0)
            if not isinstance(shape_mask, np.ndarray):
                return
            shape_mask[shape_mask > 0] = 1
            shape_mask = shape_mask.astype(np.uint8)
        
            label = 1 # Always positive for now
            mask = run_new_pred(predictor=predictor,
                                frame_idx=frame_idx,
                                obj_id=0,
                                labels=label,
                                masks=shape_mask,
                                )

        labels_layer.data[frame_idx] = mask
        labels_layer.refresh()
        
        # Prefetch batch of images
        if not prefetcher_worker.is_running:
            prefetcher_worker.run()
    return

# Store the initial length of the points data
# previous_length_points = len(shapes_layer.data)
# Hide the transform, delete, and select buttons
qctrl = viewer.window.qt_viewer.controls.widgets[shapes_layer]
buttons_to_hide = [
                   'line_button',
                   'path_button',
                   'polyline_button',
                   ]
for btn in buttons_to_hide:
    attr = getattr(qctrl, btn)
    attr.hide()

# Select the current, add tool for the points layer and attach the callback
viewer.layers.selection.active = shapes_layer
viewer.layers.selection.active.mode = 'pan_zoom'
shapes_layer.events.data.connect(on_shapes_changed)

#### Points layer

In [None]:
# Add the points layer to the viewer
points_layer = viewer.add_points(None, 
                                 ndim=3,
                                 name=label_name, 
                                 scale=(1,1),
                                 size=40,
                                 border_color='dimgrey',
                                 border_width=.2,
                                 opacity=.6,
                                 )


left_right_click = 'left'
def on_mouse_press(layer, event):
    """
    Generic function to catch left and right mouse clicks
    """
    global left_right_click
    if event.type == 'mouse_press':
        if event.button == 1:  # Left-click
            left_right_click = 'left'
        elif event.button == 2:  # Right-click
            left_right_click = 'right'     
    

def on_points_changed(event):
    """
    Function to run when points are added to the points layer
    """
    action = event.action
    frame_idx  = viewer.dims.current_step[0] 
    
    left_positive_color  = [0.59607846, 0.98431373, 0.59607846, 1.]
    right_negative_color = [1., 1., 1., 1.]
    
    if action == 'added':
        # A new point has just been added. 
        # Find out if you are dealing with a left or right click    
        if left_right_click == 'left':
            label = 1
            points_layer.face_color[-1] = left_positive_color
            points_layer.symbol[-1] = 'o'
        elif left_right_click == 'right':
            label = 0
            points_layer.face_color[-1] = right_negative_color
            points_layer.symbol[-1] = 'x'
        points_layer.refresh() # THIS IS IMPORTANT
        # Prefetch batch of images
        if not prefetcher_worker.is_running:
            prefetcher_worker.run()
        
    # Loop through all the data and create points and labels
    if action in ['added','removed','changed']:
        labels = []
        point_data = []
        for pt_no, pt in enumerate(points_layer.data):
            # Find out which label was attached to the point
            # by going through the symbol lists
            cur_symbol = points_layer.symbol[pt_no]
            if cur_symbol in ['o','disc']:
                label = 1
            else:
                label = 0
            labels.append(label)
            point_data.append(pt[1:][::-1]) # index 0 is the frame number
            
        # Then run the actual prediction
        mask = run_new_pred(predictor=predictor,
                            frame_idx=frame_idx,
                            obj_id=0,
                            labels=labels,
                            points=point_data,
                            )
        labels_layer.data[frame_idx,:,:] = mask
        labels_layer.refresh()   
        
    
points_layer.mouse_drag_callbacks.append(on_mouse_press)
points_layer.events.data.connect(on_points_changed)
# Select the current, add tool for the points layer
viewer.layers.selection.active = points_layer
viewer.layers.selection.active.mode = 'add'

### Thread prediction

In [None]:
obj_id = 0

@thread_worker
def thread_predict(frame_idx, max_imgs):
    global labels_layer

    video_segments = {} 
    start_time = time.time()
    # Prefetch images if they are not cached yet 
    _ = predictor.images[slice(frame_idx,frame_idx+max_imgs)]
    
    # Loop over frames and run prediction (single frame!)
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(start_frame_idx=frame_idx, 
                                                                                    max_frame_num_to_track=max_imgs):
        
        for i, out_obj_id in enumerate(out_obj_ids):
            
            torch_mask = out_mask_logits[i] > 0.0
            out_mask = torch_mask.cpu().numpy()

            video_segments[out_frame_idx] = {out_obj_id: out_mask}
            if not out_obj_id in predictor.inference_state['centroids']:
                predictor.inference_state['centroids'][out_obj_id] = {}
            if not out_obj_id in predictor.inference_state['areas']:
                predictor.inference_state['areas'][out_obj_id] = {}
                
        # PICK ONE OBJ (OBJ_ID = 0 or whatever)
        
        #  Add the mask image as a new labels layer
        mask = video_segments[out_frame_idx][obj_id] # THIS NEEDS TO BE MADE LAYER SPECIFIC 
        current_label = obj_id+1
        if len(np.unique(mask))>1:
            mask[mask==np.unique(mask)[1]] = current_label 

        mask = mask.squeeze()
        props = measure.regionprops(mask.astype(int))[0]
        predictor.inference_state['centroids'][obj_id][out_frame_idx] = props.centroid
        predictor.inference_state['areas'][obj_id][out_frame_idx] = props.area
        labels_layer.data[out_frame_idx,:,:] = mask
        viewer.dims.set_point(0,out_frame_idx)
        labels_layer.refresh()
    end_time = time.time()
    #print(f'start idx {frame_idx} | {max_imgs} frames in {end_time-start_time} s')
        
        

In [None]:
# plt.imshow(np.mean(labels_layer.data, axis=0))

In [None]:
# from torch import tensor as torch_tensor
# from skimage.morphology import disk

# predictor.perform_morphological_operations = True

# disk_size=10
# compute_device=device
# predictor.closing_kernel = torch_tensor(disk(disk_size).tolist()).to(compute_device)

In [None]:
print(f'Current chunk size: {chunk_size}')
worker = thread_predict(frame_idx=viewer.dims.current_step[0], max_imgs=chunk_size) 
#worker.returned.connect(viewer.add_image)  # connect callback functions
worker.start()  # start the thread!

In [None]:
# # Tried to save the checkpoint, 
# # but this does not work. 
# # the check point model_state does not contain enough info 
# import torch
# model_output_path = sample_dir / 'model_output.pth'    
# torch.save({
#             'model_state_dict': predictor.state_dict(),
#             }, model_output_path)

### **Danger zone** Predict the whole video as test 

In [None]:
# # Test 
# for i in range(0,500,chunk_size):
    
#     prediction_worker = thread_predict(frame_idx=i, max_imgs=chunk_size)  
#     prediction_worker.setAutoDelete(True)
#     #worker.returned.connect(viewer.add_image)  # connect callback functions
#     prediction_worker.start()  
#     print(f'Highest cached index {int(np.nanmax(predictor.images.cached_indices))}')
#     time.sleep(25)

### Plot some results 

In [None]:
# # Plotting
# import seaborn as sns
# sns.set_theme(style='white')
# %config InlineBackend.figure_format = 'retina'
# from matplotlib import pyplot as plt
# import matplotlib.gridspec as gridspec
# import matplotlib as mpl

# plt.style.use('dark_background')
# mpl.rcParams.update({"axes.grid" : True, "grid.color": "grey", "grid.alpha": .1})
# plt.rcParams['xtick.major.size'] = 10
# plt.rcParams['xtick.major.width'] = 1
# plt.rcParams['ytick.major.size'] = 10
# plt.rcParams['ytick.major.width'] = 1
# plt.rcParams['xtick.bottom'] = True
# plt.rcParams['ytick.left'] = True
# mpl.rcParams['savefig.pad_inches'] = .1

In [None]:
# # Plot the centroids over time
# centroids = list(predictor.inference_state['centroids'][0].values())
# centroids = np.stack(centroids)
# areas = np.array(list(predictor.inference_state['areas'][0].values())).astype(float)
# figure = plt.figure(figsize=(10,10))
# plt.imshow(viewer.layers[0].data[0], cmap='gray')
# #plt.plot(centroids[:,1], centroids[:,0], '-', color='k', alpha=.6)   
# plt.scatter(centroids[:,1], centroids[:,0], s=areas/50, marker='.', color='pink', alpha=.15, lw=0)   
# sns.despine(left=True,bottom=True)
# plt.title(f'Centroids over time (n={centroids.shape[0]} frames)')   

In [None]:
# plt.plot(list(predictor.inference_state['areas'][0].values()),'-', color='w', alpha=.6 )
# plt.title('Area over time')

In [None]:
# output_dict_per_obj is huge 
# Structure
# -> obj_id
# --> cond_frame_outputs
# --> non_cond_frame_outputs