In [1]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append('D:\\Dropbox\\stempy')
sys.path.append('D:\\Dropbox\\stemplot')
    
from stemplot import *

from stempy.io import *
from stempy.clustering import *
from stempy.plot import *
from stempy.denoise import *
from stempy.utils import *
from stempy.feature import *
from stempy.spatial import *
from stempy.manifold import *

In [2]:
imgs1 = np.load('all_adf1.npy')
imgs2 = np.load('all_adf2.npy')

In [4]:
from skimage.transform import warp_polar
from scipy.ndimage import uniform_filter1d

def baseline_correction(y,niter=10):
    n = len(y)
    y_ = np.log(np.log(np.sqrt(y +1)+1)+1)
    yy = np.zeros_like(y)

    for pp in np.arange(1,niter+1):
        r1 = y_[pp:n-pp]
        r2 = (np.roll(y_,-pp)[pp:n-pp] + np.roll(y_,pp)[pp:n-pp])/2
        yy = np.minimum(r1,r2)
        y_[pp:n-pp] = yy

    baseline = (np.exp(np.exp(y_)-1)-1)**2 -1
    return baseline

def get_fft_line(img, niter=10, size=9, use_log=True, debug=True):
    fft_abs = np.abs(np.fft.fftshift(np.fft.fft2(img)))
    if use_log:
        fft_log = np.log(fft_abs + 1)
    else:
        fft_log = fft_abs
    i ,j = np.unravel_index(np.argmax(fft_log), shape=img.shape)
    y = warp_polar(fft_log, center=(i, j)).mean(axis=0)[0:i]
    bg = baseline_correction(y, niter=niter)
    y1 = y - bg
    y2 = uniform_filter1d(y1,size=size)
    return y2

class TMDImage:

    def __init__(self, img):
        self.img = img
        self.fft_line = get_fft_line(self.img, debug=False)
        ind = np.argmax(self.fft_line)
        self.size = np.round(self.img.shape[0]/ind).astype(int)
        self.patch_size = 2*self.size + 1

        self.img_clean = img1 = remove_bg(self.img, self.size)

        n_components = self.patch_size
        extraction_step = 2
        self.imgf = denoise_svd(self.img_clean, n_components, self.patch_size, extraction_step)
        

In [5]:
aa = TMDImage(imgs2[1])

Extracting reference patches...
done in 0.10s.
Singular value decomposition...
done in 1.08s.
Reconstructing...
done in 0.69s.


In [37]:
from tqdm import tqdm
def get_feature_X(num=6, threshold = 0.3, min_distance=1):
    dp = r'D:\Dropbox\data\Leyi\Dopant Valency Project\CrWSe2_total 8 datasets\CrWSe2_20211028'+'\\'

    # load images
    file_name1 = dp + 'ADF1_{}.dm4'.format(num)
    file_name2 = dp + 'ADF2_{}.dm4'.format(num)
    
    # load image and normalize
    img1 = load_image(file_name1)
    img1 = normalize_image(img1, 0, 1)
    
    img2 = load_image(file_name2)
    img2 = normalize_image(img2, 0, 1)

    # clean images
    img1 = remove_bg(img1, 7)
    img2 = remove_bg(img2, 7)
    
    img1 = normalize_image(img1, 0, 1)
    img2 = normalize_image(img2, 0, 1)

    # denoise img2
    n_components = 32
    patch_size = 32
    extraction_step = 2
    imgf = denoise_svd(img2, n_components, patch_size, extraction_step, verbose=False) 

    # find points
    pts = local_max(imgf, min_distance=min_distance, threshold=threshold, plot=False)

    size = 2*get_patch_size(img1, debug=False)+1
    kp1 = KeyPoints(pts, img1, size)
    ps1 = kp1.extract_patches(size) 
    
    kp2 = KeyPoints(pts, img2, size)
    ps2 = kp2.extract_patches(size)

    zps = ZPs(n_max=10, size=ps1.shape[1])
    zps.fit(ps1)
    X1 = zps.moments
    zps.fit(ps2)
    X2 = zps.moments
    X = X1.hstack(X2)
    return X

def vstack_zmarrays(a):
    a0 = a[0]
    for e in a[1:]:
        a0 = a0.vstack(e)
    return a0

In [38]:
nums = [6, 7, 9, 13, 14, 15, 17]
thresholds = [0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1]
min_dists = [1, 1, 3, 1, 1, 1, 1]

In [39]:
ll = []
lbs_ = []
ii = 0
for num, threshold, min_distance in tqdm(zip(nums, thresholds, min_dists ), total=7):
    X = get_feature_X(num, threshold, min_distance)
    ll.append(X)
    np.save('X{}.npy'.format(ii), X)
    ii = ii + 1
#    lbs_.append([ii]*len(X))
#    ii = ii + 1
#lbs_ = np.hstack(lbs_)

100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:25<00:00,  3.65s/it]


In [13]:
X= vstack_zmarrays(ll)

In [14]:
X1 = X[:, 0:66]
X2 = X[:, 66:]

In [21]:
lbs = gmm_lbs(X, 2, n_init=1)

In [32]:
plot_pca(X, 2, lbs=lbs)

## create traning data from X1 or X2
* split into two classes
* 

In [45]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path

def pca(X, n_components=2):
    pca_model = PCA(n_components=n_components)
    X_pca = pca_model.fit_transform(X)
    return X_pca

class InteractiveAnnotation:

    def __init__(self, fig, X, lbs, **kwargs):
        self.fig = fig
        self.ax1 = fig.axes[0]
        self.ax2 = fig.axes[1]

        self.X = X
        self.lbs = lbs
        self.X1 = self.X[self.lbs == 0]
        self.X2 = self.X[self.lbs == 1]
        self.lbs1 = self.lbs[self.lbs == 0]
        self.lbs2 = self.lbs[self.lbs == 1]


        self.xy1 = pca(self.X1.rotinv().select(0))
        self.xy2 = pca(self.X2.rotinv().select(0))

        self.colors1 = colors_from_lbs(self.lbs1)
        self.colors2 = colors_from_lbs(self.lbs2)

        self.path_collection1 = self.ax1.scatter(self.xy1[:, 0], self.xy1[:, 1], c=self.colors1, **kwargs)
        self.path_collection2 = self.ax2.scatter(self.xy2[:, 0], self.xy2[:, 1], c=self.colors2, **kwargs)
        self.ax1.axis('equal')
        self.ax2.axis('equal')

        self.ind1 = None
        self.ind2 = None

        self.X1_selected = None
        self.X2_selected = None

        self.lasso1_active = False
        self.lasso2_active = False

        self.lasso1 = LassoSelector(self.ax1, onselect=self.onselect1)
        self.lasso2 = LassoSelector(self.ax2, onselect=self.onselect2)

        self.press = self.fig.canvas.mpl_connect("key_press_event", self.press_key)

        self.X1_train = []
        self.X2_train = []
        self.y1_train = []
        self.y2_train = []

        self.num_clusters1 = 0
        self.num_clusters2 = 0


    def onselect1(self, event):
        path = Path(event)
        self.ind1 = np.nonzero(path.contains_points(self.xy1))[0]
        if self.ind1.size != 0:
            self.lasso1_active = True
            self.lasso2_active = False
            self.X1_selected = self.X1[self.ind1]

    def onselect2(self, event):
        path = Path(event)
        self.ind2 = np.nonzero(path.contains_points(self.xy2))[0]
        if self.ind2.size != 0:
            self.lasso1_active = False
            self.lasso2_active = True
            self.X2_selected = self.X2[self.ind2]

    def press_key(self, event):
        if event.key == "enter":
            if self.lasso1_active:
                self.X1_train.append(self.X1_selected)
                self.y1_train.append(np.array([self.num_clusters1]*len(self.X1_selected)))
                self.num_clusters1 += 1
                print("One cluster has been created.")
            if self.lasso2_active:
                self.X2_train.append(self.X2_selected)
                self.y2_train.append(np.array([self.num_clusters2] * len(self.X2_selected)))
                self.num_clusters2 += 1
                print("One cluster has been created.")
            else:
                pass
        if event.key == 'shift':
            X1_train = np.vstack(self.X1_train)
            X2_train = np.vstack(self.X2_train)
            y1_train = np.hstack(self.y1_train)
            y2_train = np.hstack(self.y2_train)
            np.save('X1_train.npy', X1_train)
            np.save('X2_train.npy', X2_train)
            np.save('y1_train.npy', y1_train)
            np.save('y2_train.npy', y2_train)
            print("Created training datasets have beed saved.")

def interactive_annotation(X, lbs, **kwargs):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    app = InteractiveAnnotation(fig, X, lbs, **kwargs)
    return app

In [47]:
X = get_feature_X(num=6, threshold = 0.3, min_distance=1)
lbs = gmm_lbs(X, 2, n_init=1)

In [48]:
app = interactive_annotation(X, lbs, s=1)

One cluster has been created.
One cluster has been created.
One cluster has been created.
One cluster has been created.
Created training datasets have beed saved.
