In [None]:
from filter_functions import filter_function, find_k_value, circular_filter_function, normalize_min_max, mask_images, azimuthal_sum_w_filter, data_theta
import numpy as np
import matplotlib.pyplot as plt

In [None]:
numor = 69958
path_to_npz = f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/experimental_data/npz_sept_numor_data/{numor}.npz'
data = np.load(path_to_npz)['data']# Generate filter
intensity_image = data[26]
# plt.imshow(intensity_image)
stemdps = np.load('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/subset_dps_masked.npz')['data']
# stemdp = np.load('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/good4DSTEMdp.npz')['data']
plt.imshow(stemdps[1])
plt.show()

In [None]:
def find_symmetries(dp, experimental_resolution, threshold=1000, n_folds=1):
    """
    Calculates the rotational symmetries of a diffraction pattern (dp) by performing an azimuthal sum
    and Fourier transform. The function identifies the symmetry order of the diffraction pattern 
    based on the azimuthal overlap scores and their frequency spectra. Filters out diffraction patterns
    based on `threshold`.

    returns frequencies, and fft magnitude
    """
    imshape = dp.shape
    k = find_k_value(experimental_resolution, n_folds)
    phis, score, score_norm = azimuthal_sum_w_filter(dp, k=k, n_folds=n_folds)
    if np.max(score) < threshold: # filter out images without intensity peak(s)
        print(f'no intensity peak(s) detected with threshold={threshold}')
        return
    else:
        O_fft = np.fft.fft(score_norm)
        frequencies = np.fft.fftfreq(len(score_norm), d=1)  # Frequency in cycles per 360 degrees
        # Take the magnitude of the FFT and keep only positive frequencies
        O_fft_magnitude = np.abs(O_fft)[:len(score_norm)//2]
        frequencies = frequencies[:len(score_norm)//2] * len(score_norm)  # Normalize to symmetry order
        # Drop frequency==0, 1 (DC component)
        O_fft_magnitude = O_fft_magnitude[1:]
        frequencies = frequencies[1:]
        O_fft_magnitude_norm = (O_fft_magnitude - np.min(O_fft_magnitude)) / (np.max(O_fft_magnitude) - np.min(O_fft_magnitude))
        return frequencies, O_fft_magnitude
# print(phis)



In [None]:
k = find_k_value(3, n_folds=1)
i=0
dp = stemdps[i]

fig, ax = plt.subplots(1, 2, figsize=(10,5))
for dp in stemdps:
    result = find_symmetries(dp, 3)
    if result is not None:  # Only proceed if the function returned values
        freq, fft_mag = result
        ax[0].plot(freq, fft_mag)
        ax[0].set_xlabel('n_fold symmetry')
        ax[0].set_xlim(0, 12)
        ax[0].set_xticks(np.arange(0,12,1))
        ax[0].grid(True)

        ax[1].imshow(dp)
        plt.show()


## Test on 4DSTEM data

In [None]:
from py4DSTEM.io import read

datacube = read('/Users/cadenmyers/billingelab/dev/sym_adapted_filts/4DSTEM/data_and_figs/0020 - original-centralized-masked.h5')


In [None]:
from scipy.signal import find_peaks

dps = datacube.data
y = 50
x = 0
dp = dps[y][x]

# plt.imshow(dp)
xs = []
syms = []
for x in range(25, 35):
    dp = dps[y][x]
    freq, mag = find_symmetries(dp, 3)
    limited_freq = freq[:30]
    limited_mag = mag[:30]
    # plt.plot(limited_freq, limited_mag)
    plt.plot(freq, mag)
    plt.grid(True)
    plt.xlim(100,150)
    plt.show()
    peaks, _ = find_peaks(limited_mag, height=0.4)
    folds = limited_freq[peaks]
    xs.append(x)
    if len(folds) == 0:
        syms.append(np.nan)
    else:
        syms.append(folds[0])


In [None]:
plt.scatter(xs, syms)
# plt.xticks(np.arange(0, len(xs), 1))
# plt.yticks(np.arange(0, max(syms)+1, 1))
plt.xlabel('x coord of 4DSTEM')
plt.ylabel('detected symmetry, $n_{folds}$')
plt.grid(True)
plt.show()

In [None]:
plt.imshow(dps[50][30])