In [13]:
!pip install tifffile==2020.6.3 --exists-action i -q
!pip install numpy==1.18.1 --exists-action i -q
!pip install imagecodecs==2020.5.30 --exists-action i -q
!pip install matplotlib==3.1.3 --exists-action i -q
!pip install ipywidgets==7.5.1 --exists-action i -q

In [None]:
"""
The notebook consits of two parts, both in different cells.

Cell 1 : 

Author: Nick Schaub (nick.schaub@nih.gov)

Description: This cell contains WippPy which consists of classes for interacting with WIPP.

Cell 2 :

Author: Gauhar Bains (gauhar.bains@labshare.org)
        
Description: This consists of the main UI and logic for the Polus Image Collection Preview Prototype.

"""

import json as json_lib
import requests, copy, re
from pathlib import Path
import logging

# Initialize the loggercli
logging.basicConfig(format='%(asctime)s - %(name)-8s - %(levelname)-8s - %(message)s',
                    datefmt='%d-%b-%y %H:%M:%S')
logger = logging.getLogger("wipp")
logger.setLevel(logging.WARNING)

# Initialize the WippData Class
class WippData(object):
    """ Wipp data superclass
    
    This class should be implemented by all Wipp data type classes.
    
    """
    _entry_point = None
    _data_type_name = None
    api_route = 'http://wipp-ui.ci.aws.labshare.org/api/'
    _headers = {'Content-Type': 'application/json'}
    _logger = logging.getLogger('wipp.Data')
    
    def __init__(self,name=None,create=False,json=None,api_route=None,**kwargs):
        if api_route != None:
            self._logger.info('api_route: {}'.format(api_route))
            self.api_route = api_route
        if create and name != None:
            self._logger.info('create(): {}'.format(name))
            kwargs['data'] = {'name': name}
            self.json = self.create(**kwargs)
        elif 'data' in kwargs:
            self._logger.info('create(): attempting to create instance of {}'.format(self.__class__.__name__))
            self.json = self.create(**kwargs)
        elif json:
            self._logger.debug('creating object using json')
            self.json = json
        else:
            self.json = self._get(**kwargs)
    
        if self.json!=None:
            for key,value in self.json.items():
                self._logger.debug('setattr(): {}={}')
                setattr(self,key,value)
        
    def __repr__(self):
        return f'{self.name} (id: {self.id})'
    
    def delete(self):
        self._logger.info('delete(): {}'.format(self.api_route + self._entry_point + '/' + self.id))
        requests.delete(self.api_route + self._entry_point + '/' + self.id)
        
    def create(self,**kwargs):
        self._logger.info('create(): {}'.format(self.api_route + self._entry_point))
        return self._post(**kwargs)
        
    def _post(self,**kwargs):
        self._logger.debug('_post(): endpoint={}'.format(self.api_route + self._entry_point))
        
        config = {key:value for key,value in kwargs.items() if key in ['params','headers']}
        if 'data' in kwargs:
            config['data'] = json_lib.dumps(kwargs['data'])
        if 'headers' not in config.keys():
            config['headers'] = self._headers
        
        for key,val in config.items():
            self._logger.debug('_post(): {}={}'.format(key,val))
        
        if 'entrypoint' not in kwargs.keys():
            entrypoint = self._entry_point
        else:
            entrypoint = kwargs['entrypoint']
            
        r = requests.post(self.api_route + entrypoint,**config)
        self._logger.debug('_post(): status_code={}'.format(r.status_code))
        if r.status_code==201 or r.status_code==200:
            return r.json()
        elif r.status_code==409:
            self._logger.warning('_post(): Plugin already exists.')
        else:
            self._logger.critical('_post(): message={}'.format(r.text),exc_info=True)
            raise ValueError(self.__class__.__name__ + ' Error (status code {}): {}'.format(r.status_code,r.text))
        
    def _get(self,entrypoint=None,**kwargs):
        if entrypoint==None:
            entrypoint=self._entry_point
        self._logger.debug('_get(): endpoint={}'.format(self.api_route + entrypoint))
        
        config = {key:value for key,value in kwargs.items() if key in ['params','headers','data']}
        if 'data' in kwargs:
            config['data'] = json_lib.dumps(kwargs['data'])
        if 'headers' not in config.keys():
            config['headers'] = self._headers
        
        for key,val in config.items():
            self._logger.debug('_get(): {}={}'.format(key,val))
        
        r = requests.get(self.api_route + entrypoint,**config)
        self._logger.debug('_get(): status_code={}'.format(r.status_code))
        if r.status_code==200:
            return r.json()
        else:
            self._logger.critical('_get(): message={}'.format(r.text))
            raise ValueError(self.__class__.__name__ + ' Error (status code {}): {}'.format(r.status_code,r.text))
            
    @classmethod
    def setWippUrl(cls,url):
        cls._logger.info('setWippUrl(): {}'.format(url))
        cls.api_route = url
            
    @classmethod
    def all(cls,entry_point=False):
        """Get all instances of a data type

        Args:
            cls: Class reference for handling a WIPP data type
            entry_point: API entry point, appended to api path

        Returns:
            A dictionary, where the keys are hashes referencing a data
            instance and values are data_class objects.
        """
        if not entry_point:
            entry_point = cls._entry_point
        cls._logger.info('all(): getting all instances...')
        page = 0
        numel = 1000
        r = requests.get(cls.api_route + entry_point,params={'page':page,'size':numel})
        if r.status_code==200:
            all_data = r.json()['_embedded'][cls._entry_point]
            data = {}
            for datum in all_data:
                data[datum['id']] = cls(json=datum)
                cls._logger.debug('all(): object={}'.format(data[datum['id']]))
            for i in range(r.json()['page']['totalPages']-1):
                page += 1
                r = requests.get(cls.api_route + entry_point,params={'page':page,'size':numel})
                if r.status_code==200:
                    all_data = r.json()['_embedded'][cls._entry_point]
                    data = {}
                    for datum in all_data:
                        data[datum['id']] = cls(json=datum)
                        cls._logger.debug('all(): object={}'.format(data[datum['id']]))
        else:
            data = {}
        return data
    
    @classmethod
    def get_by_id(cls,oid):
        """Get data by id

        Args:
            cls: Class reference for handling a WIPP data type
            oid: Hash reference of data to access

        Returns:
            An object of type cls
        """    
        cls._logger.debug('get_by_id(): oid={}'.format(oid))
        r = requests.get(cls.api_route + cls._entry_point + '/' + oid)
        if r.status_code==200:
            instance = cls(json=r.json())
        else:
            cls._logger.warning('get_by_id(): returning NoneType')
            instance = None
        return instance
    
    @staticmethod
    def get_name(dtype,value):
        """ Get the name of a data instance

        Args:
            dtype: WIPP data type
            value: Unique hash reference

        Returns:
            A string containing the human readable dataset name
        """
        for cls in WippData.__subclasses__():
            if dtype==cls._entry_point:
                cls._logger.debug('get_name(): finding object associated with id={}'.format(value))
                return cls.all()[value].name
    
class WippJob(WippData):
    """ Class to handle WIPP Jobs

    Attributes:
        name: the name given to the WIPP job
        id: a unique hash assigned to the WIPP job
        json: The raw json returned by the WIPP Job backend query
        status: execution status of the WIPP job
        plugin_id: a unique hash assigned to the plugin used by the WIPP job
        plugin_name: the name of the WIPP plugin executed by the job
        inputs: the plugin input keys and values for the job
        outputs: the plugin output keys and values for the job

    Class Methods:
        get_all(): Returns a dictionary of all jobs {job hash: WippJob object}
        get(jid): Returns job with hash equal to jid

    Object Methods:
        delete(): Delete the job from WIPP.
        create(): Create the job in WIPP.
    """
    _entry_point = 'jobs'
    _data_type_name = 'Job'
    _logger = logging.getLogger('wipp.Data.WippJob')

    # Job template
    _payload = {'name': '',            # name of job
                'wippExecutable': '',  # plugin id
                'type': '',            # name of the plugin
                'dependencies': [],    # job ids for dependencies
                'parameters': {},      # dictionary of parameters
                'outputParameters': {},# dictionary of output parameters
                'wippWorkflow': ''}    # wipp workflow id
        
    def __repr__(self):
        return f'{self.name} (id: {self.id})'
    
class WippWorkflow(WippData):
    """ Class to handle WIPP Workflows

    Attributes:
        name: the name given to the WIPP workflow
        id: a unique hash assigned to the WIPP workflow
        json: the raw json returned by the WIPP Workflow backend query
        status: the execution status of the workflow
        link: a url to the backend workflow json
    
    Class Methods:
        get_all(): Returns a dictionary of all workflows{workflow hash: WippWorkflow object}
        get(wid): Returns workflow with hash equal to wid
    
    Object Methods:
        delete(): Delete the workflow from WIPP.
        create(): Create the workflow in WIPP.
        jobs(): Returns dictionary of all jobs in workflow, {job hash: WippJob object}
    """
    _entry_point = 'workflows'
    _data_type_name = 'Workflow'
    _logger = logging.getLogger('wipp.Data.WippWorkflow')
    
    def jobs(self):
        self._logger.info('jobs(): Getting all jobs for workflow={}'.format(self.id))
        if self.id not in WippWorkflow.all().keys():
            self._logger.critical('jobs(): Could not find workflow jobs')
            raise KeyError('Invalid workflow id.')
        
        r = self._get(entrypoint='jobs/search/findByWippWorkflow?wippWorkflow='+self.id)
        
        jobs = {job_json['id']:WippJob(json=job_json) for job_json in r['_embedded']['jobs']}
        for job in jobs.values():
            self._logger.debug('jobs(): job='.format(job))
            
        return jobs
    
    def update(self):
        self._logger.info('update(): updating workflow - {}'.format(self.name))
        wf = WippWorkflow.get_by_id(self.id)
        for key,value in wf.json.items():
            setattr(self,key,value)
    
    def submit(self):
        self._post(entrypoint='workflows/' + self.id + '/submit',parameters={'wippWorkflow': self.id})
    
    def add_job(self,plugin_name,job_name,inputs,plugin_version=None):
        payload = copy.deepcopy(WippJob._payload)
        plugin = WippPlugin.get_by_name(plugin_name,plugin_version)
        dependency_pattern = r'\{\{ (.*)\.(.*) \}\}'
        self._logger.info('add_job(): job_name={}, plugin_name={}, plugin_version={}'.format(job_name,plugin_name,plugin_version))
        
        # Add basic info to the payload
        payload['name'] = job_name
        payload['wippExecutable'] = plugin.id
        payload['type'] = plugin.name
        payload['wippWorkflow'] = self.id
        
        # validate and set inputs
        for inp in plugin.inputs:
            if inp['name'] not in inputs.keys() and inp['required']:
                self._logger.critical('add_job(): Missing input {} for plugin {}'.format(inp['name'],plugin.name))
            elif inp['name'] not in inputs.keys():
                continue
            self._logger.debug('add_job(): {}={}'.format(inp['name'],inputs[inp['name']]))
            payload['parameters'][inp['name']] = inputs[inp['name']]
            
            # If input has {{ }}, then it has a dependency
            if isinstance(inputs[inp['name']],str):
                dependency = re.match(dependency_pattern,inputs[inp['name']])
                if dependency != None:
                    self._logger.info('add_job(): adding dependency {}'.format(dependency.groups()[0]))
                    payload['dependencies'].append(dependency.groups()[0])
        
        job = WippJob(data=payload)
        
        return job

class WippImageCollection(WippData):
    """ Class to handle WIPP Image Collections

    Attributes:
        name: the name given to the WIPP Image Collection
        id: a unique hash assigned to the WIPP Image Collection
        json: the raw json returned by the WIPP Workflow backend query
    
    Class Methods:
        get_all(): Returns a dictionary of all image collections, {image collection hash: WippImageCollection object}
        get(icid): Returns image collection with hash equal to icid
        get_by_name(ic_name): Returns the first result of a search of image collections matching ic_name
    
    Object Methods:
        create(): Create the workflow in WIPP.
        images(): Return a list of dictionaries containing information on every image in the collection
    """
    _entry_point = 'imagesCollections'
    _data_type_name = 'Image Collection'
    _images = []
    _logger = logging.getLogger('wipp.Data.WippImageCollection')
    
    def delete(self):
        """ Throw an error only if image collection is locked"""
        self._logger.info('delete(): deleting image collection - {}'.format(self.name))
        if self.locked:
            self._logger.critical('delete(): Cannot delete locked image collection.')
            raise PermissionError('Cannot delete locked image collection.')
        else:
            super().delete()
    
    @classmethod
    def get_by_name(cls,ic_name):
        cls._logger.info('get_by_name(): getting image collection - {}'.format(ic_name))
        r = requests.get(cls.api_route + cls._entry_point + '/search/findByName',params={'name':ic_name})
        cls._logger.debug('get_by_name(): status_code={}'.format(r.status_code))
        if r.status_code==200:
            imageCollection = cls(json=r.json()['_embedded'][cls._entry_point][0])
        else:
            imageCollection = []
        return imageCollection
    
    def update(self):
        self._logger.info('update(): updating image collection - {}'.format(self.name))
        ic = WippImageCollection.get_by_name(self.name)
        for key,value in ic.json.items():
            setattr(self,key,value)

    def add_image(self,file_path):
        self._logger.info('add_image(): file_path={}'.format(file_path))
        if self.locked:
            self._logger.info('add_image(): cannot add image to locked collection')
            raise PermissionError('Cannot add images to locked collection.')
        if not isinstance(file_path,Path):
            file_path = Path(file_path)
        return WippImage(self.id,file_path)
    
    def lock(self):
        self._logger.info('lock(): locking imaging collection - {}'.format(self.name))
        r = requests.patch(self.api_route + self._entry_point + '/' + self.id,
                           headers={'Content-Type': 'application/json'},
                           data=json_lib.dumps({'locked': True}))
        self._logger.debug('lock(): status_code={}'.format(r.status_code))
        
    def images(self):
        self._logger.info('images(): getting all images for image collection - {}'.format(self.name))
        if len(self._images) > 0 and self.locked:
            return self._images
        page = 0
        numel = 1000
        images = []
        r = self._get(entrypoint=self._entry_point + '/' + self.id + '/images',
                      params={'page':page,'size':numel})
        if '_embedded' not in r.keys():
            return images
        
        images = r['_embedded']['images']
        
        for i in range(r['page']['totalPages']-1):
            page += 1
            r = requests._get(self.api_route + self._entry_point + '/' + self.id + '/images',
                              params={'page':page,'size':numel})
            images.extend(r['_embedded']['images'])
        self._images = images
        return images
    
class WippCsvCollection(WippData):
    """ Class to handle WIPP Csv Collections

    Attributes:
        name: the name given to the WIPP csv collection
        id: a unique hash assigned to the WIPP csv collection
        json: the raw json returned by the WIPP csv collection backend query
        
    Class Methods:
        all(): Returns a dictionary of all csv collections, {csv collection hash: WippCsvCollection object}
    
    Object Methods:
        delete(): Delete the csv collection from WIPP.
        create(): Create the csv collection in WIPP.
    """
    _entry_point = 'csvCollections'
    _data_type_name = 'CSV Collection'
    _logger = logging.getLogger('wipp.Data.WippCsvCollection')
    
class WippNotebook(WippData):
    """ Class to handle WIPP Notebook

    Attributes:
        name: the name given to the WIPP notebook
        id: a unique hash assigned to the WIPP notebook
        json: the raw json returned by the WIPP notebook backend query
        
    Class Methods:
        all(): Returns a dictionary of all notebooks, {notebook hash: WippNotebook object}
    
    Object Methods:
        delete(): Delete the notebook from WIPP.
        create(): Create the notebook in WIPP.
    """
    _entry_point = 'notebooks'
    _data_type_name = 'Notebook'
    _logger = logging.getLogger('wipp.Data.WippNotebook')

class WippStitchingVector(WippData):
    """ Class to handle WIPP Stitching Vectors

    Attributes:
        name: the name given to the WIPP stitching vector
        id: a unique hash assigned to the WIPP stitching vector
        json: the raw json returned by the WIPP stitching vector backend query
    
    Class methods:
        all(): Returns a dictionary of all stitching vectors, {stitching vector hash: WippStitchingVector object}
        
    Object Methods:
        delete(): Delete the stitching vector from WIPP.
        create(): Create the stitching vector in WIPP.
    """
    _entry_point = 'stitchingVectors'
    _data_type_name = 'Stitching Vector'
    _logger = logging.getLogger('wipp.Data.WippStitchingVector')

class WippPyramid(WippData):
    """ Class to handle WIPP Pyramid

    Attributes:
        name: the name given to the WIPP pyramid
        id: a unique hash assigned to the WIPP pyramid
        json: the raw json returned by the WIPP pyramid backend query
    
    Class Methods:
        all(): Returns a dictionary of all pyramids, {pyramid hash: WippPyramid object}
        
    Object Methods:
        delete(): Delete the pyramid from WIPP.
        create(): Create the pyramid in WIPP.
    """
    _entry_point = 'pyramids'
    _data_type_name = 'Image Pyramid'
    _logger = logging.getLogger('wipp.Data.WippPyramid')

class WippImage(object):
    """ Class to handle WIPP Images

    Unlike most other WIPP classes, the WippImage class acts very differently from the
    other data types. Part of this comes from images being a child of an image collection
    and therefore necessitates attachment to a WippImageCollection id.

    In general, the best way to instantiate this class is through an WippImageCollection
    object using either the images() method to get all images in a collection or the
    add_image() method to prepare an image to upload to an unlocked collection.

    Attributes:
        To be determined
    
    Class Methods:
        To be determined
    
    Object Methods:
        delete(): Delete the image from an unlocked collection in WIPP.
        send(): Send the image to WIPP.
    """
    _entry_point = 'imagesCollections/{}/images'
    _data_type_name = 'Image'
    _flowChunkSize = 1048576
    _logger = logging.getLogger('wipp.Data.WippImage')

    def __init__(self,ic_id,file_path):
        self._entry_point = WippData.api_route + self._entry_point.format(ic_id)
        self.file_path = Path(file_path)
        if not self.file_path.is_file():
            self._logger.critical('__init__(): could not find file - {}'.format(str(self.file_path.absolute())))
            raise FileNotFoundError('Could not find file: {}'.format(str(self.file_path.absolute())))

        with open(self.file_path,'rb') as in_file:
            in_file.seek(0,2)
            self._flowTotalSize = in_file.tell()
            self._flowTotalChunks = self._flowTotalSize//self._flowChunkSize
        
        self._flowFilename = self.file_path.name

        self.params = {'flowChunkNumber': 1,
                       'flowChunkSize': self._flowChunkSize,
                       'flowCurrentChunkSize': self._flowChunkSize,
                       'flowTotalSize': self._flowTotalSize,
                       'flowIdentifier': str(self._flowTotalSize) + '-' + self.file_path.name.replace('.',''),
                       'flowFilename': self.file_path.name,
                       'flowRelativePath': self.file_path.name,
                       'flowTotalChunks': self._flowTotalChunks}
        
        for key,val in self.params.items():
            self._logger.debug('__init__(): {}={}'.format(key,val))
        
    def get_name(self):
        self._logger.info('get_name(): name={}'.format(self.file_path.name))
        return self.file_path.name
        
    def set_name(self,name):
        self._logger.info('set_name(): name={}'.format(name))
        suffix = ''.join(self.file_path.suffixes)
        name = name.split('.')[0] + suffix
        self.params['flowIdentifier'] = str(self._flowTotalSize) + '-' + name.replace('.','')
        self._logger.debug('set_name(): flowIdentifier={}'.format(self.params['flowIdentifier']))
        self.params['flowFilename'] = name
        self._logger.debug('set_name(): flowFilename={}'.format(self.params['flowFilename']))
        self.params['flowRelativePath'] = name
        self._logger.debug('set_name(): flowRelativePath={}'.format(self.params['flowRelativePath']))
        
    def send(self):
        self._logger.info('send(): file={}'.format(self.file_path))
        with open(self.file_path,'rb') as in_file:
            for chunk in range(1,self._flowTotalChunks):
                self._logger.debug('send(): sending chunk {} of {} for file {}'.format(chunk,self._flowTotalChunks,self.file_path))
                self.params['flowChunkNumber'] = chunk
                for retry in range(0,10):
                    try:
                        r = requests.post(self._entry_point,
                                        params=self.params,
                                        headers={'Content-Type': 'image/tiff'},
                                        data=in_file.read(1048576))
                        break
                    except:
                        if retry==9:
                            print('{}: Reached max tries.'.format(self.params['flowFilename']))
                            raise
                        print('{}: There was an upload error, will retry in 3 seconds (try {})'.format(self.params['flowFilename'],retry+1))
                        in_file.seek(-1048576,1)
                        time.sleep(3)
            self.params['flowChunkNumber'] = self._flowTotalChunks
            self.params['flowCurrentChunkSize'] = self._flowTotalSize-in_file.tell()
            self._logger.debug('send(): sending chunk {} of {} for file {}'.format(self._flowTotalChunks,self._flowTotalChunks,self.file_path))
            r = requests.post(self._entry_point,
                              params=self.params,
                              headers={'Content-Type': 'image/tiff'},
                              data=in_file.read(self._flowTotalSize-in_file.tell()))

class WippTensorflowModel(WippData):
    """ Class to handle WIPP Tensorflow Models

    Attributes:
        name: the name given to the WIPP tensorflow model
        id: a unique hash assigned to the WIPP tensorflow model
        json: the raw json returned by the WIPP tensorflow model backend query
    
    Class Methods:
        all(): Returns a dictionary of all models, {tensorflow model hash: WippTensorflowModel object}
    
    Object Methods:
        delete(): Delete the tensorflow model from WIPP.
        create(): Create the tensorflow model in WIPP.
    """
    _entry_point = 'tesorflowModels'
    _data_type_name = 'Tensorflow Models'
    _logger = logging.getLogger('wipp.Data.WippTensorflowModel')
        
class WippPlugin(WippData):
    """ Class to handle WIPP Plugins

    Attributes:
        name: the name given to the WIPP plugin
        id: a unique hash assigned to the WIPP plugin
        json: the raw json returned by the WIPP plugin backend query
        version: the plugin version
        inputs: a dictionary containing plugin inputs and settings
        outputs: a dictionary containing plugin output types and settings
        ui: a dictionary containing ui settings
    
    Class Methods:
        all(): Returns a dictionary of all plugins, {plugin hash: WippPlugin object}
        
    Object Methods:
        delete(): Delete the plugin from WIPP.
        create(): Create the plugin in WIPP.
    """
    _entry_point = 'plugins'
    _data_type_name = 'Plugin'
    _logger = logging.getLogger('wipp.Data.WippPlugin')

    # Get the newest plugin that matches a plugin name
    @classmethod
    def get_by_name(cls,name,version=None):
        cls._logger.info('get_by_name(): name={}, version={}'.format(name,version))
        all_plugins = cls.all().values()
        matching_plugins = [p for p in all_plugins if p.name==name]
        
        # If there are no matching plugins, throw an error
        if len(matching_plugins)==0:
            raise ValueError('No plugins match the supplied name: {}'.format(name))
        
        # If no version provided, get the latest version of the plugin
        if version == None:
            # If only one plugin matches, return that
            if len(matching_plugins)==1:
                return matching_plugins[0]
            
            version = [0,0,0] # major, minor, patch
            
            for p in reversed(matching_plugins):
                c_ver = re.match(r"([0-9]+).([0-9]+).([0-9]+)-?(.*)?",p.version)
                for i in range(3):
                    v = version[i]
                    c = c_ver.groups()[i]
                    if int(c) > v:
                        plugin = p
                        version = [int(v) for v in p.version.split('.')]
                        break
                    elif int(c) < v:
                        break
        # Return specified version of plugin
        else:
            for p in reversed(matching_plugins):
                if p.version==version:
                    return p
            # If the specified version could not be found, throw an error
            raise ValueError('Version {} of plugin {} was not found in WIPP. Try installing it.'.format(version,name))
            
        return plugin
        
    def __repr__(self):
        return f'{self.name} (id: {self.id}, version: {self.version})'
    
    @classmethod
    def all(cls):
        return super().all(entry_point='plugins/')
    
    @classmethod
    def install(cls,json):
        cls._logger.info('install(): installing plugin...')
        cls._logger.debug('install(): json={}'.format(json))
        p = cls(data=json)

In [1]:
"""

This cell consists of the main UI and logic for the Polus Image Collection Preview Prototype.

"""


import cv2
import numpy as np
import re
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import math
import json 
import tifffile
from copy import copy, deepcopy


def get_selection_id(selection):
    selection_id = re.match(r".* \(id: ([0-9A-Za-z]+).*\)",selection)
    return selection_id.groups()[0]

def fig2data ( fig ):
    """ This function converts a matplotlib plot to an rbg image
    
    Input: 
        fig: matplotlib figure
    Output:
        data: rbg image for the input plot
    """    
    
    fig.canvas.draw ( )    
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))  
    return data

def get_scale_factor(height,width):  
    """
    gets the scale factor w.r.t to a target size
    
    Input: 
          height : image height
          width : image width
    Output:
          scale_factor: scale factor w.r.t target size
    """
    TARGET_SIZE=1500000 # 1 megapixel
    scale_factor=math.sqrt((height*width)/TARGET_SIZE)    
    return int(round(scale_factor)) if scale_factor>1 else 1

def get_metadata_json(path):
    """
    This functions uses tifffile to read the metadata of the image
    
    Input :
           path: path to image file
    Output:
           json_object : metadata in json form
    """
    with tifffile.TiffFile(path) as tif:
        tif_tags = {}
        for tag in tif.pages[0].tags.values():        
            name, value = tag.name, tag.value
            tif_tags[name] = value
    if 'ImageDescription' in tif_tags.keys():
        del tif_tags['ImageDescription']    
    json_object = json.dumps(tif_tags, indent = 4)     
    return json_object

    
def display_img(im_index):        
        """
        This function update the final image array and is called
        each time a change is made to the image display parameters.
        """
    
        global img_display
        global brightness_dict
        global contrast_dict
        global alpha_dict
        global check_box_dict 
        global img
        global color_array
        global intensity_dictionary
        global threshold_dictionary
        global final_img
       
        
        # build the final image array using the image display parameters
        image=deepcopy(img[im_index])
        image=np.clip(image*contrast_dict[im_index] + brightness_dict[im_index],0,1)
        inds1=image < intensity_dictionary[im_index][0]/255
        inds2=image > intensity_dictionary[im_index][1]/255
        image[inds1] = 0
        image[inds2] = 1
        image=np.interp(image,(intensity_dictionary[im_index][0]/255,intensity_dictionary[im_index][1]/255),(0,1))
        
        # threshold
        inds1=image < threshold_dictionary[im_index][0]/255
        inds2=image > threshold_dictionary[im_index][1]/255
        image[inds1] = 0
        image[inds2] = 0
        
        # update histogram
        update_histogram(im_index,image)  
        
        # update final image array
        final_img[:,:,:,im_index]=np.dstack((image,image,image))*color_array[im_index]        
        display=final_img[:,:,:,0]*check_box_dict[0]*alpha_dict[0]
        for i in range(1,final_img.shape[3]):
            display=display+final_img[:,:,:,i]*check_box_dict[i]*alpha_dict[i]  
        
        # update displa widget
        img_byte_encoded = cv2.imencode('.png', display)[1].tostring()
        img_display.value=img_byte_encoded
        img_display.width=final_img.shape[1]
        img_display.height=final_img.shape[0]
        

def parse_color(im_index):
    """
    Call back function for the color picket widget.
    It updates the color dictionary calls the display_img
    function to update the display image
    """
    def call_back(*args):
        global img
        global color_array
        
        # parse RGB values from the color picker
        red = int(args[0]['new'][1:3],16)
        green = int(args[0]['new'][3:5],16) 
        blue = int(args[0]['new'][5:7],16)   
        
        # update rbg value corresponding to the image        
        color_array[im_index]=np.array([blue,green,red]).reshape(1,1,3)        
        display_img(im_index)        
    return call_back   

def intensity_observer(im_index,key):
    """
    Call back function for the intensity slider widget.
    It updates the intensity dictionary and calls the display_img
    function to update the display image
    """    
    def call_back(*args):
        global intensity_dictionary 
        global threshold_dictionary
        
        # update intensity dictionary
        if key=='intensity':
            intensity_dictionary[im_index][0]=int(args[0]['new'][0] ) 
            intensity_dictionary[im_index][1]=int(args[0]['new'][1] )
            
        # update threshold dictionary
        elif key == 'threshold':
            threshold_dictionary[im_index][0]=int(args[0]['new'][0] ) 
            threshold_dictionary[im_index][1]=int(args[0]['new'][1] )          
            
        display_img(im_index)
        
    return call_back
        
        

def image_checkbox_observer(im_index):
    """
    Call back function for the view_image check box widget. 
    It updates the check_box_dict when a user wants to 
    view/hide image.
    """
    def call_back(*args):
        global img_display
        global brightness_dict
        global contrast_dict
        global check_box_dict 
        
        # get_selection        
        selection=args[0]['new']  
        
        # update dictionary
        if selection== True:
            check_box_dict[im_index]=True   
        elif selection== False: 
            check_box_dict[im_index]=False   
        
        # update final image aray
        display_img(im_index)
            
    return call_back   

image_collection_path=[]
def image_collection_observer(image_collections,images_widget,img_index):
    """
    call back function for the `select image collection` widget
    """
    def call_back(*args):   
        global image_collection_path
        
        # parse collection name   
        if args[0]['new'] == 'Same as previous collection':
            image_collection_path.append(image_collection_path[-1])
        else:
            selection = get_selection_id(args[0]['new']) 
            image_collection_path.append(os.path.join('/opt/shared/wipp/collections', selection,'images'))
        
        # enable widget to select image within that collection
        images_widget.disabled=False
        images_widget.options=os.listdir(image_collection_path[img_index]) 
        
        image_collections.value=args[0]['new']
        image_collections.options=[str(ic) for ic in WippImageCollection.all().values()]        
    return call_back  

# initialize image display parameters
img=[]
color_array={}
brightness_dict={}
contrast_dict= {}
alpha_dict={}
check_box_dict={}
intensity_dictionary={}
threshold_dictionary={}
final_img=[]

def image_display_observer(img_index,metadata_textbox,view_image,histogram_widget):
    def call_back(*args): 
        global img        
        global img_display
        global add_image_button
        global brightness_dict
        global contrast_dict
        global alpha_dict
        global check_box_dict
        global color_array
        global number_of_images
        global intensity_dictionary
        global threshold_dictionary
        global final_img
        
        add_image_button.disabled=False
        
        
        # get selection and load image
        selection=args[0]['new']
        path=os.path.join(image_collection_path[img_index],selection)        
        img_new=tifffile.imread(path,is_ome=True)      
        height,width=img_new.shape
        
        # scale down if image is larger than target size
        scale_factor=get_scale_factor(height,width)
        img_new=np.resize(img_new,(int(height/scale_factor),int(width/scale_factor)))  
        
        # load metadata
        metadata_json= get_metadata_json(path)
        metadata_textbox.value=metadata_json 
        
        # normalize image array
        img_new=img_new/np.iinfo(img_new.dtype).max        
        
        
        # update image display parameters
        brightness_dict[number_of_images-1]=0
        contrast_dict[number_of_images-1]=1
        alpha_dict[number_of_images-1]=0.2        
        color_array[number_of_images-1]=np.array([255,255,255]).reshape((1,1,3))
        check_box_dict[number_of_images-1]=True
        intensity_dictionary[number_of_images-1]=[0,255]
        threshold_dictionary[number_of_images-1]=[0,255]
        
        view_image.description=view_image.description + " : " + selection
        
        # create and display image array
        img.append(img_new)
        image=deepcopy(img[number_of_images-1])
        image=np.clip(image*contrast_dict[number_of_images-1] + brightness_dict[number_of_images-1],0,1)
        
        # build and show histogram
        fig = plt.figure()
        fig.add_subplot(111)        
        hist= cv2.calcHist([(image*255).astype('uint8')],[0],None,[256],[0,256])
        plt.plot(hist,c='r')
        plt.axvline(x=threshold_dictionary[number_of_images-1][0], c='g',ls='--')
        plt.axvline(x=threshold_dictionary[number_of_images-1][1], c='b',ls='--')   
        plt.semilogy()
        hist_image = fig2data(fig)
        img_byte_encoded = cv2.imencode('.png', hist_image)[1].tostring()
        histogram_widget.value=img_byte_encoded
        histogram_widget.width=hist_image.shape[1]
        histogram_widget.height=hist_image.shape[0]       

        # threshold
        inds1=image < intensity_dictionary[number_of_images-1][0]/255
        inds2=image > intensity_dictionary[number_of_images-1][1]/255
        image[inds1] = 0
        image[inds2] = 1
        
        # build the final image array
        new_img=np.dstack((image,image,image))*color_array[number_of_images-1]        
        if number_of_images == 1:
            final_img=np.expand_dims(new_img,axis=3)
        elif number_of_images > 1:
            new_img=np.expand_dims(new_img, axis=3)
            final_img=np.concatenate((final_img,new_img),axis=3)        
        display=final_img[:,:,:,0]*check_box_dict[0]*alpha_dict[0]
        for i in range(1,final_img.shape[3]):
            display=display+final_img[:,:,:,i]*check_box_dict[i]*alpha_dict[i]          
        
        # update the image display widget
        img_byte_encoded = cv2.imencode('.png', display)[1].tostring()
        img_display.value=img_byte_encoded
        img_display.width=final_img.shape[1]
        img_display.height=final_img.shape[0]
 
    return call_back 

def update_histogram(img_index,image):
    global histogram_panel
    global img
    global threshold_dictionary
    #im=np.clip(img[img_index]*contrast_dict[img_index] + brightness_dict[img_index],0,1)
    
    fig = plt.figure()
    fig.add_subplot(111)   
    hist= cv2.calcHist([(image*255).astype('uint8')],[0],None,[256],[0,256])
    plt.plot(hist,c='r')
    plt.semilogy()
    plt.axvline(x=threshold_dictionary[img_index][0], c='g',ls='--')
    plt.axvline(x=threshold_dictionary[img_index][1], c='b',ls='--')
    #fig.savefig('fig.png')
    hist_image = fig2data(fig)
    img_byte_encoded = cv2.imencode('.png', hist_image)[1].tostring()
    histogram_panel.children[img_index].value=img_byte_encoded    

def brightness_contrast_observer(key,img_index):
    """
    call back function for the brightness, contrast and alpha widgets.
    """
    def call_back(*args): 
        global brightness_dict
        global contrast_dict
        global alpha_dict
        global number_of_images
        global img
        global img_display
        global add_image_button       
        
        # get selection value
        value=args[0]['new'] 
        
        # update the respective display parameter
        if key=='brightness':           
            brightness_dict[img_index]=value            
            
        elif key=='contrast':
            contrast_dict[img_index]=value  
        elif key== 'alpha':
            alpha_dict[img_index]=value  
            
        # update final image aray
        #update_histogram(img_index)
        display_img(img_index)       
    return call_back  


    
image_panel=[] 

# initialize image display with a blank image
img_display=widgets.Image( value=cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring(),
                           format='png',
                           width=500,
                           height=500)



    

number_of_images=0
channel_list=[]
image_collection_list = [str(ic) for ic in WippImageCollection.all().values()]
def add_image(*args):
    global number_of_images  
    global channel_list
    global buttons
    global add_image_button
    global image_collection_list
    add_image_button.disabled=True
    number_of_images+=1
    
    # widget to list the image collection in WIPP backend
    if number_of_images==1:
        image_collections = widgets.Combobox(placeholder='Click on the box or start typing!',
                               options=[str(ic) for ic in WippImageCollection.all().values()],
                               description='Image Collections',
                               ensure_option=True,
                               disabled=False,
                               layout=widgets.Layout(width='95%'))
    else:
        
        image_collections = widgets.Combobox(placeholder='Click on the box or start typing!',
                               options=['Same as previous collection']+[str(ic) for ic in WippImageCollection.all().values()],
                               description='Image Collections',
                               ensure_option=True,
                               disabled=False,
                               layout=widgets.Layout(width='95%'))        
    
    # widget to list images in the chosen image collection
    images = widgets.Combobox(placeholder='Select an image collection first',
                           options=[],
                           description='Images',
                           ensure_option=True,
                           disabled=True,
                           layout=widgets.Layout(width='95%'))
    
    # add slider for brightness
    brightness_slider = widgets.FloatSlider(value=0,
                                     max=1,
                                     min=0,
                                     step=0.01,
                                     description="Brightness:",
                                     continuous_update=False,
                                     orientation='horizontal',
                                     readout=False,
                                     layout=widgets.Layout(width='68%'),
                                     disabled=False)
    
    # add ticker for brightness
    brightness_ticker = widgets.FloatText(value=0,
                                          max=1,
                                          min=0,
                                          step=.01,
                                          continuous_update=False,
                                          orientation='horizontal',
                                          layout=widgets.Layout(width='20%'),
                                          disabled=False)
    
    # add slider for contrast
    contrast_slider = widgets.FloatSlider(value=1,
                                   min=0.1,
                                   max=5,
                                   step=0.1,
                                   description="Contrast:",
                                   continuous_update=False,
                                   orientation='horizontal',
                                   readout=False,
                                   layout=widgets.Layout(width='68%'),
                                   disabled=False)
    # add ticker for contrast
    contrast_ticker = widgets.FloatText(value=1,
                                        min=0.1,
                                        max=5,
                                        step=0.1,
                                        continuous_update=True,
                                        orientation='horizontal',
                                        layout=widgets.Layout(width='20%'),
                                        disabled=False) 
    
    
    # add slider for alpha
    trans_slider = widgets.FloatSlider(value=0.2,
                                   min=0,
                                   max=1,
                                   step=0.02,
                                   description="alpha:",
                                   continuous_update=False,
                                   orientation='horizontal',
                                   readout=False,
                                   layout=widgets.Layout(width='68%'),
                                   disabled=False)
    
    # add ticker for alpha
    trans_ticker = widgets.FloatText(value=0.2,
                                        min=0,
                                        max=1,
                                        step=0.02,
                                        continuous_update=True,
                                        orientation='horizontal',
                                        layout=widgets.Layout(width='20%'),
                                        disabled=False) 
    
    # add textbos to display metadata
    metadata_textbox= widgets.Textarea(value='Waiting to load image..',
                                       placeholder='Type something',
                                       description='String:',
                                       disabled=False,
                                       layout=widgets.Layout(width='300px',height='500px'))
    
    #add color picker to choose color
    color = widgets.ColorPicker(concise=False,
                                description='Color:',
                                value='#ffffff',
                                layout=widgets.Layout(width='68%'),
                                disabled=False)
    
    
    view_image=  widgets.Checkbox(value=True,
                                  description='Image {}'.format(len(image_panel.children)+1),
                                  disabled=False,
                                  indent=False,
                                  layout=widgets.Layout(object_positioning='right')) 
    
    intensity_slider=widgets.IntRangeSlider(value=[0, 255],
                                            min=0,
                                            max=255,
                                            step=1,
                                            description='Intensity:',
                                            disabled=False,
                                            continuous_update=False,
                                            orientation='horizontal',
                                            readout=True,
                                            readout_format='d')

    threshold_slider=widgets.IntRangeSlider(value=[0, 255],
                                            min=0,
                                            max=255,
                                            step=1,
                                            description='Threshold:',
                                            disabled=False,
                                            continuous_update=False,
                                            orientation='horizontal',
                                            readout=True,
                                            readout_format='d')
    
    histogram_widget=widgets.Image( value=cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring(),
                                   format='png',
                                   width=300,
                                   height=400)

    # link the corresponding sliders and tickers
    widgets.jslink((brightness_ticker, 'value'), (brightness_slider, 'value'))
    widgets.jslink((contrast_ticker, 'value'), (contrast_slider, 'value'))
    widgets.jslink((trans_ticker, 'value'), (trans_slider, 'value'))
    
    # arrange the corresponding sliders and tickets side by side
    brightness_box=widgets.HBox([brightness_slider, brightness_ticker])
    contrast_box=widgets.HBox([contrast_slider, contrast_ticker])    
    trans_box=widgets.HBox([trans_slider, trans_ticker])     
       
    # link the widgets with their corresponding call back functions
    image_collections.observe(image_collection_observer(image_collections,images,len(image_panel.children)),'value')
    images.observe(image_display_observer(len(image_panel.children), metadata_textbox,view_image,histogram_widget), 'value')
    brightness_slider.observe(brightness_contrast_observer('brightness',len(image_panel.children)),'value')
    contrast_slider.observe(brightness_contrast_observer('contrast',len(image_panel.children)),'value')
    trans_slider.observe(brightness_contrast_observer('alpha',len(image_panel.children)),'value')
    view_image.observe(image_checkbox_observer(len(image_panel.children)),'value')
    color.observe(parse_color(len(image_panel.children)),'value')
    intensity_slider.observe(intensity_observer(len(image_panel.children),'intensity'),'value')
    threshold_slider.observe(intensity_observer(len(image_panel.children),'threshold'),'value')
    
    # update the UI accordions
    channel_list.append(view_image)
    channel_panel.children=((widgets.VBox(channel_list),))
    select_image_panel.children=(widgets.VBox([image_collections,images,buttons]),)
    select_image_panel.set_title(0,'Select Image {}'.format(len(image_panel.children)+1))
    image_panel.children= image_panel.children + (widgets.VBox([brightness_box, contrast_box,trans_box,intensity_slider, threshold_slider, color ]),) 
    image_panel.set_title(len(image_panel.children)-1, 'Img {}'.format(len(image_panel.children))) 
    image_panel.selected_index =len(image_panel.children)-1
    metadata_panel.children=metadata_panel.children + (metadata_textbox,)
    metadata_panel.set_title(len(image_panel.children)-1,'Image {} Metadata'.format(len(image_panel.children)))
    metadata_panel.selected_index =len(image_panel.children)-1
    histogram_panel.children=histogram_panel.children + (histogram_widget,)
    histogram_panel.set_title(len(image_panel.children)-1,'Histogram :Image {}'.format(len(image_panel.children)))
    histogram_panel.selected_index =len(image_panel.children)-1
    
    
    
def reset(*args):
    """ callback function for the reset button. 
    """
    global number_of_images
    global channel_list
    global select_image_panel
    global image_panel
    global metadata_panel
    global channel_panel
    global intensity_dictionary
    global img
    global color_array
    global brightness_dict
    global contrast_dict
    global alpha_dict
    global check_box_dict
    global threshold_dictionary
    global img_display
    global image_collection_path
    
    
    img=[]
    image_collection_path=[]
    color_array={}
    brightness_dict={}
    intensity_dictionary={}
    contrast_dict= {}
    alpha_dict={}
    check_box_dict={}
    threshold_dictionary={     }
    number_of_images=0
    channel_list=[]
    
    select_image_panel.children=[]

    image_panel.children=[] 
    metadata_panel.children=[]
    histogram_panel.children=[]

    channel_panel.children=[] 
    channel_panel.set_title(0,'Image List')
    
    img_display.value=cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring()
    
    add_image()
     

add_image_button = widgets.Button(description='Add Image')
reset_button=widgets.Button(description='Reset')
reset_button.on_click(reset)
buttons=widgets.HBox([add_image_button,reset_button])

select_image_panel= widgets.Accordion(children=[],
                             description='Job inputs:')

image_panel = widgets.Accordion(children=[],
                             description='Job inputs:', layout=widgets.Layout(width='400px',overflow_x='auto' ))
metadata_panel=widgets.Accordion(children=[], description='Job inputs:',layout=widgets.Layout(height='800px',overflow_y='auto' ))
histogram_panel=widgets.Accordion(children=[], description='Job inputs:',layout=widgets.Layout(height='800px',overflow_y='auto' ))
right_panel=widgets.Tab(children=[histogram_panel,metadata_panel],description='Job inputs:' )
right_panel.set_title(0,'Histogram')
right_panel.set_title(1,'Metadata')
channel_panel=widgets.Accordion(children=[],
                             description='Image List',layout=widgets.Layout(height='200px', overflow_y='auto' ))
channel_panel.set_title(0,'Image List')
add_image()
add_image_button.on_click(add_image)
display( widgets.HBox([widgets.VBox([select_image_panel,image_panel,channel_panel]),img_display,right_panel ]))









ModuleNotFoundError: No module named 'tifffile'