# Image widget development

Colin Ophus - 2023 July

This is an example widget fot plotting images interactively.

In [1]:
%matplotlib widget

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from IPython.display import display
from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label, HTML, Text, Layout
# from ipywidgets import widgets, interact, GridspecLayout, Layout,  Layout, Label, 
# from matplotlib import cm
# from ipywidgets import AppLayout, FloatSlider, FloatLogSlider, Layout

In [4]:
image_stack = np.load(
    '/Users/cophus/repos/dev/im_graphene_EWR.npz')['im_graphene_EWR']

In [5]:
# image scaling, plotting variables

hist_range_plot = (-4,4)
hist_range_init = (-1,2)
hist_num_bins = 200

# Calculate image mean and standard deviations, plotting ranges

# init
# 0 - mean
# 1 - standard deviation
# 2 - min hist range
# 3 - max hist range
# 4 - current vmin
# 5 - current vmax
int_ranges = np.zeros((image_stack.shape[0],6))

hist_bins_all = np.zeros((image_stack.shape[0],hist_num_bins))
hist_data_all = np.zeros((image_stack.shape[0],hist_num_bins))

# loop over images
for a0 in range(image_stack.shape[0]):
    int_mean = np.mean(image_stack[a0])
    int_std = np.sqrt(np.mean((image_stack[a0] - int_mean)**2))

    int_min = int_mean + hist_range_plot[0] * int_std
    int_max = int_mean + hist_range_plot[1] * int_std

    init_min = int_mean + hist_range_init[0] * int_std
    init_max = int_mean + hist_range_init[1] * int_std

    int_ranges[a0] = (
        int_mean, 
        int_std, 
        int_min, 
        int_max,
        init_min,
        init_max,
    )

    # histogram data
    hist_bins = np.linspace(
        int_ranges[a0,2],
        int_ranges[a0,3],
        hist_num_bins+1,
        endpoint=True)
    hist_data, _ = np.histogram(
        image_stack[a0].ravel(),
        bins=hist_bins,
    )
    hist_data = hist_data.astype('float')
    hist_data /= np.max(hist_data)

    hist_bins_all[a0] = hist_bins[:-1] + (hist_bins[1] - hist_bins[0])/2
    hist_data_all[a0] = hist_data
    

In [6]:


# Variables
class WidgetData():
    def __init__(
        self,
        image_stack,
        int_ranges,
        ):

        self.image_stack = image_stack
        self.ind_image = 0
        self.cmap = 'gray'
        self.ind_point = None
        self.int_ranges = int_ranges
        self.y_h = 1.1

        
        # widget figure generation
        with plt.ioff():
            self.fig_image = plt.figure(figsize = (5,5))
            self.fig_hist = plt.figure(figsize = (3,2))
            # fig2.canvas.resizable = False
        
            # initial plots
            self.ax_image = self.fig_image.add_axes([0.0, 0.0, 1.0, 1.0])
            self.ax_hist = self.fig_hist.add_axes([0.08, 0.22, 0.90, 0.76])
            self.h_hist_x_label = self.ax_hist.set_xlabel('exit wave phase')
            self.ax_hist.set_ylabel('Intensity')
            
            # image
            self.h_image = self.ax_image.imshow(
                image_stack[0],
                vmin = int_ranges[0,4],
                vmax = int_ranges[0,5],
                cmap = 'gray',
            )
            self.ax_image.set_xticks(())
            self.ax_image.set_yticks(())
            # self.h_image.get_xaxis().set_ticks([])
            # self.h_image.get_yaxis().set_ticks([])
            # self.h_image.set(yticks=[])
            # self.h_image.set(yticklabels=[])

            # histogram
            self.h_hist = self.ax_hist.fill_between(
                hist_bins_all[0],
                hist_data_all[0],
                color = (0, 0.7, 1.0, 1.0),
            );
            self.h_vlines = self.ax_hist.vlines(
                int_ranges[0,4:6],
                ymin = 0,
                ymax = self.y_h,
                color = 'k',
            )
            self.ax_hist.set_xlim((int_ranges[0,2], int_ranges[0,3]));
            self.ax_hist.set_ylim((0, self.y_h));
            self.ax_hist.set(yticks=[])
            self.ax_hist.set(yticklabels=[])
        
            # General appearance
            # fig_image.canvas.toolbar_visible = False
            self.fig_image.canvas.header_visible = False
            self.fig_image.canvas.footer_visible = False
            self.fig_image.canvas.resizable = False
            self.fig_hist.canvas.toolbar_visible = False
            self.fig_hist.canvas.header_visible = False
            self.fig_hist.canvas.footer_visible = False
            self.fig_hist.canvas.resizable = False



# initialize widget
w = WidgetData(
    image_stack,
    int_ranges,
)

In [7]:
#| label: app:image_widget

# Widget

# interactive movement of histogram 
def button_press_callback(event):
    # mouse button pressed
    # if left click and within axes, proceed with mouse movement event
    if event.inaxes is None:
        return
    if event.button != 1:
        return
    # convert from screen coordinates to axis coordinates
    t = w.ax_hist.transData.inverted()
    xy = t.transform([event.x,event.y])
    # determine closest point and initial phase of point
    w.ind_point = np.argmin(np.abs(w.int_ranges[w.ind_image,4:6] - xy[0])).astype('int')
    # print(ind_point)

def motion_notify_callback(event):
    # during mouse movement
    if w.ind_point is None:
        return
    if event.inaxes is None:
        return
    if event.button != 1:
        return
    # convert from screen coordinates to axis coordinates
    t = w.ax_hist.transData.inverted()
    xy = t.transform([event.x,event.y])
    # update vline positions
    p = w.h_vlines.get_segments()
    p[w.ind_point] = np.array([
        [xy[0], 0],
        [xy[0], w.y_h],
    ])
    w.h_vlines.set_segments(p)
    # update int_ranges
    if w.ind_point == 0:
        w.int_ranges[w.ind_image,4] = xy[0]
    elif w.ind_point == 1:
        w.int_ranges[w.ind_image,5] = xy[0]
    # update image color range
    w.h_image.set_clim(w.int_ranges[w.ind_image,4:6])

def button_release_callback(event):
    # mouse button released
    if event.button != 1:
        return
    w.ind_point = None
    
    
# button callbacks
w.fig_hist.canvas.mpl_connect('button_press_event', button_press_callback)
w.fig_hist.canvas.mpl_connect('motion_notify_event', motion_notify_callback)    
w.fig_hist.canvas.mpl_connect('button_release_event', button_release_callback)



   
# Dropdown widget to switch which image is plotted 
option_list_image = (
    'exit wave phase', 
    'exit wave amplitude', 
)
def update_image(change):
    
    if change.new == option_list_image[0]:
        w.ind_image = 0
    elif change.new == option_list_image[1]:
        w.ind_image = 1

    # update image
    w.h_image.set_data(image_stack[w.ind_image])
    w.h_image.set_clim(w.int_ranges[w.ind_image,4:6])

    # update histogram
    w.h_hist.remove()
    w.h_hist = w.ax_hist.fill_between(
        hist_bins_all[w.ind_image],
        hist_data_all[w.ind_image],
        color = (0, 0.7, 1.0, 1.0),
    );
    w.ax_hist.set_xlim((
        int_ranges[w.ind_image,2],
        int_ranges[w.ind_image,3],
    ))
    w.h_hist_x_label.set_text(option_list_image[w.ind_image])

    # update vlines
    p = w.h_vlines.get_segments()
    p[0] = np.array([
        [int_ranges[w.ind_image,4], 0],
        [int_ranges[w.ind_image,4], w.y_h],
    ])
    p[1] = np.array([
        [int_ranges[w.ind_image,5], 0],
        [int_ranges[w.ind_image,5], w.y_h],
    ])
    w.h_vlines.set_segments(p)
    
dropdown_image = Dropdown(
    options = option_list_image,
    layout = Layout(width='300px',height='30px'),
)
dropdown_image.observe(update_image, names='value')




   
# Dropdown widget to switch which image is plotted 
option_list_cmap = (
    'gray', 
    'inferno',
    'turbo',
)
def update_cmap(change):
    
    if change.new == option_list_cmap[0]:
        w.h_image.set_cmap('gray')
    elif change.new == option_list_cmap[1]:
        w.h_image.set_cmap('inferno')
    elif change.new == option_list_cmap[2]:
        w.h_image.set_cmap('turbo')
        
dropdown_cmap = Dropdown(
    options = option_list_cmap,
    layout = Layout(width='300px',height='30px'),
)
dropdown_cmap.observe(update_cmap, names='value')




# Output widget
# out = Output(layout={'border': '1px solid black'})



# Construct overall widget layout
# with out:
widget_image = widgets.HBox(
    [
        w.fig_image.canvas,
        widgets.VBox(
            [
                HTML(value='<h3>Image plotting range</h3>'),
                w.fig_hist.canvas,
                HTML(value='<h3>Plotting Image</h3>'),
                dropdown_image,
                HTML(value='<h3>Color map</h3>'),
                dropdown_cmap,
            ],
        ),
    ],
)
display(widget_image);

HBox(children=(Canvas(footer_visible=False, header_visible=False, resizable=False, toolbar=Toolbar(toolitems=[…

In [8]:
w.ax_image.set_xticks

<bound method _AxesBase.set_xticks of <Axes: >>