In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from ipywidgets import interact, IntSlider
import py4DSTEM

path = '/Users/cadenmyers/billingelab/dev/skyrmion_lattices/yevgeny_proj/data/0020 - original-centralized-masked.h5'
offsets = np.load('/Users/cadenmyers/billingelab/dev/skyrmion_lattices/yevgeny_proj/data/offsetheatmap_4fold.npz')['data']
dc = py4DSTEM.read(path)

In [None]:
# FILTER FUNCTION

im_shape = dc.data[0][0].shape
print(im_shape)
x,y = np.meshgrid(np.arange(-im_shape[0] // 2 ,im_shape[0] // 2), np.arange(-im_shape[1] // 2 ,im_shape[1] // 2))
DATA_THETA = torch.atan2(torch.tensor(x), torch.tensor(y))
offset1 = torch.tensor(0., requires_grad=True)

# for the model
MAX_ITER_OFFSET = 31
LR = 1e-2
OFFSET_ADJUSTMENT = 60

n_folds = 4
k=8
print("n_folds =", n_folds)
print('k value =', k)
def filter_function(k, theta, n_folds=n_folds):
    filter = torch.exp(k * torch.log((torch.cos(n_folds / 2 * theta))**2))
    return filter

# plt.imshow(filter_function(k, DATA_THETA))

def normalize_min_max(data):
    if isinstance(data, torch.Tensor):
        array = data.detach().numpy()
    else:
        array = data
    array_min = np.min(array)
    array_max = np.max(array)
    norm_array = (array - array_min) / (array_max - array_min)
    if isinstance(data, torch.Tensor):
        norm_tensor = torch.tensor(norm_array)
        return norm_tensor
    else:
        return norm_array




In [None]:
# fig, axs = plt.subplots(1, 2, figsize=(10,5))
for rx in range(83,84): # 217
    for ry in range(41,62): # 142
        dp = dc.data[ry, rx]
        rad_offset = np.deg2rad(offsets[rx][ry])
        axs[0].imshow(filter_function(k, DATA_THETA+rad_offset) + normalize_min_max(dp))
        axs[1].imshow(dp)
        axs[1].set_title(f'rx,ry = {rx}, {ry}')

        # plt.show()

In [None]:

def overlap_vs_phi_plot(rx, ry):
    '''plot overlap score vs. phi over angle range near offset'''

    dp = dc.data[ry, rx]
    deg_offset = offsets[rx][ry]
    rad_offset = np.deg2rad(deg_offset)

    range_deg = np.arange(deg_offset-45, deg_offset+45, .5)
    overlap_list = []
    for i in range(0, range_deg.shape[0]): # calculate overlap over range of values
        filter = filter_function(k, DATA_THETA + np.deg2rad(range_deg[i]))
        overlap = (filter*dp).sum()
        overlap_list.append(overlap.item())

    filt = filter_function(k, DATA_THETA+rad_offset).numpy()
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0].plot(range_deg, overlap_list) # plot overlap score vs. angle
    axs[0].set_title(f'y-range = {round(max(overlap_list) - min(overlap_list))}')

    im = axs[1].imshow(normalize_min_max(dp) + filt)
    axs[1].imshow(normalize_min_max(dp) + filt)
    axs[1].set_title(f'{round(deg_offset, 3)}')
    fig.colorbar(im, ax=axs[1])


    plt.tight_layout()
    plt.show()

interact(
    overlap_vs_phi_plot,
    rx=IntSlider(value=0, min=0, max=217, step=1, description='Rx'),
    ry=IntSlider(value=0, min=0, max=142, step=1, description='Ry')
);