In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from ipywidgets import interact, IntSlider
import py4DSTEM
import scipy.signal as signal
import ipywidgets as widgets
from IPython.display import display


path = '/Users/cadenmyers/billingelab/dev/skyrmion_lattices/yevgeny_proj/data/0020 - original-centralized-masked.h5'
offsets = np.load('/Users/cadenmyers/billingelab/dev/skyrmion_lattices/yevgeny_proj/data/offsetheatmap_4fold.npz')['data']
dc = py4DSTEM.read(path)

In [None]:
im_shape = dc.data[0][0].shape
print(im_shape)
x,y = np.meshgrid(np.arange(-im_shape[0] // 2 ,im_shape[0] // 2), np.arange(-im_shape[1] // 2 ,im_shape[1] // 2))
DATA_THETA = torch.atan2(torch.tensor(x), torch.tensor(y))
offset1 = torch.tensor(0., requires_grad=True)

# for the model
MAX_ITER_OFFSET = 31
LR = 1e-2
OFFSET_ADJUSTMENT = 60

n_folds = 2
k=100
print("n_folds =", n_folds)
print('k =', k)
def azimuthal_filter_function(k, theta, n_folds=n_folds):
    filter = torch.exp(k * torch.log((torch.cos(n_folds / 2 * theta))**2))
    # plt.imshow(filter)
    # plt.title(f'n_folds={n_folds}, k={k}')
    # plt.show()
    return filter

azimuthal_filter_function(k, DATA_THETA)

# a and b vectors in q-space, r is distance from center. delta is the angle between corners of rectangle (short side). Use this to find correct filter formula for tetragonal symmetry
r = 28. # pixels
delta = torch.tensor(0.8596625328063965) # based on dp x=41 y=125
a = 2*r*torch.sin(delta) # pixels
b = 2*r*torch.cos(delta) # pixels
print('a =', round(a.item(), 3), 'pixels')
print('b =', round(b.item(), 3), 'pixels')
print('delta =', round(delta.item(), 3), 'radians')
def rectangular_filter_function(k, theta1, theta2, delta=delta, n_folds=n_folds):
    filter = torch.exp(k * torch.log((torch.cos(n_folds / 4 * theta1))**2)) + torch.exp(k * torch.log((torch.cos(n_folds / 4 * theta2 + delta))**2))
    # plt.imshow(filter)
    # plt.title(f'n_folds={n_folds}, k={k}, delta={round(delta.item(), 3)}')
    # plt.show()
    return filter

rectangular_filter_function(k, DATA_THETA, DATA_THETA)
r_0 = 20
sd = 10
def circular_filter_function(r_0=r_0, sd=3, data_shape=(256, 256)):
    '''generate a circularly symmetric filter with a specified radius (r_0) and linewidth (sd)'''
    x, y = torch.meshgrid(torch.arange(-data_shape[0] // 2, data_shape[0] // 2), torch.arange(-data_shape[1] // 2, data_shape[1] // 2))
    r = torch.sqrt(x**2 + y**2)
    filter = torch.exp(-(r-r_0)**2/sd)
    # plt.imshow(filter)
    # plt.title(f'r_0={r_0}, sd={sd}')
    # plt.show()
    return filter

def normalize_min_max(data):
    if isinstance(data, torch.Tensor):
        array = data.detach().numpy()
    else:
        array = data
    array_min = np.min(array)
    array_max = np.max(array)
    norm_array = (array - array_min) / (array_max - array_min)
    if isinstance(data, torch.Tensor):
        norm_tensor = torch.tensor(norm_array)
        return norm_tensor
    else:
        return norm_array

def mask_center(data, radius):
    '''mask diff pattern up to some radius'''
    dp = data

    center_y, center_x = dp.shape[0] // 2, dp.shape[1] // 2
    y, x = np.ogrid[:dp.shape[0], :dp.shape[1]]
    distance = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    mask = distance <= radius
    dp_masked = np.copy(dp)
    dp_masked[mask] = 0

    return dp_masked

In [None]:
rx = 115
ry = 49
dp = mask_center(dc.data[ry][rx], 14)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

angle_range = np.deg2rad(np.arange(0, 360, .5))
r_range = np.arange(1, 100, .5)

# RADIAL SUM
r_overlap_score = []
for r in r_range:
    overlap = (dp * circular_filter_function(r_0=r).numpy()).sum()
    r_overlap_score.append(overlap)

r_overlap_score = np.array(r_overlap_score)
int_peaks, _ = signal.find_peaks(r_overlap_score, height=3000)
r_peaks = r_range[int_peaks]

for rs in r_peaks:
    axes[0].axvline(x=rs, color='purple', linestyle='--')
    axes[0].text(rs, 0, round(rs, 1), color='black', rotation=90)

axes[0].plot(r_range, r_overlap_score)
axes[0].set_xlabel('r (pixel)')
axes[0].grid(True)
axes[0].set_title('Radial Sum')

# AZIMUTHAL SUM
overlap_scores = []
for angle in angle_range:
    overlap = (dp * azimuthal_filter_function(k, DATA_THETA + angle).numpy()).sum()
    overlap_scores.append(overlap)

y_range = round(max(overlap_scores) + min(overlap_scores))
overlap_scores = np.array(overlap_scores)
peaks, _ = signal.find_peaks(overlap_scores, height=12000)
phi_peaks = angle_range[peaks]

for phi_rad in phi_peaks:
    x_value = np.rad2deg(phi_rad)
    axes[1].axvline(x=x_value, color='r', linestyle='--')
    axes[1].text(x_value, min(overlap_scores), round(x_value, 1), color='black', rotation=90, verticalalignment='bottom')

axes[1].plot(np.rad2deg(angle_range), overlap_scores)
axes[1].set_title(f'Azimuthal Sum, Y-range={y_range}')
axes[1].set_xlabel('Azimuthal angle')
axes[1].grid(True)

# DIFFRACTION PATTERN
total_filter = np.zeros((256, 256))
for phis in phi_peaks:
    total_filter += azimuthal_filter_function(10000, DATA_THETA + phis).numpy()

total_circle_filter = np.zeros((256, 256))
# for r in r_peaks:
#     total_circle_filter += circular_filter_function(r, 1).numpy()

axes[2].imshow(normalize_min_max(dp) + 1/2*total_filter + 1/2*total_circle_filter)
axes[2].set_title(f'({rx},{ry})')

plt.tight_layout()
plt.show()


In [None]:

rx = 115
ry = 46
dp = mask_center(dc.data[ry][rx], 18)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# RADIAL SUM:
# Using circular_filter_function(), the radial intensity sum is calculated for the specified r-range.
# Peaks in the intensity sum are found and plotted.
def radial_sum(rmin, rmax, dp, stepsize=.5, peak_height_threshold=3000):
    '''
    Using circular_filter_function(), the radial intensity sum is calculated for the specified r-range.
    Peaks in the intensity sum are found and plotted.

    inputs:
    rmin (int): minimum r for grid search
    rmax (int): maximum r for gridseard
    peak_height_threshold = azimuthal peaks below this threshold value will not be found in peak-finding algorithm

    output:
    r_overlap_score = Overlap score values
    r_range = Corresponding radii values
    Also pretty plots!!
    '''
    r_range = np.arange(rmin, rmax, .5)
    r_overlap_score = []
    for r in r_range:
        overlap = (dp * circular_filter_function(r_0=r).numpy()).sum()
        r_overlap_score.append(overlap)
    r_overlap_score = np.array(r_overlap_score)
    int_peaks, _ = signal.find_peaks(r_overlap_score, height=peak_height_threshold)
    r_peaks = r_range[int_peaks] # radii inputs to use for the plot_peaks_at_diff_radii function below
    for rs in r_peaks:
        axes[0].axvline(x=rs, color='limegreen', linestyle='--')
        axes[0].text(rs, 0, round(rs, 1), rotation=90)
    axes[0].plot(r_range, r_overlap_score, color='blueviolet')
    axes[0].set_xlabel('r (pixel)')
    axes[0].grid(True)
    axes[0].set_title('Radial Sum')
    return r_range, r_overlap_score

r_range, r_overlap_score = radial_sum(1, 100, dp)


def plot_peaks_at_diff_radii(radius, dp, peak_height_threshold): # , peak_height_threshold
    '''
    Using circular_filter_function() and using azimuthal_filter_function(),
    the azimuthal intensity sum is calculated for 360 degrees at a specified radius.

    inputs:
    radius: value found by radial_sum() or you can pick your own
    peak_height_threshold: azimuthal peaks below this threshold value will not be found in peak-finding algorithm

    output:
    phi_peaks: azimuthal peak location of a given radially masked diffraction pattern
    Also pretty plots!!
    '''
    # print('k=', k)
    # print('n_folds=', n_folds)
    angle_range = np.deg2rad(np.arange(0, 180, 1))
    axes[0].axvline(radius, color='g', linestyle='--')
    axes[0].text(radius, max(r_overlap_score)/2, radius, rotation=90, verticalalignment='bottom')
    # AZIMUTHAL SUM
    overlap_scores = []
    for angle in angle_range:
        overlap = (dp * circular_filter_function(radius).numpy() * azimuthal_filter_function(k, DATA_THETA + angle).numpy()).sum()
        overlap_scores.append(overlap)
    if peak_height_threshold == None:
        y_max = max(overlap_scores)
        peak_height_threshold = 2/3*y_max

    overlap_scores = np.array(overlap_scores)
    peaks, _ = signal.find_peaks(overlap_scores, height=peak_height_threshold)
    phi_peaks = angle_range[peaks]
    for phi_rad in phi_peaks:
        x_value = np.rad2deg(phi_rad)
        axes[1].axvline(x=x_value, color='dodgerblue', linestyle='--')
        axes[1].text(x_value, min(overlap_scores), round(x_value, 1), color='black', rotation=90, verticalalignment='bottom')
    axes[1].plot(np.rad2deg(angle_range), overlap_scores, color='crimson')
    axes[1].set_title(f'Radially-Masked Azimuthal Sum, r={radius}')
    axes[1].set_xlabel('Azimuthal angle')
    axes[1].grid(True)
    print(peak_height_threshold)
    # DIFFRACTION PATTERN
    azimuthal_filter = np.zeros((256, 256))
    for phis in phi_peaks:
        azimuthal_filter += azimuthal_filter_function(10000, DATA_THETA + phis).numpy() # shows azimuthal location of peaks
    circular_mask = circular_filter_function(radius, 1).numpy() # create circular mask for visualization
    axes[2].imshow(normalize_min_max(dp) + 1/8*azimuthal_filter + 1/8*circular_mask)
    axes[2].set_title(f'({rx},{ry})')

    plt.tight_layout()
    plt.show()
    return phi_peaks

r = 28
# rx = 115
# ry = 46
# for x in range(106, 125):
#     dp = mask_center(dc.data[ry][x], 18)
#     array = plot_peaks_at_diff_radii(r_peaks[1], dp, None)
# print(np.rad2deg(array))

plot_peaks_at_diff_radii(42, dp, None)

In [None]:

def interactive_plot(rx=115, ry=46, peak_height_threshold=3000):
    dp = mask_center(dc.data[ry][rx], 18)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # RADIAL SUM:
    rmin, rmax = 1, 100
    r_range = np.arange(rmin, rmax, 0.5)
    r_overlap_score = []
    for r in r_range:
        overlap = (dp * circular_filter_function(r_0=r).numpy()).sum()
        r_overlap_score.append(overlap)
    r_overlap_score = np.array(r_overlap_score)
    int_peaks, _ = signal.find_peaks(r_overlap_score, height=peak_height_threshold)
    r_peaks = r_range[int_peaks]

    for rs in r_peaks:
        axes[0].axvline(x=rs, color='limegreen', linestyle='--')
        axes[0].text(rs, 0, round(rs, 1), rotation=90)
    axes[0].plot(r_range, r_overlap_score, color='blueviolet')
    axes[0].set_xlabel('r (pixel)')
    axes[0].grid(True)
    axes[0].set_title('Radial Sum')


    radius = r_peaks[1]
    angle_range = np.deg2rad(np.arange(0, 180, 1))
    axes[0].axvline(radius, color='g', linestyle='--')
    axes[0].text(radius, max(r_overlap_score) / 2, radius, rotation=90, verticalalignment='bottom')

    overlap_scores = []
    for angle in angle_range:
        overlap = (dp * circular_filter_function(radius).numpy() * azimuthal_filter_function(6, torch.tensor(angle)).numpy()).sum()
        overlap_scores.append(overlap)

    overlap_scores = np.array(overlap_scores)
    peaks, _ = signal.find_peaks(overlap_scores, height=peak_height_threshold)
    phi_peaks = angle_range[peaks]

    for phi_rad in phi_peaks:
        x_value = np.rad2deg(phi_rad)
        axes[1].axvline(x=x_value, color='dodgerblue', linestyle='--')
        axes[1].text(x_value, min(overlap_scores), round(x_value, 1), color='black', rotation=90, verticalalignment='bottom')
    axes[1].plot(np.rad2deg(angle_range), overlap_scores, color='crimson')
    axes[1].set_title(f'Radially-Masked Azimuthal Sum, r={radius}')
    axes[1].set_xlabel('Azimuthal angle')
    axes[1].grid(True)

    # DIFFRACTION PATTERN:
    azimuthal_filter = np.zeros((256, 256))
    for phis in phi_peaks:
        azimuthal_filter += azimuthal_filter_function(10000, phis)
    circular_mask = circular_filter_function(radius)
    axes[2].imshow(normalize_min_max(dp) + 1 / 8 * azimuthal_filter + 1 / 8 * circular_mask.numpy())
    axes[2].set_title(f'({rx},{ry})')

    plt.tight_layout()
    plt.show()

interact(
    interactive_plot,
    rx=IntSlider(min=0, max=255, step=1, value=115),
    ry=IntSlider(min=0, max=255, step=1, value=46),
    peak_height_threshold=IntSlider(min=100, max=5000, step=500, value=3000)
)

In [None]:
def two_fold_symmetry(angles, tolerance=5):
    '''
    Finds the set of points that have two-fold symmetry within a tolerance

    Parameters:
    angles (list or array): List of azimuthal angles in radians (0 to 2*pi).
    tolerance (float): Allowed angular difference (in degrees) for symmetry matching.

    Returns:
    list: A list of tuples of phi values (in radians) with two-fold symmetry
    '''
    tolerance = np.deg2rad(tolerance)
    two_fold_pairs = []
    for i in range(len(angles)):
        for j in range(i + 1, len(angles)):
            sym = np.abs(angles[i] - angles[j])
            if np.pi - tolerance <= sym <= np.pi + tolerance:
                two_fold_pairs.append((angles[i], angles[j]))
    return two_fold_pairs
print(two_fold_symmetry(array))

def find_mirror_axes(angles, tolerance=5):
    """
    Finds the mirror symmetry axes (phi) for a given set of azimuthal angles.

    Parameters:
    angles (list or array): List of azimuthal angles in radians (0 to 2*pi).
    radius (float): Radius that the angles were found at.
    tolerance (float): Allowed angular difference (in degrees) for symmetry matching.

    Returns:
    list: A list of phi values (angles in radians) that represent the axes of mirror symmetry.
    """
    tolerance = np.deg2rad(tolerance)
    angles = np.mod(angles, 2 * np.pi)
    angles.sort()
    mirror_axes = []
    for i in range(len(angles)):
        for j in range(i + 1, len(angles)):
            phi = (angles[i] + angles[j]) / 2
            reflected_angles = np.mod(2 * phi - angles, 2 * np.pi)
            if np.allclose(np.sort(reflected_angles), angles, atol=tolerance):
                mirror_axes.append(phi)
    return np.unique(mirror_axes)

