# NAPARI visualization of UNet Training Data

You can use this notebook to view, modified and save out training data for UNet models

Labels:
+ 0 - background 
+ 1 - GFP/Phase 
+ 2 - RFP


Extra key bindings:
+ 'w' - calculate weightmap
+ 'q' - calculate custom annotated weightmap (weightmask * weightmap)
+ '/' - save label current displayed
+ 's' - save all labels
+ 'o' - output all weightmaps and metadata for tfrecord creation
+ '\>' - grow the label under the mouse cursor
+ '\<' - shrink the label under the mouse cursor
+ 'h' - fill holes in cell mask under the mouse cursor
+ 'n' - count cells (will be updated to collect more stats)

---

```
Authors:
- Alan R. Lowe (a.lowe@ucl.ac.uk)
```

---
## Set up the data path, channel(s) used, weight amplitude and the number of images to load

In [1]:
DATA_PATH = '/home/nathan/analysis/training/training_data'
#DATA_PATH = '/home/nathan/analysis/fucci/working_dir/all/full_stack'

WEIGHT_AMPLITUDE = 50.
ACQUISITION_CHANNELS = ['PHASE'] #complete sets
INCOMPLETE_CHANNELS = ['GFP'] #, 'RFP', 'IRFP'] #list here if sets incomplete ie some images missing
NUMBER_OF_IMAGES_TO_LOAD = 1 #'ALL' #in each set, from index 0, loads in numerical order, if all images desired then enter 'ALL'
INDEX_OF_IMAGES = [1053, 1093] #enter [min, max] and ensure NUMBER_OF_IMAGES_TO_LOAD = 0
#INDEX_OF_IMAGES = [minimum:maximum]
#ACQUISITION_CHANNELS = ['GFP', 'RFP']

---

In [2]:
import os
import re
import enum
import json
import csv
import napari
import pandas as pd
from skimage import io
from skimage.util import random_noise
import numpy as np
from scipy import ndimage
from itertools import islice

from scipy.ndimage.morphology import distance_transform_edt
from scipy.ndimage import gaussian_filter

import matplotlib.pyplot as plt

In [3]:
@enum.unique
class Channels(enum.Enum):
    BRIGHTFIELD = 0 
    GFP = 1
    RFP = 2
    IRFP = 3
    PHASE = 4
    WEIGHTS = 98
    MASK = 99

In [4]:
global filename
#global files



def strip_modified_filename(filename):
    if filename.endswith('.modified.tif'):
        stripped_fn = filename[:-len('.modified.tif')]
        return stripped_fn
    return filename

def make_folder(foldername):
    if os.path.exists(foldername):
        return
    os.mkdir(foldername)
    
def file_root(filename):  #searches for the filename pattern inside the given filename, returns the result of that search as grps
    FILENAME_PATTERN = r'([a-zA-Z0-9]+)_([a-zA-Z0-9]+)_*.tif'
    grps = re.search(FILENAME_PATTERN, filename)
    return grps

def load_training_data(pth, channels=[Channels.GFP, Channels.RFP]):
    """ load training data for visualisation with napari"""
          
    # find the sets and sort them
    global sets #allows sets to be called from another function
    sets = [f for f in os.listdir(pth) if os.path.isdir(os.path.join(pth, f))]
    sets.sort(key = lambda s: int(s[3:]))
    
    def set_filename_format(filename): #both functions below serve to find filenames regardless of x_gfp.tif or gfp_x.tif
        grps = file_root(filename)
        
        if grps.group(1) in [c.name.lower() for c in all_channels]: #if a feature about the fn pattern is in the channel name (ie gfp) then FNFMT = 2, allowing for x_gfp.tif or gfp_x.tif to be sorted 
            FILENAME_FORMAT = 2 #gfp_x.tif
        else:
            FILENAME_FORMAT = 1 #x_gfp.tif
            
        def filename_formatter(filename, channel):
            # assert(channel in [c.name.lower() for c in all_channels])
            grps = file_root(filename)

            return f'{grps.group(FILENAME_FORMAT)}_{channel}.tif'
            # return '{}_{}.tif'.format(*[channel, grps.group(FILENAME_FORMAT)])
        
        return filename_formatter
    
    global label_files
    global fnfmt
    global files
    
    all_channels = [Channels.MASK, Channels.WEIGHTS]+channels
    files = {k:{'files':[], 'data':[], 'sets':[], 'path':[]} for k in all_channels}
    all_channels.remove(Channels.WEIGHTS)
       
    for s in sets:

        # root_folders
        l_root = os.path.join(pth, s, 'labels')
        
        # check that this folder exists 
        if not os.path.exists(l_root):
            raise IOError(f'{l_root} does not exist. Do you need to rename label -> labels?')

        # get the training label files
        label_files = [f for f in os.listdir(l_root) if f.endswith('.tif')]
        #print(label_files)
        
        # sort to remove unmodified files and replace with the modified files
        unmodified_files, modified_files = [], []
        
        #for index, item in enumerate(islice(items, limit)):
        for i, f in enumerate(label_files): #introducing islice here allows you to choose size of dataset! but only in a random order
            if f.endswith('.modified.tif'):
                modified_files.append(strip_modified_filename(f))
            else:
                unmodified_files.append(f)
                
        unmodified_files = list(set(unmodified_files).difference(set(modified_files))) #this is only unmod where mod doesnt exist
        label_files_full = unmodified_files + [f+'.modified.tif' for f in modified_files]    
        label_files_full.sort(key = lambda f: int(f[0:4])) #sorts files numerically

        N = NUMBER_OF_IMAGES_TO_LOAD
        n = INDEX_OF_IMAGES[0]
        m = INDEX_OF_IMAGES[1]
                
        if N == 'ALL':
            label_files = label_files_full
        elif N > 0:
            label_files = label_files_full[:N] #cuts list of files to N images
        else:
            label_files = label_files_full[n:m] #cuts list of files to n->m images       
        ## MAKE THIS ELIF TIDIER SO ONLY ONE INPUT IS REQUIRED ##        
        
        fnfmt = set_filename_format(label_files[0]) #why only label_files[0]? -> bc only need one example to set the fnfmt
                        
        files[Channels.MASK]['path'] += [s+'/labels/'+f for f in label_files]
        files[Channels.MASK]['files'] += [strip_modified_filename(f) for f in label_files]
        files[Channels.MASK]['data'] += [io.imread(os.path.join(l_root, f)) for f in label_files]
        files[Channels.MASK]['sets'] += [s] * len(label_files)
        
        for channel in channels:
            
           #if cfiles channel.name.lower dir doesnt contain gfp or rfp then add option to create noise data,
            channel_dir = os.path.join(pth, s, channel.name.lower())            
            #if os.path.exists(channel_dir): #if fluor/phase channels exist, continue as norm
            cfiles = [fnfmt(l, channel.name.lower()) for l in label_files]
            files[channel]['path'] += [s+'/'+channel.name.lower()+'/'+f for f in cfiles]
            files[channel]['files'] += cfiles #filenames
            files[channel]['data'] += [io.imread(os.path.join(pth, s, channel.name.lower(), f)) for f in cfiles] #actualdata
            files[channel]['sets'] += [s] * len(label_files)
            
            """else: #if fluor ch missing then create one and fill it with noise - should add this to separate func later
                os.mkdir(channel_dir) #create fl ch 
                cfiles = [fnfmt(l, channel.name.lower()) for l in label_files]
                files[channel]['path'] += [s+'/'+channel.name.lower()+'/'+f for f in cfiles]
                files[channel]['files'] += cfiles #filenames 
                files[channel]['sets'] += [s] * len(label_files)
                for f in cfiles:
                    print(f'Adding gaussian noise fluorescence file: {f}')
                    shape = [files[channels[0]]['data'][0].shape]
                    random = np.random.normal(loc=128, scale=128, size=(shape[0])) #(1024,1024)) what is the stdev of an 8bit image?
                    ### need to check if this is really the best way of generating a random image
                    random_path = os.path.join(channel_dir, f)
                    io.imsave(random_path, random.astype('uint8'))
                files[channel]['data'] += [io.imread(os.path.join(pth, s, channel.name.lower(), f)) for f in cfiles]            
                    """
        # now look for weights 
        w_root = os.path.join(pth, s, 'weights') ### change this back to 'weights' and have func to rename previous weight folder?
        #if os.path.exists(w_root):
        wfiles = [fnfmt(l, 'weights') for l in label_files]
        for weight_file in wfiles:
            files[Channels.WEIGHTS]['path'] += [f'{s}/weights/{weight_file}']
            files[Channels.WEIGHTS]['files'] += [weight_file]
            files[Channels.WEIGHTS]['sets'] += [s]
            if os.path.exists(os.path.join(w_root, weight_file)):
                files[Channels.WEIGHTS]['data'] += [io.imread(os.path.join(w_root, weight_file)).astype(np.float32)]
            else:
                print(f'Adding empty weight file: {weight_file}')
                mask_shape = files[channels[0]]['data'][0].shape
                files[Channels.WEIGHTS]['data'] += [np.zeros(mask_shape, dtype=np.float32)]
                
    """check for fluorescent channels, if not there, add fluorescent folders, then add gaussian noise"""            

    w_root = os.path.join(pth, s, 'weights')
                                                         
    # now make image stacks 
    for channel in files.keys():
        for i, im in enumerate(files[channel]['data']):
            print(channel, files[channel]['path'][i], im.shape, im.dtype)
        
        files[channel]['data'] = np.stack(files[channel]['data'], axis=0)
        
    global shape
    shape = files[channels[0]]['data'][0].shape

    return files



In [5]:
channels = [Channels[c.upper()] for c in ACQUISITION_CHANNELS]
data = load_training_data(DATA_PATH, channels)

Channels.MASK set1/labels/0100_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set2/labels/0031_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set3/labels/0000_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set4/labels/0000_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set5/labels/0000_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set6/labels/0009_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set7/labels/0229_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set8/labels/0000_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set9/labels/0020_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set10/labels/0000_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set11/labels/0625_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set12/labels/0033_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set13/labels/0039_mask.tif.modified.tif (1024, 1024) uint8
Channels.MASK set14/labels/0010_mask.tif.modified.tif (1024,



In [6]:
def normalize_images(stack):
    normed = stack.astype(np.float32)
    
    for i in range(stack.shape[0]):
        # normed[i,...] = (normed[i,...]-np.mean(normed[i,...])) / np.std(normed[i,...])
        c = normed[i,...]
        p_lo = np.percentile(c,5)
        p_hi = np.percentile(c,99)
        normed[i,...] = np.clip((c - p_lo) / p_hi, 0., 1.)
    return normed

In [7]:
def bounding_boxes(seg):
    lbl, nlbl = ndimage.label(seg)
    class_label, _, minxy, maxxy = ndimage.extrema(seg, lbl, index=np.arange(1, nlbl+1))
    return class_label, minxy, maxxy

In [8]:
seg = np.zeros(data[channels[0]]['data'].shape, dtype=np.uint8)
mask = data[Channels.MASK]['data']
if mask.ndim == 3:
    seg = mask > 0
elif mask.ndim == 4:
    seg[mask[:,0,...]>0] = 1
    seg[mask[:,1,...]>0] = 2

In [9]:
def convert_to_mask(labels, unique_labels=range(1,len(channels)+1)):
    print(unique_labels)
    seg = np.zeros((len(unique_labels),)+labels.shape, dtype=np.uint8)
    for i,l in enumerate(unique_labels):
        seg[i,...] = labels==l
    return np.squeeze(seg)

In [10]:
def save_labels(viewer):
    # get the current image 
    current_slice = viewer.layers[viewer.active_layer].coordinates[0]
    source_set = data[Channels.MASK]['sets'][current_slice]
    source_file = data[Channels.MASK]['files'][current_slice]
    source_fn = os.path.join(source_set, 'labels', source_file)

    # get the current layer
    current_labels = viewer.layers['labels'].data[current_slice,...]
    current_mask = convert_to_mask(current_labels)

    # write out the modified segmentation mask
    new_file = os.path.join(DATA_PATH, source_fn+'.modified.tif')
    print(new_file)
    io.imsave(new_file, current_mask[0].astype('uint8'))

    print(current_slice, current_labels.shape, new_file)

In [11]:
def save_all_labels(viewer):
        # get the current image 
    current_slice = viewer.layers[viewer.active_layer].coordinates[0]
    #current_slice = viewer.layers[viewer.active_layer].coordinates[1] # prints the x coord
    
    source_set = data[Channels.MASK]['sets'][current_slice]
    source_file = data[Channels.MASK]['files'][current_slice]
    source_fn = os.path.join(source_set, 'labels', source_file)

    # get the current layer
    current_labels = viewer.layers['labels'].data[current_slice,...]
    current_mask = convert_to_mask(current_labels)
    
    # for over all stack
    #for i in range(len(data[Channels.MASK]['files'])):
    for i in range(viewer.layers['labels'].data.shape[0]):
        current_slice = i
        
        source_set = data[Channels.MASK]['sets'][current_slice]
        source_file = data[Channels.MASK]['files'][current_slice]
        source_fn = os.path.join(source_set, 'labels', source_file)
        
        current_labels = viewer.layers['labels'].data[current_slice,...]
        current_mask = convert_to_mask(current_labels)
        
        new_file = os.path.join(DATA_PATH, source_fn+'.modified.tif')
        print(new_file)
        io.imsave(new_file, current_mask[0].astype('uint8'))

        print(current_slice, current_labels.shape, new_file)


In [12]:
# weightmaps = np.zeros((seg.shape), dtype=np.float32)

def calculate_weightmaps(viewer, w0=WEIGHT_AMPLITUDE, current_slice=0):
    # get the current layer and make it binary
    mask = viewer.layers['labels'].data[current_slice,...].astype(np.bool)
    
    # label the image 
    labelled, n_labels = ndimage.label(mask)
    #print(n_labels)
    weight_mask = np.zeros(mask.shape, dtype=np.float32)
    for i in range(1,n_labels+1):
        cell = labelled == i
        not_cell = np.logical_xor(cell, mask) #np.logical_and(labelled != i, labelled > 0)
        mask_diff = gaussian_filter(cell.astype(np.float32), sigma=5) * gaussian_filter(not_cell.astype(np.float32), sigma=5)
        weight_mask += mask_diff

    wmap = w0*weight_mask #* wmap
    
    # normalize it
    wmap += 1.   
    wmap[mask] = 1.
      
    viewer.layers['weightmaps'].data[current_slice,...] = wmap.astype(np.float32)
    viewer.layers['weightmaps'].contrast_limits = (np.min(wmap), np.max(wmap))
    viewer.layers['weightmaps'].visible = True
    
    return wmap

    

In [13]:
# secondary weighting = weightmap * weightmask
# purpose is to create a secondary weightmap that can be used to tell the network to focus/ignore on user-selected ROIs

def calculate_custom_weightmap(viewer, current_slice=0):
    
    w_mask = viewer.layers['weightmask'].data[current_slice,...].astype(np.bool)
    w_map =  viewer.layers['weightmaps'].data[current_slice,...]   
    cust_wmap = w_map*w_mask
        
    #viewer.add_image(cust_wmap, name='custom weightmap', colormap='plasma', visible=True) # by having this here we generate a new custwm image each time one is generated instead of adding to stack
    
    #cust_wmap += 1.   
    #cust_wmap[w_mask] = 1. #is this necessary? not at the moment
    
    viewer.layers['custom weightmap'].data[current_slice,...] = cust_wmap.astype(np.float32) # need this line to 
    viewer.layers['custom weightmap'].contrast_limits = (np.min(cust_wmap), np.max(cust_wmap))
    viewer.layers['custom weightmap'].visible = True
    
    return cust_wmap

In [14]:
def grow_shrink_label(viewer, grow=True, n_iter=1): #by editing the number of iterations i can edit the size of growshrink
    # get the current image 
    current_slice = viewer.layers[viewer.active_layer].coordinates[0]
    current_labels = viewer.layers['labels'].data[current_slice,...] #label as in image
            
    cursor_coords = [int(p) for p in viewer.layers[viewer.active_layer].position]
    labelled, _ = ndimage.label(current_labels.astype(np.bool)) #each obj in label image is labelled
    real_label = current_labels[cursor_coords[0], cursor_coords[1]] #xy cursor coords in current label image
    
    if real_label < 1: return
    
    mask = labelled == labelled[cursor_coords[0], cursor_coords[1]] #assigning mask as labelled image with only xy of specific cell
    if grow:
        mask = ndimage.morphology.binary_dilation(mask, iterations=n_iter)   
    else:
        current_labels[mask] = 0
        mask = ndimage.morphology.binary_erosion(mask, iterations=n_iter)
    current_labels[mask] = real_label
    viewer.layers['labels'].data[current_slice,...] = current_labels
    viewer.layers['labels']._set_view_slice()

    

In [15]:
def count_cells(viewer):
    current_slice = viewer.layers[viewer.active_layer].coordinates[0]
    mask = viewer.layers['labels'].data[current_slice,...].astype(np.bool) 
    labelled, n_labels = ndimage.label(mask)
       
    #print('number of labels/cells in displayed mask is:', n_labels)
    stat_file = os.path.join(DATA_PATH, 'stats.csv')
    
    i=0
    df=pd.DataFrame(data=[0], index=[i], columns=['number of cells'])
    
    for i in range(viewer.layers['labels'].data.shape[0]):
        current_slice = i
        mask = viewer.layers['labels'].data[current_slice,...].astype(np.bool) 
        labelled, n_labels = ndimage.label(mask)
            
        df.loc[i] = n_labels
        df.to_csv(stat_file, index=False)
    print(df)      

In [16]:
def fill_holes(viewer):
    current_slice = viewer.layers[viewer.active_layer].coordinates[0]
    current_labels = viewer.layers['labels'].data[current_slice,...] #label as in image
            
    cursor_coords = [int(p) for p in viewer.layers[viewer.active_layer].position]
    labelled, _ = ndimage.label(current_labels.astype(np.bool)) #each obj in label image is labelled
    real_label = current_labels[cursor_coords[0], cursor_coords[1]] #xy cursor coords in current label image
    
    if real_label < 1: return
    
    mask = labelled == labelled[cursor_coords[0], cursor_coords[1]] #assigning mask as labelled image with only xy of specific cell
    mask = ndimage.morphology.binary_fill_holes(mask) 
    current_labels[mask] = real_label
    viewer.layers['labels'].data[current_slice,...] = current_labels
    viewer.layers['labels']._set_view_slice()

In [17]:
def single_cell_mask(viewer):  #by editing the number of iterations i can edit the size of growshrink
    # get the current image 
    current_slice = viewer.layers[viewer.active_layer].coordinates[0]
    current_labels = viewer.layers['labels'].data[current_slice,...] #label as in image
    
        
    cursor_coords = [int(p) for p in viewer.layers[viewer.active_layer].position]
    labelled, _ = ndimage.label(current_labels.astype(np.bool)) #each obj in label image is labelled
    real_label = current_labels[cursor_coords[0], cursor_coords[1]] #xy cursor coords in current label image
    
    if real_label < 1: return
    
    mask = labelled == labelled[cursor_coords[0], cursor_coords[1]] #assigning mask as labelled image with only xy of specific cell
    #print(mask) # we want this image mask here saved as current labels
    
    current_labels = mask    
    #current_labels[mask] = real_label
    viewer.layers['labels'].data[current_slice,...] = current_labels
    viewer.layers['labels']._set_view_slice()

In [18]:
def fluorescent_noise(viewer):
    pth = DATA_PATH
    for s in sets:
        #add input for gfp and rfp
        channels = [Channels[c.upper()] for c in INCOMPLETE_CHANNELS]
        incomplete_files = {k:{'files':[], 'data':[], 'sets':[], 'path':[]} for k in channels} #need to add something like this to include Channels.GFP etc in files[channels]        
        files.update(incomplete_files) #merging two dicts to create complete file dict
        for channel in channels:
            
               #if cfiles channel.name.lower dir doesnt contain gfp or rfp then add option to create noise data,
                channel_dir = os.path.join(pth, s, channel.name.lower())            
                if os.path.exists(channel_dir): #if fluor/phase channels exist, continue as norm
                    cfiles = [fnfmt(l, channel.name.lower()) for l in label_files]                  
                    files[channel]['path'] += [s+'/'+channel.name.lower()+'/'+f for f in cfiles]
                    files[channel]['files'] += cfiles #filenames
                    files[channel]['data'] += [io.imread(os.path.join(pth, s, channel.name.lower(), f)) for f in cfiles] #actualdata
                    files[channel]['sets'] += [s] * len(label_files)

                else: #if fluor ch missing then create one and fill it with noise - should add this to separate func later
                    os.mkdir(channel_dir) #create fl ch 
                    cfiles = [fnfmt(l, channel.name.lower()) for l in label_files]
                    files[channel]['path'] += [s+'/'+channel.name.lower()+'/'+f for f in cfiles]
                    files[channel]['files'] += cfiles #filenames 
                    files[channel]['sets'] += [s] * len(label_files)
                    for f in cfiles:
                        print(f'Adding gaussian noise fluorescence file: {f}')             
                        #shape = [files[channels[0]]['data'][0].shape]
                        random = np.random.normal(loc=128, scale=128, size=shape) #(1024,1024)) what is the stdev of an 8bit image?
                        ### need to check if this is really the best way of generating a random image
                        random_path = os.path.join(channel_dir, f)
                        io.imsave(random_path, random.astype('uint8'))
                    files[channel]['data'] += [io.imread(os.path.join(pth, s, channel.name.lower(), f)) for f in cfiles]

    if Channels.GFP in data:
        #gfp = normalize_images(data[Channels.GFP]['data'])
        gfp = data[Channels.GFP]['data']
        print(gfp)
        viewer.add_image(gfp, name='GFP', colormap='green', contrast_limits=(0.,1.))
        viewer.layers['RFP'].blending = 'additive'
        
    if Channels.RFP in data:
        #rfp = normalize_images(data[Channels.RFP]['data'])
        viewer.add_image(rfp, name='RFP', colormap='magenta', contrast_limits=(0.,1.))
        viewer.layers['RFP'].blending = 'additive'
        
    if Channels.IRFP in data:
        #irfp = normalize_images(data[Channels.RFP]['data'])
        viewer.add_image(irfp, name='iRFP', colormap='inferno', contrast_limits=(0.,1.)) #maybe change colormap
        viewer.layers['IRFP'].blending = 'additive'

In [None]:
# start napari
with napari.gui_qt():
    viewer = napari.Viewer()
    
    if Channels.GFP in data:
        gfp = normalize_images(data[Channels.GFP]['data'])
        viewer.add_image(gfp, name='GFP', colormap='green', contrast_limits=(0.,1.))
        
    if Channels.RFP in data:
        rfp = normalize_images(data[Channels.RFP]['data'])
        viewer.add_image(rfp, name='RFP', colormap='magenta', contrast_limits=(0.,1.))
        viewer.layers['RFP'].blending = 'additive'
        
    if Channels.PHASE in data:
        phase = normalize_images(data[Channels.PHASE]['data'])
        viewer.add_image(phase, name='Phase', colormap='gray')
    
    if Channels.WEIGHTS in data:
        weightmaps = data[Channels.WEIGHTS]['data']
        viewer.add_image(weightmaps, name='weightmaps', colormap='plasma', visible=False)
        
    if Channels.WEIGHTS in data:
        cust_wmap = data[Channels.WEIGHTS]['data']
        viewer.add_image(cust_wmap, name='custom weightmap', colormap='plasma', visible=False)
    
    viewer.add_labels(seg, name='labels')
    viewer.layers['labels'].opacity = 0.4
    viewer.layers['weightmaps'].blending = 'additive'
    
    weight_mask = np.ones(data[Channels.WEIGHTS]['data'].shape, dtype=np.uint8)
    #cust_wmap = np.ones(data[Channels.WEIGHTS]['data'].shape, dtype=np.uint8)
    viewer.add_labels(weight_mask, name='weightmask', visible=False)
    #viewer.add_image(cust_wmap, name='custom weightmap', colormap='plasma', visible=False)

    @viewer.bind_key('g')
    def k_fluorescent_noise(viewer):
        print("Creating simulated fluorescence channels with gaussian noise")
        fluorescent_noise(viewer) 
            
    @viewer.bind_key('/') #double command
    def k_save_labels(viewer):
        save_labels(viewer)  #need to change the (viewer) bit to save all images   
        
    @viewer.bind_key('s')
    def k_save_all_labels(viewer):
        save_all_labels(viewer)
        
    @viewer.bind_key('n')
    def k_count_cells(viewer):
        count_cells(viewer)
        
    @viewer.bind_key('w')
    def k_calculate_weightmaps(viewer):
        current_slice = viewer.layers[viewer.active_layer].coordinates[0]
        calculate_weightmaps(viewer, current_slice=current_slice)
        ## calculate cust_ here?
        
    @viewer.bind_key('q')
    def k_calculate_custom_weightmap(viewer):
        current_slice = viewer.layers[viewer.active_layer].coordinates[0]
        calculate_custom_weightmap(viewer, current_slice=current_slice)
        
    @viewer.bind_key('<')
    def k_shrink_label(viewer):
        print('shrink label')
        grow_shrink_label(viewer, grow=False)
    
    @viewer.bind_key('>') 
    def k_grow_label(viewer):
        print('grow label')
        grow_shrink_label(viewer, grow=True)
        
    @viewer.bind_key('c') #clear
    def k_single_cell(viewer):
        print('select cell')
        single_cell_mask(viewer)
        
    @viewer.bind_key('h') #fill holes in cell masks
    def k_fill_holes(viewer):
        print('filling holes in cell mask')
        fill_holes(viewer) 
    
    @viewer.bind_key('o')
    def k_output(viewer):
        print('Output all with metadata')
                
        for i in range(viewer.layers['weightmaps'].data.shape[0]):
            
            if np.sum(viewer.layers['weightmaps'].data[i,...]) == 0: # get rid of this condition to calculate weightmaps afresh regardless?
                print(f'Weightmap {i} is empty. Calculating...')
                wmap = calculate_weightmaps(viewer, current_slice=i)
                weight_folder = os.path.join(DATA_PATH, data[Channels.WEIGHTS]['sets'][i], 'weights')
                weight_fn = data[Channels.WEIGHTS]['files'][i]
                print(weight_folder, weight_fn)
                make_folder(weight_folder)
                io.imsave(os.path.join(weight_folder, weight_fn), wmap.astype(np.float32))
          
                #write out cust_weight_map too
                
                cust_wmap = calculate_custom_weightmap(viewer, current_slice=i)
                cust_weight_folder = os.path.join(DATA_PATH, data[Channels.WEIGHTS]['sets'][i], 'custom_weights')       
                #os.path.join((data[Channels.WEIGHTS]['files'][i][:+]),'.custom.tif')
                cust_weight_fn = data[Channels.WEIGHTS]['files'][i] ##add custom to this fn!!!!!!!!!!!! or do i want to keep fn as before for forward compatability?
                print(cust_weight_folder, cust_weight_fn)
                make_folder(cust_weight_folder)
                io.imsave(os.path.join(cust_weight_folder, cust_weight_fn), cust_wmap.astype(np.float32))
                
                ## adding save out of the weight masks
                w_mask = viewer.layers['weightmask'].data[i,...].astype(np.bool)                
                weight_mask_folder = os.path.join(DATA_PATH, data[Channels.WEIGHTS]['sets'][i], 'weight_masks')
                weight_mask_fn = data[Channels.MASK]['files'][i] ##add weight to fn
                print(weight_mask_folder, weight_mask_fn)
                make_folder(weight_mask_folder)
                io.imsave(os.path.join(weight_mask_folder, weight_mask_fn), w_mask.astype(np.float32))
                
                
        # write out a JSON file with the data
        jfn = os.path.join(DATA_PATH, 'training_metadata.json')
        jdata = {}
        for channel in data.keys():
            jdata[channel.name.lower()] = data[channel]['path']
            
        with open(jfn, 'w') as json_file:
            json.dump(jdata, json_file, indent=2, separators=(',', ': '))
        
        

Creating simulated fluorescence channels with gaussian noise
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
Adding gaussian noise fluorescence file: 0064_gfp.tif
[array([[193, 170,  54, ..., 125,  88,  45],
       [235, 117, 203, ...,   

AttributeError: 'list' object has no attribute 'shape'

In [None]:
# # convert segmentation output labels to multichannel stacks

# p = '/Users/arl/Dropbox/Data/TrainingData/set12'
# files = [f for f in os.listdir(os.path.join(p,'labels')) if f.endswith('.tif')]
# for f in files:
#     mask = io.imread(os.path.join(p, 'labels', f))
#     print(mask.shape)
#     gfp = mask==1
#     rfp = mask==2
#     new_mask = np.stack([gfp, rfp], axis=0)
#     io.imsave(os.path.join(p,f), new_mask.astype('uint8'))