In [None]:
import numpy as np
import torch
import py4DSTEM
from py4DSTEM import DataCube
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

In [None]:
path = '/Users/cadenmyers/billingelab/dev/skyrmion_lattices/yevgeny_proj/data/0020 - original-centralized-masked.h5'
dc = py4DSTEM.read(path)

In [None]:
def extract_datacube_subset(datacube, x_range, y_range):
    """
    Extracts a subset of a Py4DSTEM datacube based on the given x and y ranges.

    Parameters:
        datacube (DataCube): The original Py4DSTEM DataCube.
        x_range (tuple): A tuple specifying the range of x (start, end) (inclusive).
        y_range (tuple): A tuple specifying the range of y (start, end) (inclusive).

    Returns:
        DataCube: A new DataCube containing the subset of data within the specified ranges.
    """
    x_start, x_end = x_range
    y_start, y_end = y_range

    # Extract the data subset
    data_subset = datacube.data[y_start:y_end + 1, x_start:x_end + 1]
    dc_sub = DataCube(data_subset)
    return dc_sub

x_range = (40,130)
y_range = (0, 50)
dcsub = extract_datacube_subset(dc, x_range, y_range)

In [None]:
# FILTER FUNCTION

im_shape = dc.data[0][0].shape
print(im_shape)
x,y = np.meshgrid(np.arange(-im_shape[0] // 2 ,im_shape[0] // 2), np.arange(-im_shape[1] // 2 ,im_shape[1] // 2))
DATA_THETA = torch.atan2(torch.tensor(x), torch.tensor(y))
offset1 = torch.tensor(0., requires_grad=True)

# for the model
MAX_ITER_OFFSET = 31
LR = 1e-2
OFFSET_ADJUSTMENT = 60

n_folds = 4
k=70
step_size = 0.5
deg_range = np.arange(-360/(2*n_folds), 360/(2*n_folds), step_size)
print("n_folds =", n_folds)
print('k value =', k)
def filter_function(k, theta, n_folds=n_folds):
    filter = torch.exp(k * torch.log((torch.cos(n_folds / 2 * theta))**2))
    return filter

# plt.imshow(filter_function(k, DATA_THETA))

# 25deg is approximately what i determined for the angle between the filters (delta). From DP x=48, y=59
delta = torch.deg2rad(torch.tensor(25.))
print('delta =', round(np.rad2deg(delta.item()), 3), 'deg')
def rectangular_filter_function(k, theta, delta=delta, n_folds=n_folds):
    filt = torch.exp(k * torch.log((torch.cos(n_folds / 4 * theta))**2)) + torch.exp(k * torch.log((torch.cos(n_folds / 4 * (theta + delta)))**2))
    # plt.imshow(filter)
    # plt.title(f'n_folds={n_folds}, k={k}, delta={round(np.rad2deg(delta.item()), 3)}deg')
    # plt.show()
    return filt


def normalize_min_max(data):
    if isinstance(data, torch.Tensor):
        array = data.detach().numpy()
    else:
        array = data
    array_min = np.min(array)
    array_max = np.max(array)
    norm_array = (array - array_min) / (array_max - array_min)
    if isinstance(data, torch.Tensor):
        norm_tensor = torch.tensor(norm_array)
        return norm_tensor
    else:
        return norm_array
    

def mask_center(data, radius):
    '''mask diff pattern up to some radius'''
    dp = data

    center_y, center_x = dp.shape[0] // 2, dp.shape[1] // 2
    y, x = np.ogrid[:dp.shape[0], :dp.shape[1]]
    distance = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    mask = distance <= radius
    dp_masked = np.copy(dp)
    dp_masked[mask] = 0

    return dp_masked

def grid_search(rx, ry, datacube, step_size=step_size, n_folds=n_folds):
    dp_raw = datacube.data[ry, rx]
    dp = mask_center(dp_raw, 18)
    overlap_vals = []
    deg_range = np.arange(-360/(n_folds), 360/(n_folds), step_size)

    for angle in deg_range:
        rad_angle = np.deg2rad(angle)
        filter = rectangular_filter_function(k, DATA_THETA + rad_angle).numpy()
        overlap_score = (filter + dp).sum()
        overlap_vals.append(overlap_score)

    max_overlap_index = np.argmax(overlap_vals)
    max_overlap_angle = deg_range[max_overlap_index]
    return deg_range, overlap_vals, dp, max_overlap_angle

#     axs[0].plot(deg_range, overlap_vals)
#     axs[0].set_xlabel('filter angle')
#     axs[0].set_ylabel('overlap score')
#     axs[0].set_title(f'Y-range: {max(overlap_vals) - min(overlap_vals)}')

#     axs[1].imshow(dp)
#     axs[1].set_title(f'{rx}, {ry}')
# plt.draw()

x = rectangular_filter_function(k, DATA_THETA-.7)

plt.imshow(normalize_min_max(dc.data[59][48]) + x.numpy())

In [None]:
# scrap
ry = 10
rx = 10

deg_range, overlaps, dp, offset = grid_search(rx, ry, dcsub)

print(max(overlaps))
offset_index = np.argmax(overlaps)
# print(offset_index)
print(deg_range[offset_index])
print(offset)
plt.imshow(normalize_min_max(dp)+rectangular_filter_function(k,DATA_THETA + offset).numpy())
plt.show(0)
plt.imshow(dp)
plt.show()


## Run Grid Search

In [None]:
offset_list = []
for rx in range(0, dcsub.data.shape[0]+1):
    for ry in range(0, dcsub.data.shape[1]+1):
        _, _, _, offset = grid_search(rx, ry)
        optfilter = rectangular_filter_function(k, DATA_THETA+offset)
        offset_list.append((offset, optfilter))
    # for ry in range(0, dcsub.data.shape[1]+1):

In [None]:
offsets = [item[0] for item in offset_list]

print((dcsub.data.shape[0]+1) * (dcsub.data.shape[1]+1))

print(len(offsets))
print(len(offset_list))

reshaped_values = np.array(offsets).reshape(dcsub.data.shape[0]+1, dcsub.data.shape[1]+1)


In [None]:
rx = 100
ry = 20
dp = dc.data[ry, rx]
overlap_vals = []
fig, axs = plt.subplots(1, 2, figsize = (12,6))

# Loop through angle range
for angle in deg_range:
    rad_angle = np.deg2rad(angle)
    filter = filter_function(k, DATA_THETA + rad_angle)  # Keep as tensor
    overlap_score = (filter + dp).sum().item()  # Use .item() to get a scalar value
    overlap_vals.append(overlap_score)

    # Optionally update rx, ry dynamically here if you want to change them during the loop
    # rx, ry = new_values  # Update rx, ry if needed

    # Update the plot at each iteration (interactive)
    axs[0].cla()  # Clear axis for dynamic updates
    axs[0].plot(deg_range, overlap_vals)
    axs[0].set_xlabel('filter angle')
    axs[0].set_ylabel('overlap score')
    axs[0].set_title(f'Y-range: {max(overlap_vals) - min(overlap_vals)}')

    axs[1].cla()  # Clear axis for dynamic updates
    axs[1].imshow(dp)
    axs[1].set_title(f'{rx}, {ry}')
    plt.draw()
    plt.pause(0.01)  # Optional: pause for updating the figure

plt.show()