## Experimental workflow code for Fast automated drift correction for functional SPM imaging

Code by: Marti Checa

Publication: Mcheca et al "Automated piezoresponse force microscopy domain tracking during fast thermally stimulated phase transition in CuInP2S6"

## Imports

In [None]:
from __future__ import division, print_function, absolute_import
from scipy.io import loadmat
import glob
import os
import skimage
import time

import pywt # conda install pywavelets
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
import cv2
from scipy.io import savemat

from skimage.color import rgb2gray
from skimage.data import stereo_motorcycle
from skimage.transform import warp
#from skimage.registration import optical_flow_tvl1, optical_flow_ilk
from skimage.registration import phase_cross_correlation
from skimage.registration._phase_cross_correlation import _upsampled_dft
from celluloid import Camera

import win32com.client  

## Compresive sensing reconstruction function

In [None]:
def SSTEM(yspar,mask,itern,levels,lambd):  
############################################
#INPUT: 
#yspar: sparse image as an array, to reduce iteration numer, better rescaling the value to [0,1]
#mask: binary array, 1 indicationg sampled pixel locations
#itern: iteration number, usually 20 is enough
#levels: wavelet level, common choice 2,3,4, larger value for larger feature size, if too blur, change to smaller one
#lambd: threshold value, usually 0.8 is fine

#Output: Inpaited image
#############################################   
    fSpars = yspar
    W_thr = [0]*levels;

    ProjC = lambda f, Omega: (1-Omega)*f + Omega*yspar
    
    for i in range(itern):
        fSpars = ProjC(fSpars, mask)
        W_pro = pywt.swt2(fSpars, 'db2',levels)
        for j in range(levels):
            sA = W_pro[j][0]
            sH = W_pro[j][1][0]
            sV = W_pro[j][1][1]
            sD = W_pro[j][1][2]
            W_thr[j] = (pywt.threshold(sA,0,'soft')),(pywt.threshold(sH,lambd,'soft'),
                        pywt.threshold(sV,lambd,'soft'),pywt.threshold(sD,lambd,'soft'))    
        fSpars = pywt.iswt2(W_thr,'db2')
    return fSpars

## Function to calculate phase crosscorrelation

In [None]:
def shift_correction(topo_recon_norm1,topo_recon_norm2,mask_recon,res_factor):
#we calculate the phase crosscorrelation
    image0 = (topo_recon_norm1)
    image1 = (topo_recon_norm2)

#We resize the image to account for subpixel resultion in the drift calculation
    image0  = cv2.resize(image0 , dsize=(image0.shape[0]*res_factor, image0.shape[1]*res_factor), interpolation=cv2.INTER_CUBIC)
    image1  = cv2.resize(image1 , dsize=(image1.shape[0]*res_factor, image1.shape[1]*res_factor), interpolation=cv2.INTER_CUBIC)
    mask_recon_resize = cv2.resize(mask_recon , dsize=(mask_recon.shape[0]*res_factor, mask_recon.shape[1]*res_factor), interpolation=cv2.INTER_CUBIC)

# --- Compute the phase crosscorrelation
    shift = phase_cross_correlation(image0, image1,reference_mask=mask_recon_resize, moving_mask=mask_recon_resize)

    print('The shift vector is:', shift)

# --- Display
    fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(6, 4))

# --- Sequence image sample

    im1=ax0.imshow(image0, cmap='viridis')
    ax0.set_title("Image 0")
    ax0.set_axis_off()

    im2=ax1.imshow(image1, cmap='viridis')
    ax1.set_title("Image 1")
    ax1.set_axis_off()

    ax0.quiver(image0.shape[0]/2, image0.shape[0]/2, -shift[1]*10, -shift[0]*10, color='r', units='width',
           angles='xy', scale_units='xy', lw=3, scale=0.5) #this plots the mean optical flow

    ax0.set_axis_off()

    im3=ax2.imshow(abs(image0-image1), cmap='viridis')
    ax2.set_title("Difference")
    ax2.set_axis_off()
    fig.tight_layout()
    plt.show()
    
    return shift

## We call the Labview software to control the FPGA

In [None]:
exe_path = r'D:\User Data\Marti\FPGA Scanner 01 py 101622 01\py spiral scanner.exe'
#exe_path = 'D:\\User Data\Marti\FPGA Scanner 01 py 100622 01\\builds\FPGA Scanner 01\py spiral scannerV4\py spiral scannerV4.exe'
os.startfile(exe_path)

time.sleep(5)

#labview = win32com.client.Dispatch("PySpiralScanner.Application")
labview = win32com.client.Dispatch("PySpiralScanner.Application")

VI = labview.getvireference(r'D:\User Data\Marti\FPGA Scanner 01 py 101622 01\py spiral scanner.exe\FPGA scanner 08 (host).vi')
#VI = labview.getvireference(r'D:\\User Data\Marti\FPGA Scanner 01 py 100622 01\\builds\FPGA Scanner 01\py spiral scannerV4\py spiral scannerV4.exe\FPGA scanner 08 (host).vi')

## Initialization of FPGA parameters

We set the FPGA parameters, perform an initial scan, and adjust the parametric values for the compresed sensing

In [None]:
#Set parameters to save data
path_save= r'D:\User Data\Marti\2023\Jan\12th\spiral_scans\\'
filename_base= 'Flake1_'

#Set parameters for labview
Amplitude1=0
Amplitude2=1
N_Cylces=20
X0=0
Y0=0
Duration=1 #in seconds
n_frames=20
imagesize=128

In [None]:
#we set up the parameters and perform 3 scans
time.sleep(0.3)
VI.setcontrolvalue('x0',(X0))
time.sleep(0.3)
VI.setcontrolvalue('y0',(Y0))
VI.setcontrolvalue('spiral cluster 1',(Amplitude2, N_Cylces, Duration, 10.0, 2.0, 10.0, imagesize, Amplitude1))
VI.setcontrolvalue('do scan update',str(True))
time.sleep(2)
VI.setcontrolvalue('do scanning',str(True))
time.sleep(Duration-0.5)
VI.setcontrolvalue('do scanning',str(False))
time.sleep(2)
VI.setcontrolvalue('do scanning',str(True))
time.sleep(Duration-0.5)
VI.setcontrolvalue('do scanning',str(False))
time.sleep(2)
VI.setcontrolvalue('do scanning',str(True))
time.sleep(Duration-0.5)
VI.setcontrolvalue('do scanning',str(False))
time.sleep(2)

#we input the data
ai0 = np.flipud(np.rot90(np.asarray(VI.getcontrolvalue('AI0 image'))))
ai1 = np.flipud(np.rot90(np.asarray(VI.getcontrolvalue('AI1 image'))))
mask=np.zeros((imagesize,imagesize))
    
#we calculate the mask

for ii in range(ai0.shape[0]):
    for jj in range(ai0.shape[1]):
        if ai0[ii,jj] > 0:
            mask[ii,jj] = 1

n_steps=3
ite_ini=1
lev_ini=1
lam_ini=1
ite_step=3
lev_step=1
lam_step=3

plt.title('Raw Image')
plt.imshow(ai0, cmap='copper')
plt.colorbar()
plt.show()

fig, axs = plt.subplots(n_steps, (n_steps)**2,figsize=(n_steps*6,n_steps*3),constrained_layout=True)

for i in range(n_steps):
    itern= ite_ini+i*ite_step
    for j in range(n_steps):
        levels= lev_ini+j*lev_step
        for k in range(n_steps):
            lambd= lam_ini+k*lam_step
            
            topo_recon = SSTEM(ai0,mask,itern,levels,lambd)
            param_print ='it=' + str(itern) +' lev='+ str(levels) + ' lam='+ str(lambd)
            axs[i,k+(j*n_steps)].imshow(topo_recon,cmap='viridis')
            axs[i,k+(j*n_steps)].set_title(param_print)
            
plt.suptitle('Exploration of compressed sensing parameters',fontsize=20)
plt.show()

In [None]:
#Define the desired set parameters for Compresed sensing
itern = 1
levels = 2
lambd = 4

## Experiment workflow:
1. 2 consecutive inages are aquired
2. Data is captured and images are inpainted using CS
3. Shift is calculated
4. If shift is different from 0, center coordinates of the scan are corrected.
5. Next scan is triggered.

In [None]:
#Set parameters to save data
path_save= r'D:\User Data\Marti\2023\Jan\13th\spiral scan\region 4\20cycles_1sec_withouttracking\\'
filename_base= 'CIPS_IPS_Flake_'

#Set parameters for labview
Amplitude1=0
Amplitude2=1
N_Cylces=20
X0=0
Y0=2.2
Duration=1 #in seconds
n_frames=120
imagesize=128

In [None]:
#saving metadata
metadata=[Amplitude1,Amplitude2,N_Cylces,X0,Y0,Duration,n_frames,imagesize]
np.savetxt(path_save + filename_base + '_METADATA.txt', metadata, delimiter=',') 

#parameter for the shift calculation
pix_amplif=1
option_corr=2 #1 for disabling drift correction, 2 for enabling drift correction

#variable initialization
channel1=np.zeros((n_frames,imagesize,imagesize))
channel2=np.zeros((n_frames,imagesize,imagesize))
channel3=np.zeros((n_frames,imagesize,imagesize))
channel4=np.zeros((n_frames,imagesize,imagesize))
norm_rec_ai0=np.zeros((n_frames,imagesize,imagesize))
norm_rec_ai1=np.zeros((n_frames,imagesize,imagesize))
mask=np.zeros((imagesize,imagesize))
mask_recon=np.zeros((imagesize,imagesize))
shift=np.zeros((n_frames,2))
shift_V=np.zeros((n_frames,2))
X=np.zeros((n_frames))
Y=np.zeros((n_frames))
Xpix=np.zeros((n_frames))
Ypix=np.zeros((n_frames))

for kk in range(n_frames):
    
    if kk==0: # the first scan 
        X[kk]=X0
        Y[kk]=Y0
        Xpix[kk]=imagesize/2
        Ypix[kk]=imagesize/2
        VI.setcontrolvalue('spiral cluster 1',(Amplitude2, N_Cylces, Duration, 10.0, 2.0, 10.0, imagesize, Amplitude1))
        time.sleep(0.3)
        VI.setcontrolvalue('x0',(X0))
        time.sleep(0.3)
        VI.setcontrolvalue('y0',(Y0))
        time.sleep(0.3)
        VI.setcontrolvalue('do scan update',str(True))
        time.sleep(2)
        
        for m in range(3):
            print('First 3 scanns at --> (X0 , Y0) = (0 , 0)')
            #Do a single scann
            VI.setcontrolvalue('do scanning',str(True))
            time.sleep(Duration-0.5)
            VI.setcontrolvalue('do scanning',str(False))
            time.sleep(0.5)
            
            ai0 = np.flipud(np.rot90(np.asarray(VI.getcontrolvalue('AI0 image'))))
            ai1 = np.flipud(np.rot90(np.asarray(VI.getcontrolvalue('AI1 image'))))
    
            #we calculate the mask
            for ii in range(ai0.shape[0]):
                for jj in range(ai0.shape[1]):
                    if ai0[ii,jj] > 0:
                        mask[ii,jj] = 1

            #we do the compressed sensing
            ai0_rec = SSTEM(ai0,mask,itern,levels,lambd)
            ai1_rec = SSTEM(ai1,mask,itern,levels,lambd)
            
            for ii in range(ai0_rec.shape[0]):
                for jj in range(ai0_rec.shape[1]):
                    mask_recon[ii,jj] = 1
                    if ai0_rec[ii,jj] <= (ai0_rec[0,0]) and ai0_rec[ii,jj] >= (ai0_rec[0,0]):
                        mask_recon[ii,jj] = 0
    
            channel1[kk,:,:]=ai0
            channel2[kk,:,:]=ai1
            channel3[kk,:,:]=ai0_rec
            channel4[kk,:,:]=ai1_rec
            norm_rec_ai0[kk,:,:]= (ai0_rec - np.min(ai0_rec))/np.ptp(ai0_rec)
            norm_rec_ai1[kk,:,:]= (ai1_rec - np.min(ai1_rec))/np.ptp(ai1_rec)
            
    else:
        
        #Do a single scann
        VI.setcontrolvalue('do scanning',str(True))
        time.sleep(Duration-0.5)
        VI.setcontrolvalue('do scanning',str(False))
        time.sleep(0.5)
    
        #we read the data
        ai0 = np.flipud(np.rot90(np.asarray(VI.getcontrolvalue('AI0 image'))))
        ai1 = np.flipud(np.rot90(np.asarray(VI.getcontrolvalue('AI1 image'))))
    
        #we calculate the mask

        for ii in range(ai0.shape[0]):
            for jj in range(ai0.shape[1]):
                if ai0[ii,jj] > 0:
                    mask[ii,jj] = 1

        #we do the compressed sensing
        ai0_rec = SSTEM(ai0,mask,itern,levels,lambd)
        ai1_rec = SSTEM(ai1,mask,itern,levels,lambd)
    
        channel1[kk,:,:]=ai0
        channel2[kk,:,:]=ai1
        channel3[kk,:,:]=ai0_rec
        channel4[kk,:,:]=ai1_rec
        
        norm_rec_ai0[kk,:,:]= (ai0_rec - np.min(ai0_rec))/np.ptp(ai0_rec)
        norm_rec_ai1[kk,:,:]= (ai1_rec - np.min(ai1_rec))/np.ptp(ai1_rec)
    
        #plot the data
        fig, ax = plt.subplots(2, 2, figsize=(6, 6))
        ax[0,0].imshow(ai0, cmap='copper')
        ax[0,0].set_title('Topo frame num =' + str(kk))
        ax[0,1].imshow(ai1 ,cmap='viridis')
        ax[0,1].set_title('Defl frame num =' + str(kk))
        ax[1,0].imshow(ai0_rec,cmap='copper')
        ax[1,0].set_title('Topo rec frame num =' + str(kk))
        ax[1,1].imshow(ai1_rec ,cmap='viridis')
        ax[1,1].set_title('Defl rec frame num =' + str(kk))
        plt.show()
    
        np.savetxt(path_save + filename_base + str(kk) +'_ai0.txt', ai0, delimiter=',') 
        np.savetxt(path_save + filename_base + str(kk) +'_ai1.txt', ai1, delimiter=',') 
        np.savetxt(path_save + filename_base + str(kk) +'_ai0_rec.txt', ai0_rec, delimiter=',') 
        np.savetxt(path_save + filename_base + str(kk) +'_ai1_rec.txt', ai1_rec, delimiter=',') 
        
        #we calculate the shift we have to correct for
        shift[kk,:]=shift_correction(norm_rec_ai0[0,:,:],norm_rec_ai0[kk,:,:],mask_recon,pix_amplif)
        shift_V[kk,:]=shift[kk,:]*2*Amplitude2/(imagesize*pix_amplif)
        print('Shift in volts is:'+ str(shift_V[kk,:]) + ' V')
        
        #we update the new parameters to the scann putting limits to the
        #max voltages to input to avoid piezo damage:
        if abs(shift_V[kk,0]) <0.25 and abs(X[kk-1]) < 8 and abs(shift_V[kk,1]) <0.25  and abs(Y[kk-1]) < 8:
            X[kk]=X[kk-1]-shift_V[kk,1]
            Xpix[kk]=Xpix[kk-1]-shift[kk,1]
            Y[kk]=Y[kk-1]-shift_V[kk,0]
            Ypix[kk]=Ypix[kk-1]-shift[kk,0]
            
        print('New X --> ' + str(X[kk]) + '   ' + 'New Y --> ' + str(Y[kk]))
        
        if option_corr==2:
            if shift_V[kk,0]!=0 or shift_V[kk,1]!=0:
                
                if abs(shift_V[kk,0])>0.25 or abs(shift_V[kk,1])>0.25:
                    norm_rec_ai0[kk,:,:]=norm_rec_ai0[kk-1,:,:]
                    
                if abs(shift_V[kk,0])<0.25 and abs(shift_V[kk,1])<0.25:    
                    print('warming up after change...')
                    VI.setcontrolvalue('x0',(X[kk]))
                    time.sleep(0.3)
                    VI.setcontrolvalue('y0',(Y[kk]))
                    time.sleep(0.3)
            
                    #Do a warm up scann
                    VI.setcontrolvalue('do scanning',str(True))
                    time.sleep(Duration-0.5)
                    VI.setcontrolvalue('do scanning',str(False))
                    time.sleep(0.5)
            
        if shift_V[kk,0]==0 and shift_V[kk,1]==0:
            norm_rec_ai0[kk,:,:]=norm_rec_ai0[kk-1,:,:]
                               
#We go back to the initial position
VI.setcontrolvalue('spiral cluster 1',(Amplitude2, N_Cylces, X0, Y0, Duration, 10.0, 2.0, 10.0, imagesize, Amplitude1))
time.sleep(0.5)
VI.setcontrolvalue('do scan update',str(True))
time.sleep(2)

print('Done!')

In [None]:
#We save the shift vectors to a file
np.savetxt(path_save + filename_base + str(kk) +'_shift_vectorX.txt', shift_V[:,0], delimiter=',')
np.savetxt(path_save + filename_base + str(kk) +'_shift_vectorY.txt', shift_V[:,1], delimiter=',')