In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from ipywidgets import interact, IntSlider
import py4DSTEM
import matplotlib.patches as patches
from py4DSTEM import DataCube
from filter_functions import filter_function, circular_filter_function, normalize_min_max
from bg_mpl_stylesheets.styles import all_styles
plt.style.use(all_styles["bg-style"])

In [None]:
path = '/Users/cadenmyers/billingelab/dev/sym_adapted_filts/yevgeny_proj/data/0020 - original-centralized-masked.h5'
dc = py4DSTEM.read(path)

### function to extract 4DSTEM data subset

In [None]:
def extract_datacube_subset(datacube, x_range, y_range):
    x_start, x_end = x_range
    y_start, y_end = y_range

    # Extract the data subset
    data_subset = datacube.data[y_start:y_end + 1, x_start:x_end + 1]
    dc_sub = DataCube(data_subset)
    return dc_sub



In [None]:
stem_img = np.sum(dc.data, axis=(2, 3))

In [None]:
#SANITY CHECK - making sure I'm plotting the right thing

x_range = (100,120)
y_range = (60, 90)
dcsub = extract_datacube_subset(dc, x_range, y_range)

rect = patches.Rectangle(
    (x_range[0], y_range[0]),  # Bottom-left corner
    x_range[1] - x_range[0],  # Width
    y_range[1] - y_range[0],  # Height
    linewidth=.5, edgecolor='red', facecolor='none'
)

fig, ax = plt.subplots()
ax.add_patch(rect)
ax.imshow(stem_img, interpolation='nearest', origin='lower')
ax.set_title(f'x-range={x_range}, y-range={y_range}')


### Useful functions

In [None]:

def mask_annulus(data, inner_radius=22, outer_radius=80):
    '''Mask the diffraction pattern within an inner radius and outside an outer radius.'''
    dp = np.copy(data)

    center_y, center_x = dp.shape[0] // 2, dp.shape[1] // 2
    y, x = np.ogrid[:dp.shape[0], :dp.shape[1]]
    distance = np.sqrt((x - center_x)**2 + (y - center_y)**2)

    # Create the mask: True for pixels inside the inner radius or outside the outer radius
    mask = (distance <= inner_radius) | (distance > outer_radius)
    
    dp[mask] = 0  # Apply the mask to the data

    return dp

dat = dc.data[20][0]
dp = mask_annulus(dc.data[20][0], 18, 80)
# plt.imshow(dp)
# interact(mask_center_plotting, index=IntSlider(value=0, min=0, max=141, step=1, description='Image Index:'));

In [None]:
# SIGNAL APPLIFICATION CHECK
dat = dc.data[85][0]
dp = mask_annulus(dat, 22, 80)
# plt.imshow(dat)

threshold = 1/2*dp.max() #4 * dp.std()  # Threshold based on mean + std deviation
print(threshold)
dp[dp < threshold] = 0  # Suppress values below the threshold
dp[dp >= threshold] *= 2  # Boost the remaining values
# plt.colorbar()
# plt.show()
# plt.imshow(dp)
# plt.colorbar()


### Filters and GD parameters

In [None]:

im_shape = dc.data[0][0].shape
print(im_shape)
x,y = np.meshgrid(np.arange(-im_shape[0] // 2 ,im_shape[0] // 2), np.arange(-im_shape[1] // 2 ,im_shape[1] // 2))
DATA_THETA = torch.atan2(torch.tensor(x), torch.tensor(y))

# for the model
MAX_ITER_OFFSET = 101
LR = 1e-2
OFFSET_ADJUSTMENT = 60
print(f'Iterations = {MAX_ITER_OFFSET}')
k=100

# a and b vectors in q-space, r is distance from center. Use this to find correct filter formula for tetragonal symmetry
a = 10.
b = 3.
r = 15.
# delta = 2*torch.arcsin(torch.tensor(.5*a/r)) or delta = 2*torch.arccos(torch.tensor(.5*b/r))
delta = torch.deg2rad(torch.tensor(38.)) # delta of 38 degrees is my approximation from dp x=60 y=59
n_fans = 4
print('n_fans = ', n_fans, '(used for rectangular filter)')
# print('delta =', round(np.rad2deg(delta.item()), 3), 'deg')
def rectangular_filter_function(k, theta1, theta2, delta=delta, n_fans=n_fans):
    filt = torch.exp(k * torch.log((torch.cos(n_fans / 4 * theta1))**2)) + torch.exp(k * torch.log((torch.cos(n_fans / 4 * (theta2 + delta)))**2))
    # plt.imshow(filt)
    # plt.title(f'n_fans={n_fans}, k={k}, delta={round(np.rad2deg(delta.item()), 3)}deg')
    # plt.show()
    return filt

off = torch.deg2rad(torch.tensor(-54.))
filt = rectangular_filter_function(k, DATA_THETA+off, DATA_THETA+off)
# plt.imshow(filt + normalize_min_max(dc.data[59][60]))


### Gradient Descent Functions

In [None]:
# GRADIENT DESCENT WITH DYNAMIC K
def FWHM_fit(k, n_folds=n_folds):
    y = 4/n_folds*np.arccos(np.sqrt(np.exp(np.log(1/2)/k)))
    return y

def k_from_fwhm(FWHM, n_folds=n_folds):
    return -np.log(1/2) / np.log(np.cos(FWHM * n_folds / 4)**2)

def linear_fwhm(iteration, max_iterations, start_fwhm, end_fwhm):
    m = (end_fwhm - start_fwhm) / max_iterations  # Linear slope
    return m * iteration + start_fwhm

offset_init = torch.tensor(np.deg2rad(-30.), requires_grad=True)
print(f'offset_init = {torch.rad2deg(offset_init)} deg')
def gradient_descent_optimize_offset(intensity, offset=offset_init, k_start=1, k_end=100, linearly_decrease_k=False, n_folds=n_folds):
    opt = torch.optim.Adam([offset], lr=LR)

    # parameters for a linear decrease of filter FWHM
    fwhm_start = FWHM_fit(k_start, n_folds)
    fwhm_end = FWHM_fit(k_end, n_folds)
    fwhm_linear = np.linspace(fwhm_start, fwhm_end, MAX_ITER_OFFSET)

    intensity = normalize_min_max(mask_annulus(intensity))
    threshold = 1/2*intensity.max() # Threshold based on mean + std deviation
    # print(threshold)
    intensity[intensity < threshold] = 0  # Suppress values below the threshold
    intensity[intensity >= threshold] *= 2  # Boost the remaining values
    for i in range(MAX_ITER_OFFSET):
        if linearly_decrease_k == True:
            # linear decrease of k
            k_current = k_start + (k_end - k_start) * (i / (MAX_ITER_OFFSET - 1))
        if linearly_decrease_k == False:
            # linear decrease of FWHM
            fwhm_current = fwhm_linear[i]
            k_current = -k_from_fwhm(fwhm_current)
        # print(f'{i}:', k_current)
        evaluate_image_theta = filter_function(k_current, DATA_THETA+offset, n_folds) # rectangular_filter_function(k_current, DATA_THETA + offset, DATA_THETA + offset)
        loss = -(torch.tensor(intensity) * evaluate_image_theta).sum()
        opt.zero_grad()
        loss.backward()
        opt.step()
        if i % 5 == 0:
            filt = evaluate_image_theta.clone().detach().numpy()
            plt.imshow(filt + intensity)
            plt.title(f'iter: {i}')
            plt.show()
    return offset, intensity, evaluate_image_theta, loss

rx = 60
ry = 100
optoffset, intensity, optfilter, _ = gradient_descent_optimize_offset(dc.data[ry][rx])
print(f'optoffset = {round(np.rad2deg(optoffset.item()), 2)} deg')
print(delta)
plt.imshow(optfilter.detach().numpy() + intensity)
plt.title(f'n_folds={n_folds}, iterations={MAX_ITER_OFFSET}, optoffset={round(np.rad2deg(optoffset.item()), 3)}')
plt.show()

### RUN GRADIENT DESCENT


In [None]:

datalist = []
# dcsub = extract_datacube_subset(dc, x_range, y_range)
# fig, axs = plt.subplots(1, 2, figsize=(10,5))
x_shape = dc.data.shape[1]
y_shape = dc.data.shape[0]
print('x-shape=', x_shape)
print('y-shape=', y_shape)
print('DPs evaluated=', x_shape*y_shape)

offset_init = torch.tensor(np.deg2rad(0.), requires_grad=True)
for rx in range(0, x_shape): # 217
    for ry in range(0, y_shape): # 142
        dp = mask_annulus(dc.data[ry][rx])
        # if dp.max() > 0:
        offset_init = torch.tensor(np.deg2rad(0.), requires_grad=True)
        offset, intensity, filt, _ = gradient_descent_optimize_offset(dp, offset_init)
        dict = {'offset': np.rad2deg(offset.item()), 'xcoord': rx, 'ycoord': ry}
        datalist.append(dict)
        # if ry % 5 == 0:
        #     filt_copy = filt.detach().clone().numpy()
        #     fig, axs = plt.subplots(1, 2, figsize=(10, 5))  # Create a figure with 1 row and 2 columns
        #     axs[0].imshow(filt_copy + intensity)
        #     axs[0].set_title(f'x,y = {rx},{ry}, optoffset={round(np.rad2deg(offset.item()), 3)} deg')
        #     axs[1].imshow(dc.data[ry][rx])
        #     plt.show()
        # else:
        #     print(f'x,y = {rx},{ry} is NaN, max={dp.max()}')
        #     dict = {'offset': np.NaN, 'xcoord': rx, 'ycoord': ry}
        #     datalist.append(dict)
        #     # plt.imshow(dp)
        #     # plt.show()
    print(f'optimizing on column {rx+1}/{x_shape}')

In [None]:
# Define the interactive function to be used with sliders
def check_algorithm(x, y):
    offset_init = torch.tensor(0.0, requires_grad=True)
    intensity = dc.data[y][x]  # Accessing the data for the selected x and y
    opt_offset, opt_filter, _ = gradient_descent_optimize_offset(intensity)
    print(f'offset_init = {offset_init}')
    print(f'Iterations = {MAX_ITER_OFFSET}')
    print("n_folds = ", n_folds)
    print('k value =', k)
    print(f'optoffset = {opt_offset.item()}')
    
    # Plot the result
    plt.imshow(opt_filter.detach().numpy() + normalize_min_max(intensity))
    plt.title(f"Optimized Result for (x={x}, y={y})")
    # plt.colorbar()
    plt.show()

# Create sliders for x and y values
x_slider = IntSlider(min=0, max=dc.data.shape[1]-1, step=1, value=89, description='X Value:')
y_slider = IntSlider(min=0, max=dc.data.shape[0]-1, step=1, value=98, description='Y Value:')

# Use interact to link the sliders with the check_algorithm function
interact(check_algorithm, x=x_slider, y=y_slider)

In [None]:
# FIND FILTER WIDTH DEPENDENCE ON K

def compute_fwhm(k_value, n_folds=n_folds, resolution=1000):
    theta_range = (-np.pi/n_folds, np.pi/n_folds)
    theta = torch.linspace(theta_range[0], theta_range[1], resolution)

    # Evaluate the filter function
    filter_vals = filter_function(k_value, theta, n_folds)

    # Find the peak value
    peak_val = torch.max(filter_vals).item()

    diff = torch.abs(filter_vals - .5)
    index = torch.argmin(diff).item()
    _, indices = torch.topk(diff, 2, largest=False)

    theta_2 = theta[indices[0]]
    theta_1 = theta[indices[1]]
    fwhm = theta_2.item() - theta_1.item()

    return fwhm

# Example usage:
k_value = 1
fwhm_results = compute_fwhm(k_value)

print(f"k = {k_value}, FWHM = {fwhm_results} radians")

fwhms = []
ks = []
for i in range(1,100):
    fwhm = compute_fwhm(i)
    ks.append(i)
    fwhms.append(fwhm)

# fitting function
def FWHM_fit(k, n_folds=n_folds):
    y = 4/n_folds*np.arccos(np.sqrt(np.exp(np.log(1/2)/k)))
    return y

x_fit = np.linspace(min(ks), max(ks), 99)
ys = FWHM_fit(x_fit)
plt.plot(ks, ys, color='red', label='fit')
plt.scatter(ks, fwhms, label='calculated FWHM')
plt.xlabel('k')
plt.ylabel('FWHM')
# plt.ylim(0, np.pi/6)
plt.title(f'n_folds={n_folds}')
plt.legend()
plt.show()

In [None]:
# RELATE OFFSET TO X AND Y POSITION, PLOT OFFSET HEATMAP, PLOT FILTER + DP

from tifffile import imread
import matplotlib.gridspec as gridspec

# Load data
img = imread('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/0020 - B1- biogenic guanine 23000 x STEM HAADF.tif')
heatmap = np.load('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/offsetheatmap_2fold_dynamic_k.npz')['data']

# Create a figure with specified size
fig = plt.figure(figsize=(15, 7.5))

# Define grid specifications for 1 row and 2 columns with slightly adjusted width ratios
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1], wspace=0.02)

# First plot (without tick labels or ticks)
ax0 = fig.add_subplot(gs[0])
cax0 = ax0.imshow(heatmap)
ax0.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)

# Second plot (without tick labels or ticks)
ax1 = fig.add_subplot(gs[1])
ax1.imshow(img)
ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)

# Add a horizontal colorbar below the first plot
cbar = fig.colorbar(cax0, ax=ax0, orientation='horizontal', fraction=0.06, pad=0.02)
cbar.set_label('Lattice Orientation (degrees)', labelpad=5)
plt.subplots_adjust(bottom=0.283)
# Show the figure
plt.show()
# xcoords = [entry['xcoord'] for entry in datalist]
# ycoords = [entry['ycoord'] for entry in datalist]
# offsets = [entry['offset'] for entry in datalist]

# x_range = np.unique(xcoords)
# y_range = np.unique(ycoords)
# offsetmap = np.full((len(y_range), len(x_range)), np.nan)

# for offset, x, y in zip(offsets, xcoords, ycoords):
#     x_idx = np.where(x_range == x)[0][0]
#     y_idx = np.where(y_range == y)[0][0]
#     offsetmap[y_idx, x_idx] = offset

# fig, axs = plt.subplots(1, 2, figsize=(22, 10))
# im1 = axs[0].imshow(np.flipud(offsetmap), cmap='viridis', interpolation='nearest', origin='lower')
# axs[0].set_title("2-fold filter with dynamic k (linearly decreasing FWHM)")
# cbar1 = fig.colorbar(im1, ax=axs[0], orientation='vertical',fraction=0.02, pad=0.04)
# cbar1.set_label('Offset Value (degrees)')
# axs[1].imshow(img, cmap='gray', interpolation='nearest')

# def plot(index):
#     '''plotting filter + diff pattern'''
#     offset_val = np.deg2rad(offsets[index])
#     filt = filter_function(8, DATA_THETA+offset_val)
#     dp = dc.data[ry][index]
#     plt.imshow(normalize_min_max(dp) + normalize_min_max(filt.numpy()))
#     plt.title(f'{round(offsets[index])}')

# interact(plot, index=IntSlider(value=0, min=0, max=216, step=1, description='Image Index:'))

In [None]:
# SAVE OFFSET HEATMAP AS NPZ

# np.savez('offsetheatmap_2fold_dynamic_k.npz', data=offsetmap)

In [None]:

ry =0

def plotting(index):
    '''plot overlap score vs. phi over angle range near offset'''
    dp = dc.data[ry][index]
    dp = mask_annulus(dp)
    offset, _, _ = gradient_descent_optimize_offset(dp)
    dp = normalize_min_max(dp)
    angle = np.rad2deg(offset.item())
    range_deg = np.arange(angle-45, angle+45, .5)
    # def plotting(index):
    #     filter = filter_function(k, DATA_THETA + np.deg2rad(range_deg[index]))
    #     plt.imshow(filter)
    #     plt.title(f'{ran[index]}')
    #     plt.show()
    # # interact(plotting, index=IntSlider(value=0, min=start_idx, max=ran.shape[0], step=1))

    overlap_list = []
    for i in range(0, range_deg.shape[0]):
        filter = filter_function(k, DATA_THETA + np.deg2rad(range_deg[i]))
        overlap = (filter*dp).sum()
        overlap_list.append(overlap.item())

    filt = filter_function(k, DATA_THETA + offset.item())
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0].plot(overlap_list)
    axs[0].set_title(f'y-range = {round(max(overlap_list) - min(overlap_list))}')

    im = axs[1].imshow(normalize_min_max(dp) + filt.numpy())
    axs[1].imshow(normalize_min_max(dp) + filt.numpy())
    axs[1].set_title(f'{round(angle,3)}')
    fig.colorbar(im, ax=axs[1])


    plt.tight_layout()
    plt.show()
    # print(max(overlap_list) - min(overlap_list))

interact(plotting, index=IntSlider(value=0, min=0, max=10, step=1, description='Image Index:'));