In [1]:
# Plotly 
import torch
import plotly
from plotly.offline import iplot as plt
from plotly import graph_objs as plt_type
from plotly import graph_objs as go
plotly.offline.init_notebook_mode(connected=True)


In [None]:
def defaultLayout(scale=1.):
    return dict(
        font=dict(family="Times New Roman", size=int(14*scale)),
        titlefont=dict(family="Times New Roman", size=int(18*scale))
    )


def getTimestamp():
    return '_'+datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%dT%H%M%S')

In [None]:
import collections

def dict_merge(dct, merge_dct):
    """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    :param dct: dict onto which the merge is executed
    :param merge_dct: dct merged into dct
    :return: None
    """
    for k, v in merge_dct.items():
        if (k in dct and isinstance(dct[k], dict)
                and isinstance(merge_dct[k], collections.Mapping)):
            dict_merge(dct[k], merge_dct[k])
        else:
            dct[k] = merge_dct[k]

In [2]:
# Create custom perceptually uniform colormap (called colorscale in plotly)
import colorcet
import numpy as np
def convertColorcetToPlotly(cur_cm, crop_low=None, crop_high=None, illum=1.0, colOrder=[0,1,2]):
    if crop_low is None:
        crop_low = 0
        
    if crop_high is None:
        crop_high = len(cur_cm)   
    
    return [[float(i-crop_low)/(crop_high-crop_low-1), 
                     'rgb'+str((
                         np.uint8(cur_cm[i][colOrder[0]]*255*illum),
                         np.uint8(cur_cm[i][colOrder[1]]*255*illum),
                         np.uint8(cur_cm[i][colOrder[2]]*255*illum))
                     )
                      ]
                   for i in range(crop_low, crop_high)]

# Best linear green: convertColorcetToPlotly(colorcet.linear_ternary_green_0_46_c42, illum=2.0, crop_low=25)
# Best diverging bwr: convertColorcetToPlotly(colorcet.diverging_bwr_40_95_c42, illum=0.9)
# Best greyscale: convertColorcetToPlotly(colorcet.linear_grey_10_95_c0)

# Sinusoid test image
# import scipy.misc
# tmp = scipy.misc.imread('colourmaptest.tif')
# imagesc(tmp)

In [3]:
def setColorScale(colorscale='felfire'):
    # Add keywords to colorscale
    if colorscale == 'div':
        colorscale = convertColorcetToPlotly(colorcet.diverging_bwr_40_95_c42, illum=0.9)
    elif colorscale == 'lin':
        colorscale = convertColorcetToPlotly(colorcet.linear_ternary_green_0_46_c42, illum=2.0, crop_low=25)
    elif colorscale == 'grey':
        colorscale = convertColorcetToPlotly(colorcet.linear_grey_10_95_c0)
    elif colorscale == 'fire':
        colorscale = convertColorcetToPlotly(colorcet.linear_kryw_5_100_c67)
    elif colorscale == 'felfire':
        colorscale = convertColorcetToPlotly(colorcet.linear_kryw_5_100_c67, colOrder=[1,0,2], crop_high=245)
        
    return colorscale

In [10]:
# # Save colormaps to matlab use
# def convertColorcetToMatlab(cur_cm, crop_low=None, crop_high=None, illum=1.0, colOrder=[0,1,2]):
#     if crop_low is None:
#         crop_low = 0
        
#     if crop_high is None:
#         crop_high = len(cur_cm)
    
#     # For matlab we only need to keep the given values, transformed same as in Plotly, without *255
#     return [
#                      [ 
#                          cur_cm[i][colOrder[0]]*illum,
#                          cur_cm[i][colOrder[1]]*illum,
#                          cur_cm[i][colOrder[2]]*illum
#                       ]
#                    for i in range(crop_low, crop_high)]

# import scipy.io

# dictSaveToMatlab ={
#     'div': convertColorcetToMatlab(colorcet.diverging_bwr_40_95_c42, illum=0.9),
#     'lin': convertColorcetToMatlab(colorcet.linear_ternary_green_0_46_c42, illum=2.0, crop_low=25),
#     'gray': convertColorcetToMatlab(colorcet.linear_grey_10_95_c0),
#     'fire':  convertColorcetToMatlab(colorcet.linear_kryw_5_100_c67),
#     'felfire': convertColorcetToMatlab(colorcet.linear_kryw_5_100_c67, colOrder=[1,0,2], crop_high=245)
# }

# scipy.io.savemat('../Visualisation/matlab_colormaps', mdict=dictSaveToMatlab)




# Image visualiastion

In [None]:
def createScalebar(imSize, 
                   pixels_per_micron=None, 
                   barSize = 50 # microns
                  ):
    scalebar_shape = []
    scalebar_text = []
    
    if pixels_per_micron is not None:
        scalebar_loc = {
            'x0': imSize[1]*0.9-pixels_per_micron*barSize, # Add a 50 um bar
            'y0': imSize[0]*0.9,
            'x1': imSize[1]*0.9,
            'y1': imSize[0]*0.9
        }

        scalebar_shape = [{
            'type': 'line',
            'xref': 'x',
            'yref': 'y',
            **scalebar_loc,
            'line': {
                'color': 'rgb(255, 255, 255)',
                'width': 4,
            },
        }]

        scalebar_text= [dict(
                    x=(scalebar_loc['x0']+scalebar_loc['x1'])/2.,
                    y=scalebar_loc['y0'],
                    xref='x',
                    yref='y',
                    yanchor='top',
                    text=str(int(barSize)) + ' &mu;m',
                    showarrow=False,
                    arrowhead=7,
                    ax=0,
                    ay=0,
                    font = dict(
                        color='rgb(255, 255, 255)',
                        size=10
                    )
                )]
                   
    return scalebar_shape, scalebar_text

In [None]:
import time, datetime
def setDefaultImageExport(**kwargs):
    time_stamp = kwargs.pop('time_stamp', getTimestamp())
    imageExportArgs = dict(
        filename='imagesc'+ time_stamp
    )
    
    if 'image' in kwargs:
        dict_merge(
            imageExportArgs, 
            dict(
                image = kwargs['image'],
                image_height=600,
                image_width=800
            ))
    
    
    if 'filename' in kwargs:
        imageExportArgs['filename'] = kwargs['filename'] + time_stamp
    
    return imageExportArgs
        

In [None]:
def setDefaultImagePlot(imSize, **kwargs):
    layoutArgs = defaultLayout()
    dict_merge(layoutArgs, dict(
        autosize=False,
        height= 600*0.8,
        width = 800*0.8, # Width of colorbar
        margin=go.Margin(
                l=50,
                r=50,
                b=50,
                t=50,
                pad=0,
                autoexpand = True),
        xaxis = dict(
                constrain='domain',
                position=0.,
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
            yaxis = dict(
              autorange='reversed',
              scaleanchor="x", scaleratio=1,
                constrain='domain',
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            )
    ))
    
    fig_height = float(layoutArgs['height']) - float(layoutArgs['margin']['b']) - float(layoutArgs['margin']['t'])
    fig_width = float(layoutArgs['width']) - float(layoutArgs['margin']['l']) - float(layoutArgs['margin']['r'])
    
    heatmapArgs = dict(
        colorscale=setColorScale(kwargs.get('heatmap', dict()).pop('colorscale','felfire')),
        colorbar=dict(
            len=min(1., 1./(fig_height/fig_width*float(imSize[1])/imSize[0])),
            xanchor='left',
            x = min(1., 1.-(1.-fig_height/fig_width*float(imSize[1])/imSize[0])/2.),
            xpad=15.,
            title = 'a.u.',
            titleside = 'top',
        )
    )   
    
    # Merge the potential inputs
    dict_merge(layoutArgs, kwargs.get('layout', dict()))
    dict_merge(heatmapArgs, kwargs.get('heatmap', dict()))
    
    return heatmapArgs, layoutArgs

In [None]:
def imagesc(im, title='', pixels_per_micron=None, now=True, **kwargs):
    """ 
    Shows image with reasonable default, but fully customisable, 
    these dicts get passed to the appropriate places:
        layout = dict()
        heatmap = dict()
        image = 'svg'
        filename = 'image' + time_stamp
        saveHtml = image is not None
        pixels_per_micron = None # Adds 50 micron bar if set
    """
   
    # Get default settings
    heatmapArgs, layoutArgs = setDefaultImagePlot(im.size(), **kwargs)
    imageExportArgs = setDefaultImageExport(**kwargs)
    saveHtml = kwargs.pop('saveHtml', 'image' in imageExportArgs)
    layoutArgs['title'] = title
   
    # The main image data
    data= [plt_type.Heatmap(z=im,
                            **heatmapArgs
                           )
          ]
    
    # Add a scale bar if the scale is given
    scalebar_shape, scalebar_text = createScalebar(im.size(), pixels_per_micron, kwargs.pop('barSize', 50.))
        
    # Create the plot
    fig = plt_type.Figure(
        data= data, 
        layout=plt_type.Layout(
            **layoutArgs,
            shapes=[] + scalebar_shape,
            annotations=[] + scalebar_text,
        )
    )
    
    # Save html if promted
    if now:
        if saveHtml:
            plotly.offline.plot(fig, filename='savedHtml/'+imageExportArgs['filename']+'.html', auto_open=False)
        
        plt(fig,
            **imageExportArgs)
    else:
        return fig
    

In [None]:
im = torch.randn(245,245)
 # In um

imagesc(im, pixels_per_micron=2., barSize=20.)

# Trace Visualisation

In [None]:
import time, datetime
def setDefaultPlotImageExport(**kwargs):
    time_stamp = kwargs.pop('time_stamp',getTimestamp())
    imageExportArgs = dict(
        filename='plot'+ time_stamp
    )
    
    if 'image' in kwargs:
        dict_merge(
            imageExportArgs, 
            dict(
                image = kwargs['image'],
                image_height=600,
                image_width=1000
            ))
    
    
    if 'filename' in kwargs:
        imageExportArgs['filename'] = kwargs['filename'] + time_stamp
    
    return imageExportArgs

In [None]:
def exportFigure(fig, type='plot', **kwargs):
    if type=='plot':
        imageExportArgs = setDefaultPlotImageExport(**kwargs)
    elif type=='image':
        imageExportArgs = setDefaultImageExport(**kwargs)
    saveHtml = kwargs.pop('saveHtml', 'image' in imageExportArgs)

    if saveHtml:
        plotly.offline.plot(fig, filename='savedHtml/'+imageExportArgs['filename']+'.html', auto_open=False)
        
    plt(fig,
        **imageExportArgs)

In [None]:
def plot(Y, X=None, now = True, **kwargs):
    """
    plots a matrix Y (n x num_lines)
    with optionally x spacing
    """
    plots = list()
    
    if Y.ndimension()==1:
        Y = Y.unsqueeze(-1)
    
    if X is None:
        X = torch.arange(Y.shape[0])
    
    for num_line in range(Y.shape[1]):
        plots.append(
            plt_type.Scatter(
                x = X[:,num_line] if X.ndimension()==2 else X,
                y = Y[:,num_line],
                **kwargs
            )
        )
        
    if now:
        plt(plots)
    else:
        return plots

In [None]:
def setDefaultStackedPlot(n_traces, **kwargs):
    
    tracename=kwargs.pop('tracename', 'n_photon')
    colorscale = setColorScale(kwargs.pop('colorscale', 'felfire'))
    cur_colors = torch.linspace(0., len(colorscale)-1, n_traces)

    dataArgs = dict()
    for num_line in range(n_traces):
        dataArgs[num_line] = dict(
            name = (tracename + ' = ' + str(num_line)
                   ) if num_line<(n_traces-1
                   ) else (tracename + ' > ' + str(num_line-1)),
            fill='tonexty',
            line=dict(
                color = colorscale[int(cur_colors[num_line])][1]
            )
        )
      
    layoutArgs = defaultLayout()
    dict_merge(layoutArgs, dict(
        autosize=False,
        height= 600*0.8,
        width = 1000*0.8,
        margin=go.Margin(
                l=50,
                r=50,
                b=50,
                t=50,
                pad=0,
                autoexpand = True),
        xaxis = dict(
            title='Grey level in data',
            constrain='domain',
            position=0.,
            showgrid=False,
            zeroline=True,
            showline=True,
        ),
        yaxis = dict(
            title='Cumulative probability of photon count',
            constrain='domain',
            showgrid=False,
            zeroline=True,
            showline=False,
        )
    ))
            
            
    # Merge the potential inputs
    dict_merge(layoutArgs, kwargs.get('layout', dict()))
            
    # Merge 'data_all' to each of them
    for dataDict in dataArgs:
        dict_merge(dataDict, kwargs.get('data_all', dict()))
            
    # Merge the basic data to the specific keys
    dict_merge(dataArgs, kwargs.get('data', dict()))
    
    return dataArgs, layoutArgs
    

In [None]:
import copy
def plotStacked(Y, X=None, now=True, **kwargs):
    data = list()
    if X is None:
        X = torch.arange(Y.shape[0])
    
    dataArgs, layoutArgs = setDefaultStackedPlot(Y.size(1), **kwargs)  
    
    # Create the scatter traces
    y_cur = torch.zeros_like(Y[:,0]);
    for num_line in range(Y.shape[1]):
        y_cur += Y[:,num_line]
        data.append(
            plt_type.Scatter(
                x = X[:,num_line] if X.ndimension()==2 else X,
                y = copy.deepcopy(y_cur),
                **(dataArgs[num_line])
            )
        )   
        
    
    fig = plt_type.Figure(
        data= data, 
        layout=plt_type.Layout(
            **layoutArgs
        )
    )
    
    if now:
        plt(fig)
    else:
        return fig

In [None]:
a = torch.rand(500,4)
plotStacked(a)

In [None]:
a

In [None]:


# scalebar_loc = {
#     'x0': im.size(1)*0.9-pixels_per_micron*50, # Add a 50 um bar
#     'y0': im.size(0)*0.9,
#     'x1': im.size(1)*0.9,
#     'y1': im.size(0)*0.9,
# }

# scalebar_shape = {
#     'type': 'line',
#     'xref': 'x',
#     'yref': 'y',
#     **scalebar_loc,
#     'line': {
#         'color': 'rgb(255, 255, 255)',
#         'width': 4,
#     },
# }

# scalebar_text= dict(
#             x=(scalebar_loc['x0']+scalebar_loc['x1'])/2.,
#             y=scalebar_loc['y0'],
#             xref='x',
#             yref='y',
#             yanchor='top',
#             text='50 &mu;m',
#             showarrow=False,
#             arrowhead=7,
#             ax=0,
#             ay=0,
#             font = dict(
#                 color='rgb(255, 255, 255)',
#                 size=9
#             )
#         )

# plt(plt_type.Figure(
#     data= [plt_type.Heatmap(z=im,
#         colorbar=dict(
#             len=min(1.0, im.size(0)/im.size(1))
#         ))
#           ], 
#     layout=plt_type.Layout(
#         autosize=False,
#         width=500,
#         height=500,
#         shapes=[scalebar_shape],
#         annotations=[scalebar_text],
#         margin=go.Margin(
#             l=50,
#             r=50,
#             b=50,
#             t=50,
#             pad=4),
#         xaxis = dict(
#             showgrid=False,
#             zeroline=False,
#             showline=False,
#             ticks='',
#             showticklabels=False
#         ),
#         yaxis = dict(
#           autorange='reversed',
#           scaleanchor="x", scaleratio=1,
#             showgrid=False,
#             zeroline=False,
#             showline=False,
#             ticks='',
#             showticklabels=False
#         )
#     )
# )
#    )

In [None]:
import numpy as np
import copy

def croppedHist(inp, bins=50, cropLow = 0., cropHigh = None, extremeVal = 1e12):
    cropLowVal = cropLow if cropLow is not None else inp.min()
    cropHighVal = cropHigh if cropHigh is not None else inp.max()
    bins = np.linspace(cropLowVal, cropHighVal, bins+1)
    bins_extended = copy.deepcopy(bins)
    
    if cropLow is not None:
        bins_extended[0] = -extremeVal
    
    if cropHigh is not None:
        bins_extended[-1] = extremeVal
        
    histRange = (bins_extended[0], bins_extended[-1])
    
    hist_counts = np.histogram(inp, bins_extended, range=histRange)[0]
    
    return hist_counts, bins, bins_extended

### Hexagonal binning

In [None]:
# Implementation from https://github.com/plotly/plotly.js/issues/1574
import plotly.offline as offline
import plotly.graph_objs as go
import numpy as np
import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot
from matplotlib.colors import Normalize


def scatterHex(x, y, gridsize=30, logCounts = True, now=True):

    # Subfunction using matplotlib to compute bins
    def compute_hexbin(x, y, gridsize=100, bins=None, cmap=matplotlib.pyplot.cm.Blues, logCounts = True):
        """Computes the hexagonal binning
        """
        collection = matplotlib.pyplot.hexbin(x, y, bins=bins, gridsize=gridsize)
        matplotlib.pyplot.close()

        pts_in_hexagon = collection.get_array()
        
        if logCounts:
            pts_in_hexagon = np.log10(pts_in_hexagon+1.)-1

        #compute colors for the svg shapes
        colors = ["#%02x%02x%02x" % (int(r), int(g), int(b)) for r, g, b, _ in 255*cmap(Normalize()(pts_in_hexagon))]

        # coordinates for single hexagonal patch
        hx = [0, .5, .5, 0, -.5, -.5]
        hy = [-.5/np.cos(np.pi/6), -.5*np.tan(np.pi/6), .5*np.tan(np.pi/6),
              .5/np.cos(np.pi/6), .5*np.tan(np.pi/6), -.5*np.tan(np.pi/6)]

        # number of hexagons needed
        m = len(collection.get_offsets())

        # scale of hexagons
        n = (x.max() - x.min()) / gridsize

        # y_scale to adjust for aspect ratio
        y_scale = (y.max() - y.min())/(x.max() - x.min())

        # coordinates for all hexagonal patches
        hxs = np.array([hx]*m)*n + np.vstack(collection.get_offsets()[:,0])
        hys = np.array([hy]*m)*n*y_scale + np.vstack(collection.get_offsets()[:,1])

        return hxs.tolist(), hys.tolist(), colors, pts_in_hexagon

    # Compute the bins
    x, y, color_list, pts_in_hexagon = compute_hexbin(x, y, gridsize=gridsize, logCounts=logCounts)

    shape_container = []
    hover_point_x = []
    hover_point_y = []

    for x_list, y_list, color in zip(x, y, color_list):

        #Create the svg path based on the computed points

        svg_path = 'M {},{} L {},{} L {},{} L {},{} L{},{} L{},{}'\
            .format(x_list[0], y_list[0],
                    x_list[1], y_list[1],
                    x_list[2], y_list[2],
                    x_list[3], y_list[3],
                    x_list[4], y_list[4],
                    x_list[4], y_list[1])

        #Create hover point from the hexagon, witch is the center of gravity
        hover_point_x.append(round((max(x_list) - min(x_list))/2+min(x_list), 2))
        hover_point_y.append(round((max(y_list) - min(y_list))/2+min(y_list), 2))

        shape_container.append({
              "fillcolor": color,
              "line": {
                "color": color,
                "width": 1.5
              },
              "path": svg_path,
              "type": "path"
            })

    trace = go.Scattergl(x=hover_point_x,
                       y=hover_point_y,
                       mode='markers'
                       )
    
    if logCounts:
        trace['marker']['colorbar'] = {"title": "log10( # of points) \n -1 is no data)"}
        trace['marker']['color'] = pts_in_hexagon
        trace['text'] = list(map(lambda z: 'Number of points: {}'.format(int(z)), (10.**(pts_in_hexagon+1))-1.))

    else:
        trace['marker']['colorbar'] = {"title": " # of points"}
        trace['marker']['color'] = pts_in_hexagon
        trace['text'] = list(map(lambda z: 'Number of points: {}'.format(int(z)), pts_in_hexagon))
    
    
    
    trace['marker']['reversescale'] = True
    trace['marker']['colorscale'] = 'Blues'
    trace['marker']['size'] = 0

    layout = {'shapes':shape_container,
              'width': 850,
              'height': 700,
              'hovermode':'closest'}

    fig = dict(data=[trace], layout=layout)

    if now:
        plt(fig)
    else:
        return fig

In [None]:
# import torch
# a = plot(torch.randn(50,3), now=False)

In [None]:
quant_funcs = torch.rand(300,4).sort(0)[0]

In [None]:
numInvBins = 11
inv_bins = torch.linspace(0., 1., numInvBins)

In [None]:
tmp = (quant_funcs.unsqueeze(-1) - inv_bins.view(1,1,-1))

In [None]:
invQuantFuncY = (quant_funcs.unsqueeze(-1) - inv_bins.view(1,1,-1)).abs().argmin(0).permute(1,0)

In [None]:
plot(invQuantFuncY, X=inv_bins)

In [None]:
tmp.argmin(0)