In [None]:
#########################
# defs and imports
#########################

import logging
import os
import shutil

import numpy as np
from numpy.linalg import inv

from nis_util import *
from simple_detection import *

from skimage.transform import AffineTransform
from xmlrpc.client import ServerProxy


def copy_lock(src, dst, copyfun=shutil.copy2, lock_ending='lock'):
    lock_file = '.'.join([dst if not os.path.isdir(dst) else os.path.join(dst, src.rsplit(os.sep, 1)[-1]), lock_ending])
    fd = open(lock_file, 'w')
    fd.close()

    copyfun(src, dst)
    os.remove(lock_file)


def _pix2unit3(x, offset, fov, pixel_size, cam_rotation=None, im_flip=None):
    """
    transform a point from pixel coordinates to NIS stage coordinates,
    taking into account offsets, fov, camera rotation or image flipping
    
    Parameters
    ----------
    x: 4-tuple
        point to transform, in pixels  
    offset: array-like
        extra offset to add to transformed bounding boxes (in units)
        (center of image or center of first tile in large image )
    fov: array-like
        field-of-view size (in units) 
    pixel_size: scalar
        pixel size in units
    cam_rotation: 2x2 mat, optional
        camera rotation matrix as provided by NIS
    im_flip: array-like
        array of 1,-1 indicating whether to flip coordinates in a dimension or not
    
    Returns
    -------
    x_tr: array-like
        transformed point, in units
    """
    
    logger = logging.getLogger(__name__)
    
    # default: no camera rotation
    if cam_rotation is None:
        cam_rotation = np.array([[1,0], [0,1]], dtype=float)
    
    # augmented rotation matrix and inverse
    cam_rot_tr = AffineTransform(np.array([[cam_rotation[0,0], cam_rotation[0,1],  0],
                                           [cam_rotation[1,0], cam_rotation[1,1],  0],
                                           [0,                                 0,  1]])
                                 )
    cam_rot_tr_i = AffineTransform(inv(cam_rot_tr.params))
    
    # default image flip: along y
    im_flip_t = AffineTransform(scale=[1,-1] if im_flip is None else im_flip)
    
    x = np.array(x, dtype=float)
    x_tr = (im_flip_t+cam_rot_tr_i)(x * pixel_size - np.array(fov, dtype=float)/2)

    res = np.squeeze(np.array(offset, dtype=float) + x_tr)
    logger.debug('transformed point {} (pixels) to {} (units)'.format(x, res))
    return res
    
    
def bbox_pix2unit3(bbox, offset, fov, pixel_size, cam_rotation=None, im_flip=None):
    """
    transform a bounding box from pixel coordinates to NIS stage coordinates,
    taking into account offsets, fov, camera rotation or image flipping
    
    Parameters
    ----------
    bbox: 4-tuple
        ymin, xmin, ymax, xmax (as output by skimages regionprops, in pixels)  
    offset: array-like
        extra offset to add to transformed bounding boxes (in units)
        (center of image or center of first tile in large image )
    fov: array-like
        field-of-view size (in units) 
    pixel_size: scalar
        pixel size in units
    cam_rotation: 2x2 mat, optional
        camera rotation matrix as provided by NIS
    im_flip: array-like
        array of 1,-1 indicating whether to flip coordinates in a dimension or not
    
    Returns
    -------
    bbox_tr: 4-tuple
        transformed bounding box (ymin, xmin, ymax, xmax - in units)
    """
    
    logger = logging.getLogger(__name__)
      
    # transform bbox
    (ymin, xmin, ymax, xmax) = bbox    
    bbox_tr = np.apply_along_axis(lambda x: _pix2unit3(x, list(reversed(list(offset))), list(reversed(list(fov))), pixel_size, cam_rotation, im_flip),
                                  1, 
                                  np.array([[xmin, ymin],
                                            [xmin, ymax],
                                            [xmax, ymin],
                                            [xmax, ymax]], dtype=float)
                                  )
    
    # get new min max
    min_ = np.apply_along_axis(np.min, 0, bbox_tr)
    max_ = np.apply_along_axis(np.max, 0, bbox_tr)
    
    logger.debug('new min: {}, new max: {}'.format(min_, max_))
    
    # NB: we reverse here to preserve original ymin, xmin, ymax, xmax - order
    bbox_tr_arr = np.array([list(reversed(list(min_))), list(reversed(list(max_)))], dtype=float)
    res = bbox_tr_arr.ravel()
    
    logger.debug('bbox: {}, toUnit: {}'.format(bbox, res))
    return tuple(list(res))

def bbox_pix2unit2(bbox, start, pixsize, direction, fov, mat):    
    """
    old, possibly wrong pix2unit
    TODO: remove if no longer necessary
    """
    
    logger = logging.getLogger(__name__)
    
    extra_offset = inv(mat).dot(np.array(list(reversed(fov)), dtype=float)/2) * np.array(list(reversed(direction)), dtype=float)
    start_ = np.array(start, dtype=float) + np.array(list(reversed(list(extra_offset))), dtype=float)
    
    logger.debug('extra offset: {}, new start: {}'.format(extra_offset, start_))
    
    (ymin, xmin, ymax, xmax) = bbox
    
    bbox_tr = np.apply_along_axis(lambda x: mat.dot(x), 1, np.array([[xmin, ymin], 
                                                           [xmin, ymax],
                                                           [xmax, ymin],
                                                           [xmax, ymax]], dtype=float))
    # TODO: WHY?
    bbox_tr *= -1
    
    min_ = np.apply_along_axis(np.min, 0, bbox_tr)
    max_ = np.apply_along_axis(np.max, 0, bbox_tr)
    
    logger.debug('new min: {}, new max: {}'.format(min_, max_))
    
    bbox_tr_arr = np.array([list(reversed(list(min_))), list(reversed(list(max_)))], dtype=float)
    
    res = (bbox_tr_arr * np.array(pixsize, dtype=float) *
            np.array(direction, dtype=float) + start_)
    
    logger.debug('bbox: {}, toUnit: {}'.format(bbox, res.reshape((4,))))
    return res.reshape((4,))


def do_scan(field, oc_overview, ocs_detail, path_to_nis, save_base_path, prefix, server_path_local, server_path_remote, suffix='.nd2', do_plot=True,
           z_range = 10, z_step=2, z_drive='Ti2 ZDrive', dry_run_details=False, stitched=True, separate_dirs_on_server=True, re_use_ov=False):
    
    logger = logging.getLogger(__name__)
    
    if stitched and not np.isscalar(ocs_detail):
        logger.warning('Doing multi-channel acquisition, cannot use NIS stitching. Please stitch manually.')
    
    # get field and directions
    # NB: this is not the actual field being scanned, but rather [min+1/2 fov - max-1/2fov]
    (left, right, top, bottom) = field
    direction = [1 if top<bottom else -1, 1 if left<right else -1]
    
    # set overview optical configuration
    set_optical_configuration(path_to_nis, oc_overview)
    
    # get resolution, binning and fov
    (xres, yres, siz, mag) = get_resolution(path_to_nis)
    
    live_fmt, capture_fmt = get_camera_format(path_to_nis)
    binning_factor = float(capture_fmt.split()[1].split('x')[0])
    
    fov_x = xres * siz / mag * binning_factor
    fov_y = yres * siz / mag * binning_factor
    
    logger.debug('overview resolution: {}, {}, {}, {}'.format(xres, yres, siz, mag))
    
    # do overview scan
    ov_path = os.path.join(save_base_path, prefix + '_overview' + suffix)
    if not re_use_ov:
        do_large_image_scan(path_to_nis, ov_path, left, right, top, bottom)
    
    logger.info('finished overview, detecting wings...')
    
    # detect wings
    img = read_bf(ov_path)
    
    # copy to server mount
    copy_lock(ov_path, server_path_local + (os.sep + 'overviews') if separate_dirs_on_server else '')
    
    if separate_dirs_on_server:
        remote_path = '/'.join([server_path_remote, 'overviews', prefix + '_overview' + suffix])
    else:
        remote_path = '/'.join([server_path_remote, prefix + '_overview' + suffix])
    
    # TODO: implement with CNN
    # we want 4x downsampling for detection
    ds = max(0, int(round(2 - np.log2(binning_factor))))
    
    if ds != 0:
        img = list(pyramid_gaussian(img, ds))[-1]
    
    flt = {
        'area': (20000, 80000)
    }
    
    with ServerProxy("http://eco-gpu:8000/") as proxy:
        bboxes = proxy.detect_bbox(remote_path, binning_factor, flt)
    
    if do_plot:
        plt.figure()
        plt.imshow(img)
    
    bboxes_scaled = []
    for bbox in bboxes[0]:
        bbox_scaled = np.array(tuple(bbox)) * binning_factor
        logger.debug('bbox: {}, upsampled: {}'.format(bbox, bbox_scaled))
        bboxes_scaled.append(bbox_scaled)
        if do_plot:
            minr, minc, maxr, maxc = tuple(list(bbox))
            rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                  fill=False, edgecolor='red', linewidth=2)
            plt.gca().add_patch(rect)
    
    #bboxes = detect_wings_simple(img, pixel_size=siz/mag*binning_factor, plot=do_plot, layers=ds)
    bboxes = bboxes_scaled
    
    # get camera rotation
    mat = np.array(get_rotation_matrix(path_to_nis)).reshape((2,2))
    
    # pixels to units
    bboxes = [bbox_pix2unit3(b, [top, left], [fov_y/binning_factor, fov_x/binning_factor], siz/mag, mat) for b in bboxes]
    
    # expand bounding boxes
    bboxes = [scale_bbox(bbox, expand_factor=.3) for bbox in bboxes]
    
    print('detected {} wings:'.format(len(bboxes)))

    for idx, bbox in enumerate(bboxes):
        
        # FIXME: testing code
        # first few detections on test slide are crap, skip them
        #if idx < 7:
        #    continue
        
        print('scanning wing {}: {}'.format(idx, bbox))
        
        (ymin, xmin, ymax, xmax) = bbox
        (ymin, xmin, ymax, xmax) = (ymin if direction[0] > 0 else ymax,
                                    xmin if direction[1] > 0 else xmax,
                                    ymin if direction[0] < 0 else ymax,
                                    xmin if direction[1] < 0 else xmax)
        
        wing_path = os.path.join(save_base_path, prefix + '_wing' + str(idx) + suffix) 
        
        # only one optical configuration
        if np.isscalar(ocs_detail):
            
            # set oc so we have correct magnification
            set_optical_configuration(path_to_nis, ocs_detail)

            # get resolution
            (xres, yres, siz, mag) = get_resolution(path_to_nis)
            fov = get_fov_from_res(get_resolution(path_to_nis))
            logger.debug('detail resolution: {}, {}, {}, {}'.format(xres, yres, siz, mag))
            logger.debug('fov: {}'.format(fov))

            # get fov
            fov_x = xres * siz / mag
            fov_y = yres * siz / mag
            
            # do not actually do the detail acquisition
            if dry_run_details:
                continue
                
            # do a manual grid acquisition via multipoint nD acquisition -> has to be stitched afterwards
            if not stitched:
                
                # we scan around current z -> get that
                pos = get_position(path_to_nis)
                
                # generate the coordinates of the tiles
                grid = gen_grid(fov, [xmin, ymin], [xmax, ymax], 0.15, True, True, True)
                
                for g in grid:
                    logger.debug('wing {}: will scan tile at {}'.format(idx-1, g))
                    
                nda = NDAcquisition(wing_path)
                nda.set_z(int(z_range/2), int(z_range/2), int(z_step), z_drive)
                nda.add_points(map(lambda x : (x[0], x[1], pos[2] - pos[3]), grid))
                nda.prepare(path_to_nis)
                nda.run(path_to_nis)
                
                
                copy_lock(wing_path, server_path_local)
                
            # do NIS's scan large image -> stitching is performend in NIS
            else:
                do_large_image_scan(path_to_nis, wing_path, xmax, xmin, ymin, ymax, 15, True)
        
        # multiple ocs -> we have to do nD acquisition
        else:
            
            # set to first oc so we have correct magnification
            set_optical_configuration(path_to_nis, ocs_detail[0])

            # get resolution
            (xres, yres, siz, mag) = get_resolution(path_to_nis)
            fov = get_fov_from_res(get_resolution(path_to_nis))
            logger.debug('detail resolution: {}, {}, {}, {}'.format(xres, yres, siz, mag))
            logger.debug('fov: {}'.format(fov))

            # get fov
            fov_x = xres * siz / mag
            fov_y = yres * siz / mag
            
            # generate the coordinates of the tiles
            grid = gen_grid(fov, [xmin, ymin], [xmax, ymax], 0.15, True, True, True)
            
            for g in grid:
                logger.debug('wing {}: will scan tile at {}'.format(idx-1, g))

            # do not actually do the detail acquisition
            if dry_run_details:
                continue
                
            # NB: we have multiple channels, so we have to do
            # manual grid acquisition via multipoint nD acquisition -> has to be stitched afterwards
                
            # we scan around current z -> get that
            pos = get_position(path_to_nis)

            # setup nD acquisition
            nda = NDAcquisition(wing_path)
            nda.set_z(int(z_range/2), int(z_range/2), int(z_step), z_drive)
            nda.add_points(map(lambda x : (x[0], x[1], pos[2] - pos[3]), grid))
            
            for oc in ocs_detail:
                nda.add_c(oc)
            
            nda.prepare(path_to_nis)
            nda.run(path_to_nis)
            
            # copy to server mount
            copy_lock(wing_path, server_path_local)



In [None]:
###################
# set up the environment, nis, and image saving path
###################

path_to_nis = 'C:\\Program Files\\NIS-Elements\\nis_ar.exe'
save_base_path = 'C:\\Users\\Nikon\\Documents\\David\\tmpOverview'
save_server_path_local = 'Y:\\auto-test'
save_server_path_remote = '/data/wing-scanner/auto-test'

%matplotlib inline
plt.rcParams['figure.figsize'] = [10,10]

In [None]:
#################
# do the scans
#################

logging.basicConfig(format='%(asctime)s - %(levelname)s in %(funcName)s: %(message)s', level=logging.DEBUG)
logger = logging.getLogger(__name__)

# name of the slides to image
# set to 'None' to skip a slide

slide_left = 'DH_Overview_011'
slide_mid = None #'NG_Overview_024'
slide_right = None# 'DH_Overview_005'

do_scan_left  = slide_left != None
if do_scan_left:
    logger.info('Scanning left scan.')
    do_scan((53331, 28806, -20726, 20464), 'DIA4x', ['DIA10x', 'GFP 10x'], path_to_nis, save_base_path, slide_left, save_server_path_local, save_server_path_remote)
else:
    logger.info('Skipping left slide.')

# mid slide
do_scan_mid  = slide_mid != None
if do_scan_mid:
    logger.info('Scanning mid scan.')
    do_scan((13400, -7850, -20954, 18220), 'DIA4x', ['DIA10x', 'GFP 10x'], path_to_nis, save_base_path, slide_mid, save_server_path_local, save_server_path_remote )
else:
    logger.info('Skipping middle slide.')

# right slide
do_scan_right  = slide_right != None
if do_scan_right:
    logger.info('Scanning right scan.')
    do_scan((-26500, -50000, -21053, 18177), 'DIA4x', ['DIA10x', 'GFP 10x'], path_to_nis, save_base_path, slide_right, save_server_path_local, save_server_path_remote )
else:
    logger.info('Skipping right slide.')

# various test code below

In [None]:
#copy_lock(ov_path, server_path_local + (os.sep + 'overview') if separate_dirs_on_server else '')



In [None]:
from skimage.transform import pyramid_gaussian

ov_path = 'C:\\Users\\Nikon\\Documents\\David\\tmpOverview\\DH_Overview_00x_overview.nd2'
img = read_bf(ov_path)   

pyr = [p for p in pyramid_gaussian(img, max_layer= 2, downscale = 2)]
img_ds = pyr[2]

In [None]:
# get resolution, binning and fov
(xres, yres, siz, mag) = get_resolution(path_to_nis)
fov_x = xres * siz / mag
fov_y = yres * siz / mag
    
live_fmt, capture_fmt = get_camera_format(path_to_nis)
binning_factor = float(capture_fmt.split()[1].split('x')[0])

bboxes = detect_wings_simple(img, pixel_size=siz/mag*binning_factor, plot=True, layers=2)

In [None]:
from skimage.color import label2rgb
image_label_overlay = label2rgb(np.zeros(img_ds.shape), image=ii)
plt.imshow(image_label_overlay)

In [None]:
np.max(pyr[0])

ii = rescale_intensity(pyr[2], out_range=(0,1))
np.max(ii)