In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

## Get base vectors using 3d FFT

In [None]:
from nanomesh.volume import Volume
import pyvista as pv
from skimage import filters
import numpy as np

In [None]:
vol = Volume.load('sample_data.npy')

In [None]:
fourier = np.fft.fftn(vol.image)
scaled = np.abs(np.fft.fftshift(fourier))

fs = Volume(np.clip(scaled, a_min=0, a_max=1e7))

fs.show_slice(index=101, along='y')

In [None]:
import scipy.ndimage as ndi
from skimage import measure
from sklearn.cluster import DBSCAN, MeanShift

def find_peaks(image, *, threshold: float, min_sigma: float =1.0, max_sigma: float = 2.0):
    """Find peaks in image using difference of gaussian."""
    
    difference = ndi.gaussian_filter(image, min_sigma) - ndi.gaussian_filter(image, max_sigma)
    labels, numlabels = ndi.label(difference > threshold)
    props = measure.regionprops(labels, image)
    peaks = np.array([prop.centroid for prop in props])
    
    return peaks


def find_periodic(peaks) -> dict:
    """Find periodic vectors in list of peaks.
    
    TODO: uniqify vectors (i.e. select one quadrant)
    """
    from scipy.spatial import Delaunay

    D = Delaunay(peaks)

    vertices = D.points
    faces = D.simplices
    # diffs = vertices[faces] - vertices[faces][:,1:2,:]
    diff = (vertices[faces] - vertices[faces][:,0:1,:]).reshape(-1,3)
    
    db = DBSCAN(eps=1.0, min_samples=5).fit(diff)
    core_samples = db.core_sample_indices_
    labels = db.labels_[core_samples]

    core_samples = diff[core_samples]
    
    unique, counts = np.unique(labels, return_counts=True)
    
    d = {}
    for label, count in zip(unique, counts):
        vectors = core_samples[labels==label]
        d[label] = {
            'count' : count,
            'label' : label,
            'vector': vectors.mean(axis=0),
            'std': vectors.std(axis=0),
            'confidence': vectors.std(axis=0).mean(),
            'samples' : vectors,
        }

    return d

In [None]:
peaks = find_peaks(scaled, threshold=1e6)
clusters = find_periodic(peaks)

In [None]:
samples = np.vstack([val['samples'] for val in clusters.values()])
centers = np.array([val['vector'] for val in clusters.values()])

plotter = pv.PlotterITK()
plotter.add_points(centers)
plotter.add_points(samples)
plotter.show()