In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.widgets import TextBox, Button  # todo - add file loaders
from py4DSTEM.io import read
from py4DSTEM import DataCube
from filter_functions import filter_function, find_k_value, circular_filter_function, normalize_min_max, mask_images, azimuthal_sum_w_filter, data_theta

from tifffile import imread
from ipywidgets import interact
import matplotlib.patches as patches
# import torch
import PIL

In [None]:
datacube = read('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/0020 - original-centralized-masked.h5')
img = imread('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/0020 - B1- biogenic guanine 23000 x STEM HAADF.tif')

In [None]:
def extract_datacube_subset(datacube, x_range, y_range):
    """
    Extracts a subset of a Py4DSTEM datacube based on the given x and y ranges.
    """
    x_start, x_end = x_range
    y_start, y_end = y_range

    # Extract the data subset
    data_subset = datacube.data[y_start:y_end + 1, x_start:x_end + 1]
    dc_sub = DataCube(data_subset)
    return dc_sub



In [None]:

matplotlib.use('TkAgg')
# plt.ion()

# use PIL or tiffile
tif_img = img
# tif_img= PIL.Image.open(
#     '/Users/cadenmyers/billingelab/dev/sym_adapted_filts/yevgeny_proj/data/0020 - B1- biogenic guanine 23000 x STEM HAADF.tif')

x_range =  (0, datacube.shape[1]) # (40,130)
y_range = (0, datacube.shape[0]) #(0, 50)
# Extract the subset
dcsub = extract_datacube_subset(datacube, x_range, y_range)
stem_img = np.sum(dcsub.data, axis=(2, 3))
diffraction_patterns = dcsub.data

full_stem_img = np.sum(datacube.data, axis=(2,3))

# create the fig obj
fig, (ax_tif, ax_full, ax_stem, ax_dp) = plt.subplots(1, 4, figsize=(20, 5), num='4DSTEM')

# create the initial images

# tif
ax_tif.imshow(tif_img)
ax_tif.axis('off')

# stem subsection or entire stem if you did not clip the datacube
im_stem = ax_stem.imshow(stem_img, picker=True)
ax_stem.set_xlabel("X")
ax_stem.set_ylabel("Y")
ax_stem.axis('off')

# full stem image
im_full = ax_full.imshow(full_stem_img, picker=True)
# ax_full.set_xlabel("X")
# ax_full.set_ylabel("Y")
# ax_full.axis('off')

rect = patches.Rectangle(
    (x_range[0], y_range[0]),  # Bottom-left corner
    x_range[1] - x_range[0],  # Width
    y_range[1] - y_range[0],  # Height
    linewidth=.5, edgecolor='red', facecolor='none'
)
ax_full.add_patch(rect)

## Cursor
# Current cursor position
x_cen = stem_img.shape[1] // 2
y_cen = stem_img.shape[0] // 2
current_x, current_y = x_cen, y_cen
cursor = ax_stem.scatter(current_x, current_y, marker="+", color='red')  # todo - generalize on other figs

# dp
im_dp = ax_dp.imshow(diffraction_patterns[y_cen, x_cen])
ax_dp.set_title(f"Diffraction Pattern at ({x_cen}, {y_cen})")
ax_dp.axis('off')

# Create text boxes for manual input
axbox_x = plt.axes([0.2, 0.01, 0.25, 0.05])  # X-coordinate text box position
axbox_y = plt.axes([0.55, 0.01, 0.25, 0.05])  # Y-coordinate text box position
text_box_x = TextBox(axbox_x, 'X:', initial="0")
text_box_y = TextBox(axbox_y, 'Y:', initial="0")

# Create axes for error message
error_ax = plt.axes([0.2, 0.08, 0.6, 0.05])  # Above the text boxes
error_ax.axis('off')  # Hide the error box frame
error_text = error_ax.text(
    0.5, 0.5, "", color="red", ha="center", va="center", fontsize=10
)

# pick event on map
def on_click(event):
    if event.inaxes == ax_stem:
        x, y = int(event.xdata), int(event.ydata)  # Get clicked coordinates
        print(x, y)
        update(x, y)
        fig.canvas.draw_idle()  # Redraw the figure to reflect changes


# Callback for text box input
def on_text_submit(_):
    """Handle manual input from text boxes to update the cursor."""
    try:
        x = int(text_box_x.text)  # Get X-coordinate from text box
        y = int(text_box_y.text)  # Get Y-coordinate from text box
        update(x, y)
    except ValueError:
        msg = "Invalid input. Please enter integers only!"
        error_text.set_text(msg)
        print(msg)
        fig.canvas.draw_idle()  # Redraw the figure to reflect changes

# Callback for keyboard input
def on_key(event):
    """Move the cursor using arrow keys."""
    global current_x, current_y
    if event.key == 'up':
        update(current_x, current_y - 1, )
    elif event.key == 'down':
        update(current_x, current_y + 1, )
    elif event.key == 'left':
        update(current_x - 1, current_y, )
    elif event.key == 'right':
        update(current_x + 1, current_y, )


# Function to update the cursor
def update(x, y):
    """Move the scatter cursor to (x, y) and update the title."""
    image = stem_img  # todo - generalize and set several images

    global current_x, current_y
    if 0 <= x < image.shape[1] and 0 <= y < image.shape[0]:  # Check bounds
        current_x, current_y = x, y
        cursor.set_offsets([[x, y]])  # Update scatter point coordinates
        im_dp.set_data(diffraction_patterns[y, x])
        ax_dp.set_title(f"Cursor Location: x={x}, y={y}")  # Update the title
        text_box_x.set_val(str(x))  # Update the X text box
        text_box_y.set_val(str(y))  # Update the Y text box
        error_text.set_text("")  # Clear any error message
        fig.canvas.draw_idle()  # Redraw the figure to reflect changes
    else:
        msg = f"Coordinates out of bounds: x={x}, y={y}"
        error_text.set_text(msg)  # Clear any error message
        print(msg)
        fig.canvas.draw_idle()  # Redraw the figure to reflect changes

# Connect the text box inputs, mouse clicks, and keyboard input to the handlers
text_box_x.on_submit(on_text_submit)
text_box_y.on_submit(on_text_submit)
fig.canvas.mpl_connect('button_press_event', on_click)
fig.canvas.mpl_connect('key_press_event', on_key)

plt.show()

In [None]:
def mask_center(data, radius=20):
    """Mask the diffraction patterns up to a specified radius."""
    dp_masked = np.copy(data)
    center_y, center_x = data.shape[-2] // 2, data.shape[-1] // 2
    y, x = np.ogrid[:data.shape[-2], :data.shape[-1]]
    distance = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    mask = distance <= radius
    dp_masked[..., mask] = 0  # Apply mask along the last two axes
    return dp_masked

def min_max_normalize(arr):
    min_val = np.min(arr)
    max_val = np.max(arr)
    
    # Handle case where max and min are the same (to avoid division by zero)
    if max_val == min_val:
        return np.zeros_like(arr)  # or return arr directly, depending on your use case
    
    normalized_arr = (arr - min_val) / (max_val - min_val)
    return normalized_arr

# Define ranges
x_range = (0, 0)  # Adjust as needed
y_range = (0, datacube.shape[0])  # Adjust as needed

# Extract the subset
dcsub = extract_datacube_subset(datacube, x_range, y_range)
# mask
diffraction_patterns = mask_center(dcsub.data)

diffraction_patterns = np.squeeze(diffraction_patterns, axis=1)

# np.savez('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/subset_dps.npz', data=diffraction_patterns)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from matplotlib.widgets import TextBox, Slider
from tifffile import imread

matplotlib.use('TkAgg')

# Load the TIF image
tif_img = imread('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/yevgeny_proj/data_and_figs/0020 - B1- biogenic guanine 23000 x STEM HAADF.tif')  # Replace with your file path

# Define ranges
x_range = (0, datacube.shape[1])  # Adjust as needed
y_range = (0, datacube.shape[0])  # Adjust as needed

# Extract the subset
dcsub = extract_datacube_subset(datacube, x_range, y_range)
stem_img = np.sum(dcsub.data, axis=(2, 3))

full_stem_img = np.sum(datacube.data, axis=(2, 3))

# Initialize the threshold value
threshold = 30  # Adjust as needed

# Function to apply the threshold
def apply_threshold(data, threshold):
    data = np.where(data > threshold, data, 0)
    return data

# Create the figure and axes
fig, (ax_tif, ax_full, ax_stem, ax_dp) = plt.subplots(1, 4, figsize=(20, 5), num='4DSTEM')

# Initial image display
ax_tif.imshow(tif_img)
ax_tif.axis('off')

im_stem = ax_stem.imshow(stem_img, picker=True)
ax_stem.axis('off')

im_full = ax_full.imshow(full_stem_img, picker=True)
rect = patches.Rectangle((x_range[0], y_range[0]), x_range[1] - x_range[0], y_range[1] - y_range[0], linewidth=0.5, edgecolor='red', facecolor='none')
ax_full.add_patch(rect)

# Cursor and diffraction pattern
x_cen = stem_img.shape[1] // 2
y_cen = stem_img.shape[0] // 2
current_x, current_y = x_cen, y_cen
cursor = ax_stem.scatter(current_x, current_y, marker="+", color='red')
im_dp = ax_dp.imshow(apply_threshold(diffraction_patterns[y_cen, x_cen], threshold), cmap='viridis')
ax_dp.set_title(f"Diffraction Pattern at ({x_cen}, {y_cen})")
ax_dp.axis('off')

# Slider for threshold adjustment
ax_slider = plt.axes([0.2, 0.95, 0.6, 0.03], facecolor='lightgoldenrodyellow')
slider = Slider(ax_slider, 'Threshold', -20, 50, valinit=threshold, valstep=1)

# Update function
def update(x, y):
    global current_x, current_y
    if 0 <= x < stem_img.shape[1] and 0 <= y < stem_img.shape[0]:
        current_x, current_y = x, y
        cursor.set_offsets([[x, y]])
        dp_data = apply_threshold(diffraction_patterns[y, x], slider.val)
        im_dp.set_data(dp_data)
        ax_dp.set_title(f"Diffraction Pattern at ({x}, {y})")
        fig.canvas.draw_idle()

# Handle slider change
def update_threshold(val):
    dp_data = apply_threshold(diffraction_patterns[current_y, current_x], val)
    im_dp.set_data(dp_data)
    fig.canvas.draw_idle()

slider.on_changed(update_threshold)

# Event handlers for interactions
def on_click(event):
    if event.inaxes == ax_stem:
        x, y = int(event.xdata), int(event.ydata)
        update(x, y)

def on_key(event):
    global current_x, current_y
    if event.key == 'up':
        update(current_x, current_y - 1)
    elif event.key == 'down':
        update(current_x, current_y + 1)
    elif event.key == 'left':
        update(current_x - 1, current_y)
    elif event.key == 'right':
        update(current_x + 1, current_y)

fig.canvas.mpl_connect('button_press_event', on_click)
fig.canvas.mpl_connect('key_press_event', on_key)

plt.show()
