In [1]:
import os
import json
import napari

import numpy as np

from enum import Enum
from magicgui import magicgui
from skimage.io import imread
from pathlib import Path

from napari.layers import Image

from datetime import datetime

from zipfile import ZipFile
from skimage.external.tifffile import TiffWriter
import io

In [2]:
filename = "/Users/arl/Desktop/Pos12_aligned.tif"

In [3]:
# load the image data
data = imread(filename)
metadata = {"filename": filename}

In [4]:
class CellState(Enum):
    Interphase = 0
    Prometaphase = 1
    Metaphase = 2
    Anaphase = 3
    Apoptosis = 4

In [5]:
COLOR_CYCLE = [
    '#1f77b4',
    '#ff7f0e',
    '#2ca02c',
    '#d62728',
    '#9467bd',
]

In [6]:
def get_image_patch(layers, coords, shape=64):
    """Get an image patch from the image layer data.
    
    @Kristina
    TODO: We need to extract the image patches from the layer data.
    Layers is a list of layers, each layer has a .data property containing the image data.
    
    """
    return np.random.randint(0,255,size=(64,64)).astype(np.uint8)

In [7]:
def annotator(viewer):
    
    SESSION_TIME = datetime.now().strftime("%m-%d-%Y--%H-%M-%S")
    SESSION_NAME = f"annotation_{SESSION_TIME}"
    
    # add an empty points layer, with the same dimensions as the image data
    points_layer = viewer.add_points(
        name="Annotation", 
        properties={'State': [s.name for s in CellState]}, 
        ndim=data.ndim
    )

    points_layer.mode = 'add'
    points_layer.face_color = 'State'
    points_layer.face_color_cycle = COLOR_CYCLE
    points_layer.face_color_mode = 'cycle'
#     points_layer.n_dimensional = True
    
    @magicgui(
        call_button="Export",
        layout="horizontal",
        filename={"label": "Export path:"},  # custom label
    )
    def cnn_annotation_widget(
        filename=Path.home(),  # path objects are provided a file picker
        shape=64,
        use_visible_layers=True,
        state=CellState.Interphase,
    ):
        """Export the annotations."""
        
        export_data = {'shape': shape}
        
        # find the visible image layers and export the metadata
        image_layers = [layer for layer in viewer.layers if isinstance(layer, Image)]
        for layer in image_layers:
            if use_visible_layers and layer.visible:
                export_data[layer.name] = layer.metadata
                
        # record the coordinates of the annotations 
        for idx in range(points_layer.data.shape[1]):
            export_data[f'coords-{idx}'] = points_layer.data[:, idx].tolist()
        
        # record the state labels of the annotations 
        export_data['labels'] = points_layer.properties['State'].tolist()
        
        # extract the image patches here
        with ZipFile(f"{SESSION_NAME}.zip", 'w') as myzip:
            for idx, patch_coords in enumerate(points_layer.data):
                patch_label = points_layer.properties['State'][idx]

                # grab the image patch
                image_patch = get_image_patch(image_layers, patch_coords, shape=shape)
                image_patch_fn = f"{patch_label}/{patch_label}_{SESSION_TIME}_{idx}.tif"
    
                # open a stream to write to the zip file
                stream = io.BytesIO()
                with TiffWriter(stream) as tif:
                    tif.save(image_patch)
                    stream_data = stream.getvalue()
                myzip.writestr(image_patch_fn, stream_data)
        
            # write out the json log to the zip file also
            stream = json.dumps(export_data, indent=2)
            myzip.writestr(f"{SESSION_NAME}.json", stream)
        
        return locals().values()
    
    def _change_points_properties(event):
        """Update the current properties of the points layer to reflect the currently selected state"""
        points_layer.current_properties['State'] = np.array([cnn_annotation_widget.state.value.name])
    
    cnn_annotation_widget.state.changed.connect(_change_points_properties)
    
    # add the magicgui dock widget 
    viewer.window.add_dock_widget(cnn_annotation_widget)
    
    @viewer.bind_key('.')
    def next_label(event=None):
        """Increment the label in the gui"""
        new_state = (cnn_annotation_widget.state.value.value + 1) % len(CellState)
        cnn_annotation_widget.state.value = CellState(new_state)
        
    
    @viewer.bind_key(',')
    def previous_label(event=None):
        """Decrement the label in the gui"""
        new_state = (cnn_annotation_widget.state.value.value - 1) % len(CellState)
        cnn_annotation_widget.state.value = CellState(new_state)

    

        

        





In [8]:
with napari.gui_qt():
    viewer = napari.Viewer()
    viewer.add_image(data, name='GFP', metadata=metadata)
   
    annotator(viewer)
    