In [1]:
import colorsys
import imageio
import math
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft2, ifft2, fftshift
import PIL
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from skimage import io

def bldclrwhl (ny, nx, sym):
    im = Image.new("RGB", (nx,ny))
    radius = min(im.size)/2.0
    centre = im.size[0]/2, im.size[1]/2
    pix = im.load()
    for x in range(im.width):
        for y in range(im.height):
            rx = x - centre[0]
            ry = y - centre[1]
            s=1
            if s <= 1.0:
                h = ((math.atan2(ry, rx) / math.pi) + 1.0) / 2.0
                h = h * sym       #symmetry
                rgb = colorsys.hsv_to_rgb(h, s, 1.0)
                pix[x,y] = tuple([int(round(c*255.0)) for c in rgb])
    imnp = np.array(im)
    return imnp

def nofft(whl, img, nx, ny):
    imnp = np.array(img)
    fimg = fft2(imnp)
    whl = fftshift(whl)
    proimg = np.zeros((nx,ny,3))
    comb = np.zeros((nx,ny,3), dtype=complex)
    magnitude = np.repeat(np.abs(fimg)[:,:,np.newaxis], 3, axis=2)
    phase = np.repeat(np.angle(fimg)[:,:,np.newaxis], 3, axis=2)
    proimg = whl*magnitude
    comb = np.multiply(proimg, np.exp(1j*phase))
    for n in range(3):
        proimg[:, :, n] = np.real(ifft2(comb[:,:,n]))
        proimg[:, :, n] = proimg[:, :, n] - np.min(proimg[:, :, n])
        proimg[:, :, n] = proimg[:, :, n] / np.max(proimg[:, :, n])
    
    return proimg

In [41]:
import base64
import io as file
import json
import time

import dash
from dash.dependencies import Input, Output, State, ALL
from dash.exceptions import PreventUpdate
import dash_bootstrap_components as dbc
from dash import dcc
from dash import html
from flask_caching import Cache
from jupyter_dash import JupyterDash
from PIL import Image
import plotly
import plotly.express as px
from urllib.parse import quote as urlquote

# Load Data
df = px.data.tips()

# App Layout
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, "../assets/mlex-style.css"])

def header():
    header= dbc.Navbar(
        dbc.Container([
            dbc.Row(
                [
                    dbc.Col(
                        html.Img(id="logo",
                                 src='assets/mlex.png',
                                 height="60px"),
                        md="auto"),
                    dbc.Col(
                        [html.Div(children=html.H3("MLExchange | Colorwheel Orientation"),
                                  id="app-title")],
                        md=True,
                        align="center",
                    )
                ],
                align="center",
            ),
            dbc.Row([
                dbc.Col([dbc.NavbarToggler(id="navbar-toggler")],
                        md=2)],
                align="center"),
        ],
            fluid=True),
        dark=True,
        color="dark",
        sticky="top",
    )
    return header

sidebar = dbc.Card(
    id='slidebar',
    children=[
        dbc.CardHeader(dbc.Label('Parameters', className='mr-2')),
        dbc.CardBody(
            children=[
                html.Div(children='''Symmetry'''),
                dcc.Slider(
                    id='symmetry-slider',
                    min=1,
                    max=12,
                    step=1,
                    value=6,
                    updatemode='mouseup',
                    marks={str(n): str(n) for n in range(13)}
                ),
                html.Div(children='''Color Saturation'''),
                dcc.Slider(
                    id='color-slider',
                    min=0,
                    max=20,
                    step=0.1,
                    value=1,
                    tooltip={"placement": "bottom", "always_visible": True},
                    updatemode='mouseup'
                ), 
                html.Div(children='''Brightness'''),
                dcc.Slider(
                    id='bright-slider',
                    min=0,
                    max=4,
                    step=0.1,
                    value=1,
                    tooltip={"placement": "bottom", "always_visible": True},
                    updatemode='mouseup'
                ), 
                html.Div(children='''Contrast'''),
                dcc.Slider(
                    id='contrast-slider',
                    min=0,
                    max=10,
                    step=0.1,
                    value=1,
                    tooltip={"placement": "bottom", "always_visible": True},
                    updatemode='mouseup'
                ), 
                html.Div(children='''Blur'''),
                dcc.Slider(
                    id='blur-slider',
                    min=0,
                    max=10,
                    step=0.1,
                    value=0,
                    tooltip={"placement": "bottom", "always_visible": True},
                    updatemode='mouseup'
                ),
                html.Div(children='''Overlap'''),
                dcc.Slider(
                    id='overlap',
                    min=0,
                    max=1,
                    step=0.1,
                    value=1.0,
                    tooltip={"placement": "bottom", "always_visible": True},
                    updatemode='mouseup'
                )
            ]
        )
    ]
)

content = html.Div([
    dcc.Upload(
        id='upload-data',
        children=html.Div(['Drag and Drop or ',
                           html.A('Select Files')
                          ]),
        style={
            'width': '100%',
            'height': '60px',
            'lineHeight': '60px',
            'borderWidth': '1px',
            'borderStyle': 'dashed',
            'borderRadius': '5px',
            'textAlign': 'center',
            'margin': '10px'
        },
        # Allow multiple files to be uploaded
        multiple=True
    ),
    html.Div(id='graph'),
    dcc.Loading(id='download'),
    html.Div(id='no-display',
             children=[
                 dcc.Store(id='image-store', data={}),
                 dcc.Store(id='cu_sym', data=-1),
                 dcc.Store(id='temp-img', data={}),
                 dcc.Store(id='clrwhl', data=[]),
                 dcc.Store(id='original_image', data=[]),
                 dcc.Store(id='rgb', data=[])
             ])
],
style={'margin-top': '1rem', 'margin-right': '1rem'})

app.layout = html.Div(children=[
                                header(),
                                dbc.Row([dbc.Col(sidebar, width=3), 
                                         dbc.Col(content, width=9)])
                                ],
                     style={'margin-top': '3rem',
                            'margin-bottom': '3rem',
                            'margin-left': '3rem',
                            'margin-right': '3rem'
                           })


# Reads the image in cache and returns it as a numpy array
def read_img_cache(image_cache):
    # image_cache is a dict, keys=filename, value=bin encoding
    img_bytes = base64.b64decode(image_cache)
    im_bbytes = file.BytesIO(img_bytes)
    im = PIL.Image.open(im_bbytes)
    return np.array(im)


# Returns the figure
def make_figure(image_npy, clrwhl=None):
    height, width = np.array(image_npy).shape[0:2]
    fig = px.imshow(image_npy)
    if clrwhl:
        fig.update_xaxes(
            showgrid=False,
            range=(0, width),
            showticklabels=True, 
            zeroline=False,
            tickvals=np.linspace(start=0, stop=width, num=5),
            ticktext=np.linspace(start=225, stop=315, num=5)
        )
        fig.update_yaxes(
            showgrid=False,
            range=(height, 0),
            showticklabels=True, 
            zeroline=False,
            tickvals=np.linspace(start=0, stop=height, num=5),
            ticktext=np.linspace(start=135, stop=225, num=5)
        )
        fig.update_layout(margin=dict(l=0, r=0, t=0, b=0),
                          height=200)
    else:
        fig.update_xaxes(
            showgrid=False,
            #range=(0, width),
            showticklabels=False, 
            zeroline=False
        )
        fig.update_yaxes(
            showgrid=False,
            #range=(height, 0),
            showticklabels=False, 
            zeroline=False
        ) 
    return fig


# Define callback to upload image(s)
@app.callback(
    Output('image-store', 'data'),
    Output('graph', 'children'),
    Input('upload-data', 'contents'),
    Input('upload-data', 'filename')
)
def image_upload(upload_image_contents, upload_image_filename):
    if upload_image_contents is None:
        raise PreventUpdate
    else:
        image_store_data = {}
        for c, n in zip(upload_image_contents, upload_image_filename):
            content_type, content_string = c.split(',')
            image_store_data[n] = (content_type, content_string)
        image_slider_max = len(upload_image_filename)-1
    contents = [html.Div(id={'type': 'contents', 'index': 0},
                         children=[dcc.Graph(id={'type': 'graph', 'index': 0},
                                             config={'displayModeBar': False}),
                                   dcc.Slider(id={'type': 'image-slider', 'index': 0},
                                         min=0,
                                         max=image_slider_max,
                                         value=0,
                                         updatemode='mouseup',
                                         tooltip={"placement": "bottom", "always_visible": True})],
                         style={'display': 'none'}),
                html.Div(id={'type': 'contents', 'index':1},
                         children=[dcc.Graph(id={'type': 'graph', 'index': 1},
                                             config={'displayModeBar': False},
                                             style={'margin-bottom': '1rem'}),
                                   dbc.Button("SAVE", 
                                              id={'type':'save-data', 'index': 0}, 
                                              className="ms-auto", 
                                              n_clicks=0,
                                              style={'width': '95%'})],
                         style={'display': 'none'}),
                
               ]
    return image_store_data, contents


# Define callback to update graph
@app.callback(
    [
        Output({'type': 'graph', 'index': ALL}, 'figure'),
        Output({'type': 'contents', 'index': ALL}, 'style'),
        Output('temp-img', 'data'),
        Output('clrwhl', 'data'),
        Output('cu_sym', 'data'),
    ],
    [
        Input('symmetry-slider', 'value'),
        Input('color-slider', 'value'),
        Input('bright-slider', 'value'), 
        Input('contrast-slider', 'value'), 
        Input('blur-slider', 'value'),
        Input('overlap', 'value'), 
        Input('image-store', 'data'),
        Input({'type': 'image-slider', 'index': ALL}, 'value'),
        State('cu_sym', 'data'),
        State('temp-img', 'data'),
        State('clrwhl', 'data')
    ]
)
def update_figure(symmetry, enh_val, bright_val, contra_val, blur_val, overlap, image_store_data, slider_value, 
                  cu_sym, rgb2, clrwhl):
    if len(image_store_data)==0:
        raise PreventUpdate
    changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]
    
    if 'overlap' in changed_id or 'image-store' in changed_id or 'symmetry-slider' in changed_id or \
    'image-slider' in changed_id:
        if slider_value[0] is not None:
            if slider_value[0]<len(image_store_data):
                im_cache = image_store_data[list(image_store_data.keys())[slider_value[0]]][1]
        else:
            im_cache = image_store_data[list(image_store_data.keys())[0]][1]
        original_image = read_img_cache(im_cache)
        if len(original_image.shape)==3:
            original_image = original_image[:,:,:3]     #If 4 channels, convert to 3
            im = original_image[:,:,0]
        else:
            im = original_image
            original_image = original_image[:,:,np.newaxis]
        
        if 'overlap' not in changed_id :
            clrwhl = bldclrwhl(original_image.shape[0], original_image.shape[1],symmetry)
            cu_sym = symmetry
            graph2 = make_figure(clrwhl, True)
        else:
            clrwhl = np.array(json.loads(clrwhl), dtype='uint8')
            graph2 = dash.no_update
        
        rgb = nofft(clrwhl, im, im.shape[0], im.shape[1])
        iim = np.repeat(im[:,:,np.newaxis], 3, axis=2)
        rgb2 = original_image * (1-overlap) + overlap * iim * rgb
        clrwhl = json.dumps(clrwhl.tolist())
    else:
        content_type, content_string = rgb2.split(',')
        rgb2 = read_img_cache(content_string)
        graph2 = dash.no_update
    
    rgb2 = Image.fromarray(np.uint8(rgb2))
    img2 = rgb2.filter(ImageFilter.GaussianBlur(radius = blur_val)) 
    converter = PIL.ImageEnhance.Color(img2)
    img2 = converter.enhance(enh_val)
    converter = PIL.ImageEnhance.Brightness(img2)
    img2 = converter.enhance(bright_val)
    converter = PIL.ImageEnhance.Contrast(img2)
    img2 = converter.enhance(contra_val)
    
    graph1 = make_figure(img2)
    return [graph1, graph2], [{'width': '59%', 'display': 'inline-block', 'padding': '0 20'}, 
                              {'display': 'inline-block', 'width': '39%', 'vertical-align': 'top',
                              'margin-top': '5rem'}], rgb2, clrwhl, cu_sym

                                
                                
# Downloads image with the same file extension as it's original version
@app.callback(
    Output('download', 'children'),
    Input({'type': 'save-data', 'index': ALL}, 'n_clicks'),
    State({'type': 'graph', 'index': ALL}, 'figure'),
    State('image-store', 'data'),
    State({'type': 'image-slider', 'index': ALL}, 'value'),
    prevent_initial_call=True,
)
def func(n_clicks, image, image_store_data, slider_value):
    if n_clicks[0]>0:
        file_type = image_store_data[list(image_store_data.keys())[slider_value[0]]][0]
        file_type = file_type.split('/')[1]
        file_type = file_type.split(';')[0]
        im_cache = image[0]['data'][0]['source']
        filename = 'image.'+file_type
        url = "/download/" + urlquote(filename)
        return html.A(download=filename, href=im_cache, children=["Click here to start download"])
    pass
                                
# Run app and display result inline in the notebook
# app.run_server(mode='inline')
app.run_server(mode='external', host='0.0.0.0')

Dash app running on http://0.0.0.0:8050/
