In [1]:
%matplotlib widget
DEBUG = False

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from glob import glob
from skimage.io import imread
from skimage.color import rgb2gray
from ipyfilechooser import FileChooser
import traceback
import functools
import fnmatch
from astropy.io import fits
from scipy.ndimage import median_filter

FILE_PATTERNS = ['*.png', '*.jpg', '*.npy', '*.fits']

# Set the filter pattern to only allow PNG and JPG files
folder_selector = FileChooser(
    os.path.realpath('.'),
    show_only_dirs=True,
    select_desc="Select folder",
    change_desc="Change folder",
)

folder_selector.layout = widgets.Layout(width='600px')  # Expand width of the folder selection
folder_selector._select.layout = widgets.Layout(width="200px")

file_selection_widget = widgets.Select(
    options=[],
    description="Files in folder:",
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)
refresh_button = widgets.Button(description="Refresh file list", style={'description_width': 'initial'})

normalize_to_max_checkbox = widgets.Checkbox(value=False, description='Normalize to max')
apply_median_filter_checkbox = widgets.Checkbox(value=False, description='Apply median 3x3 filter')

vmin_vmax_slider = widgets.FloatRangeSlider(value=[0, 1], min=0, max=1., step=0.001, description='Contrast:', layout=widgets.Layout(width=f'{7*72}px'))
logy_checkbox = widgets.Checkbox(value=False, description='Log y scale', layout=widgets.Layout(width=f'{3*72}px'))
x_range_slider = widgets.IntRangeSlider(value=[0, 1], min=0, max=1, step=1, description='X Range:', layout=widgets.Layout(width=f'{8.5*72}px'))
y_range_slider = widgets.IntRangeSlider(value=[0, 1], min=0, max=1, step=1, description='Y Range:', orientation='vertical', layout=widgets.Layout(height=f'{7*72}px'))
error_widget = widgets.HTML() #widgets.Label() #widgets.HTML()

#fig, axes = plt.subplots(2, 3, figsize=(12, 8), gridspec_kw={'width_ratios': [4, 1, 1], 'height_ratios': [4, 1]})
#ax_img, ax_hist, ax_right = axes[0]
#ax_bottom, _, _ = axes[1]

plt.ioff() # Avoid displaying automatically - will need to be re-enabled at the end
fig_contrast, ax_hist = plt.subplots(1, 1, figsize=(6.5, 0.25))
fig_contrast.subplots_adjust(left=0, right=1, top=1, bottom=0)
fig_contrast.canvas.header_visible = False # Hide "Figure 1"
fig_contrast.canvas.footer_visible = False # Hide "(x, y) [z]" coordinates on mouse hover at the bottom
fig_contrast.canvas.toolbar_visible = False  # Hide zooming controls
#ax_hist.format_coord = lambda x, y: ""

fig_colorbar, ax_colorbar = plt.subplots(1, 1, figsize=(6.5, 0.5))
fig_colorbar.subplots_adjust(left=0, right=1, top=1, bottom=0)
fig_colorbar.canvas.header_visible = False
fig_colorbar.canvas.footer_visible = False
fig_colorbar.canvas.toolbar_visible = False 

fig, axes = plt.subplots(2, 2, figsize=(6, 6), gridspec_kw={'width_ratios': [4, 1], 'height_ratios': [4, 1], })
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False 
plt.gca().format_coord = lambda x, y: ""
ax_img, ax_right = axes[0]
ax_bottom, ax_empty = axes[1]
fig.subplots_adjust(left=0.1, right=1, top=1, bottom=0.05)

#ax_img.set_title("Image")
#ax_bottom.set_title("Integrated X")
#ax_right.set_title("Integrated Y")
ax_empty.clear()  # Remove any existing content
ax_empty.set_xticks([])  # Remove x-axis ticks
ax_empty.set_yticks([])  # Remove y-axis ticks
ax_empty.set_xticklabels([])  # Remove x-axis tick labels
ax_empty.set_yticklabels([])  # Remove y-axis tick labels
ax_empty.set_frame_on(False)  # Remove the frame

img = None
original_image = None

def filter_files(filenames, patterns):
    return [f for f in filenames if any(fnmatch.fnmatch(f, p) for p in patterns)]

def format_exception_html(exception):
    """Formats an exception as an HTML string with styling. UNSAFE!"""
    formatted_traceback = traceback.format_exc()
    return f"""
    <div style="color: red; font-family: monospace; white-space: pre-wrap;">
        <b>Error:</b> {str(exception)}<br>
        <pre>{formatted_traceback}</pre>
    </div>
    """

def catch_errors(widget):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            widget.value = ""
            try:
                return func(*args, **kwargs)
            except Exception as e:
                #error_message = f'```Error: SOME ERROR\nLine 1\n\nLine 2\n\n Line 3\n\n```\n</span>'
                #error_message = f'```\nError: {str(e)}\n{traceback.format_exc()}\n```\n</span>'
                #html = markdown.markdown(error_message)
                #widget.value = f'<span style="color: #ff0000;">{html}</span>'  # Update the widget with the error message
                widget.value = format_exception_html(e)
        return wrapper
    return decorator

def update_files_in_folder(change=None):
    folder = folder_selector.selected
    files = os.listdir(folder)
    # Only files matching, and sort alphabetically
    filtered_files = sorted(filter_files(files, FILE_PATTERNS))

    # Update widget
    file_selection_widget.options = filtered_files
    if filtered_files:
        # Select the first file
        file_selection_widget.value = file_selection_widget.options[0]

def update_hist_yscale(change=None):
    global ax_hist, ax_colorbar, logy_checkbox

    for ax in [ax_colorbar, ax_hist]:
        if logy_checkbox.value:
            ax.set_yscale('log')
            ax.relim() # recompute the data limits
            ax.autoscale(axis='y') 
            ax.set_ylim(0.1, None)
        else:
            ax.set_yscale('linear')
            ax.relim() # recompute the data limits
            ax.autoscale(axis='y') 
            ax.set_ylim(0, None)
        

@catch_errors(error_widget)
def reload_image(change=None):
    global img
    global original_image
    
    folder = folder_selector.selected
    # If both are not None
    if folder and file_selection_widget.value:
        file_path = os.path.join(folder, file_selection_widget.value)
    else:
        file_path = None
    # Returns if nothing is selected (but first clear all)
    if not file_path or not os.path.exists(file_path):
        ax_hist.clear()
        ax_colorbar.clear()
        ax_img.clear()
        ax_right.clear()
        ax_bottom.clear()
        return
    
    # Load and preprocess image
    if file_path.endswith('.npy'):
        img = np.load(file_path)
        # Image needs to be flipped upside-down
        img = np.flipud(img)
    elif file_path.endswith('.fits'):
        with fits.open(file_path) as hdul:
            hdu = hdul[0]
            img = hdu.data.copy()
    else:
        img = imread(file_path)
        # Image needs to be flipped upside-down
        img = np.flipud(img)
    
    if len(img.shape) == 3 and img.shape[2] == 4: # RGBA, convert to RGB on white background
        alpha_channel = img[:,:,3]
        rgb_channels = img[:,:,:3]

        # White Background Image
        white_background_image = np.ones_like(rgb_channels, dtype=np.uint8) * 255

        # Alpha factor
        alpha_factor = alpha_channel[:,:,np.newaxis].astype(np.float32) / 255
        alpha_factor = np.concatenate((alpha_factor,alpha_factor,alpha_factor), axis=2)
        
        # Transparent Image Rendered on White Background
        base = rgb_channels.astype(np.float32) * alpha_factor
        white = white_background_image.astype(np.float32) * (1 - alpha_factor)
        final_image = base + white
        original_image = img
        img = final_image.astype(np.uint8)
        #img = rgb_channels

    # Convert to RGB
    if len(img.shape) == 3:
        img = rgb2gray(img)    

    if DEBUG:
        error_widget.value = f"BEFORE: {img.dtype=}, {img.min()=}, {img.max()=}<br>"
    if img.dtype in (np.uint8, np.uint16):
        img = img.astype(float) / np.iinfo(img.dtype).max    
    if apply_median_filter_checkbox.value:
        img = median_filter(img, size=3)
#        if DEBUG:
#            error_widget.value += f"AFTER FILTER: {img.dtype=}, {img.min()=}, {img.max()=}<br>"
    if normalize_to_max_checkbox.value:
        img /= img.max()
    if DEBUG:
        error_widget.value += f"AFTER: {img.dtype=}, {img.min()=}, {img.max()=}<br>"

    # Update sliders based on image dimensions
    x_range_slider.min, x_range_slider.max = 0, img.shape[1]
    y_range_slider.min, y_range_slider.max = 0, img.shape[0]
    x_range_slider.value = 0, img.shape[1]
    y_range_slider.value = 0, img.shape[0]

    update_image(change=change)

def update_image(change=None):
    global img, ax_img, ax_hist, ax_bottom, ax_right

    vmin, vmax = vmin_vmax_slider.value

    # Update histogram
    hist, bins = np.histogram(img.flatten(), bins=256, range=(0, 1))
    for idx, ax in enumerate([ax_hist, ax_colorbar]):
        ax.clear()            
        ax.plot((bins[:-1] + bins[1:])/2, hist, color='red')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        update_hist_yscale() # Calling it again, as it is reset

    # x range
    ax_hist.set_xlim(0, 1)
    ax_colorbar.set_xlim(vmin, vmax)
    for ax in [ax_colorbar, ax_hist]:
        y_min, y_max = ax.get_ylim()
        gradient = np.linspace(0, 1, 256).reshape(1, -1)  # Horizontal gradient
        ax.imshow(gradient, extent=[vmin, vmax, y_min, y_max], aspect='auto', cmap='gray', vmin=0, vmax=1)
    ax_hist.axvline(vmin, color='blue', linestyle='--')
    ax_hist.axvline(vmax, color='blue', linestyle='--')
    
    # Update image display
    ax_img.clear()    
    ax_img.imshow(img, cmap='gray', vmin=vmin, vmax=vmax)
    ax_img.set_xlim(0, img.shape[1])
    ax_img.set_ylim(0, img.shape[0])
    ax_img.set_xticklabels([])
    ax_img.set_yticklabels([])
    ax_img.set_xticks([])
    ax_img.set_yticks([])

    
    # Update crop lines
    x1, x2 = x_range_slider.value
    y1, y2 = y_range_slider.value
    ax_img.axvline(x1, color='red', linestyle='--')
    ax_img.axvline(x2, color='red', linestyle='--')
    ax_img.axhline(y1, color='blue', linestyle='--')
    ax_img.axhline(y2, color='blue', linestyle='--')
    
    # Extract cropped region and update projections
    cropped = img[int(y1):int(y2), int(x1):int(x2)]
    
    ax_bottom.clear()
    ax_bottom.plot(np.mean(cropped, axis=0))
    ax_bottom.set_xlim(0, cropped.shape[1])
    
    ax_right.clear()
    ax_right.plot(np.mean(cropped, axis=1), np.arange(len(np.mean(cropped, axis=1))))
    ax_right.set_ylim(0, cropped.shape[0])

# Link events
folder_selector.register_callback(update_files_in_folder)
file_selection_widget.observe(reload_image, names='value')
vmin_vmax_slider.observe(update_image, names='value')
x_range_slider.observe(update_image, names='value')
y_range_slider.observe(update_image, names='value')
refresh_button.on_click(update_files_in_folder)
logy_checkbox.observe(update_hist_yscale, names='value')
apply_median_filter_checkbox.observe(reload_image, names='value')
normalize_to_max_checkbox.observe(reload_image, names='value')

display(widgets.VBox([folder_selector,
                      widgets.HBox([file_selection_widget, refresh_button]),
                      error_widget,
                      widgets.HBox([normalize_to_max_checkbox, apply_median_filter_checkbox]),
                      fig_contrast.canvas,
                      widgets.HBox([logy_checkbox, vmin_vmax_slider]), 
                      fig_colorbar.canvas,
                      widgets.HBox([y_range_slider,fig.canvas]),
                      x_range_slider, 
                      ]))

# Reactivate interactivity now that everything is displayed
plt.ion()
# Initialize with first image
reload_image()

VBox(children=(FileChooser(path='/home/jovyan/work/ICON_tests', filename='', title='', show_hidden=False, sele…