Interactive plot for contrast in TEM

Stephanie Ribet

2023 June


In [1]:
%matplotlib widget

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from IPython.display import display
from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label
# from ipywidgets import widgets, interact, GridspecLayout, Layout,  Layout, Label, 
from matplotlib import cm
from ipywidgets import AppLayout, FloatSlider, FloatLogSlider, Layout
import abtem
from ase.io import read
from scipy.ndimage import gaussian_filter
from abtem.noise import poisson_noise
# from py4DSTEM.visualize import Complex2RGB

In [4]:
# Copied from py4DSTEM directly
def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=False):
    """
    complex_data (array): complex array to plot
    vmin (float)        : minimum absolute value 
    vmax (float)        : maximum absolute value 
    hue_start (float)   : rotational offset for colormap (degrees)
    inverse (bool)      : if True, uses light color scheme
    """
    amp = np.abs(complex_data)
    if np.isclose(np.max(amp),np.min(amp)):
        if vmin is None:
            vmin = 0
        if vmax is None:
            vmax = np.max(amp)
    else:
        if vmin is None:
            vmin = 0.02
        if vmax is None:
            vmax = 0.98
        vals = np.sort(amp[~np.isnan(amp)])
        ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
        ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
        ind_vmin = np.max([0, ind_vmin])
        ind_vmax = np.min([len(vals) - 1, ind_vmax])
        vmin = vals[ind_vmin]
        vmax = vals[ind_vmax]

    amp = np.where(amp < vmin, vmin, amp)
    amp = np.where(amp > vmax, vmax, amp)

    phase = np.angle(complex_data) + np.deg2rad(hue_start)
    amp /= np.max(amp)
    rgb = np.zeros(phase.shape +(3,))
    rgb[...,0] = 0.5*(np.sin(phase)+1)*amp
    rgb[...,1] = 0.5*(np.sin(phase+np.pi/2)+1)*amp
    rgb[...,2] = 0.5*(-np.sin(phase)+1)*amp
    
    return 1-rgb if invert else rgb
#create zernike probe
class Diverse_CTF(abtem.transfer.CTF):
    """
    child class of abTEM contrast transfer function
    """
    def __init__(
        self, 
        energy,
        defocus, 
        cutoff = 10/1000, 
        phase_shift = np.pi/2,
        *args, 
        **kwargs):
        """
        Parameters
        ----------
       
        """
         
        self.cutoff = cutoff
        self.E0 = energy
        self.phase_shift = phase_shift
        
        super().__init__(
            energy = energy,
            defocus = defocus,
            *args, 
            **kwargs)
        
    def evaluate(self, 
                 alpha, 
                 phi):
        """
        Parameters
        ----------

        Returns
        -------
        CTF: Union[float, np.ndarray]
            custom contrast transfer function 
        """
        from abtem.device import get_array_module, get_device_function

        xp = get_array_module(alpha)
        
        CTF = self.evaluate_aperture(alpha,phi)*self.evaluate_aberrations(alpha, phi)

        zernike = np.zeros_like(alpha) 
        zernike[alpha > self.cutoff] = self.phase_shift
        CTF *= xp.exp(1j * zernike)            
            
        return CTF

In [5]:
atoms = read('3jcl.xyz')
atoms.positions[:,0] -= atoms.positions[:,0].min()
atoms.positions[:,1] -= atoms.positions[:,1].min()
atoms.positions[:,2] -= atoms.positions[:,2].min()

atoms.cell[0][0] = atoms.positions[:,0].max()
atoms.cell[1][1] = atoms.positions[:,1].max()
atoms.cell[2][2] = atoms.positions[:,2].max()

atoms.center(vacuum = 10, axis = (0,1))
atoms.cell[0][0] = atoms.cell[1][1]


In [6]:


sampling = 1
slice_thickness = 1 
potential = abtem.Potential(
    atoms, 
    sampling = sampling,
    # gpts = 512, 
    slice_thickness = slice_thickness, 
    projection = 'infinite', 
    parametrization = 'kirkland',
    precalculate = True,
)

potential = potential.build()


potential_blurred = abtem.potentials.PotentialArray(
    gaussian_filter(potential.array,0.5), 
    potential.slice_thicknesses, 
    potential.extent,
    potential.sampling, 
)

wave = abtem.waves.PlaneWave(energy=300e3)
exit_waves = wave.multislice(potential_blurred)

Multislice:   0%|          | 0/137 [00:00<?, ?it/s]

In [7]:
with plt.ioff():
    fig = plt.figure(figsize = (9,3))

ax0 = fig.add_axes([0.04,  0.05,  0.28, 0.75])
ax1 = fig.add_axes([0.37,  0.05,  0.28, 0.75])
ax2 = fig.add_axes([0.70,  0.05,  0.28, 0.75])
ax3 = fig.add_axes([0.7,  0.6,  0.1, 0.2])

dose = 100
energy = 300e3
semiangle_cutoff = 10
defocus = 1000

ctf_focus = abtem.transfer.CTF(
    semiangle_cutoff = semiangle_cutoff,
    energy=energy
)


ctf_defocus = abtem.transfer.CTF(
    defocus = defocus,
    semiangle_cutoff = semiangle_cutoff,
    energy=energy
)

ctf_zernike = Diverse_CTF(
    energy, 
    defocus = 0,
    cutoff=0.5/1000, 
    semiangle_cutoff = semiangle_cutoff,
    phase_shift = np.pi/2
)

image_wave_focus = exit_waves.apply_ctf(ctf_focus)
intensity_focus = image_wave_focus.intensity()
noisy_focus = poisson_noise(intensity_focus, dose)

image_wave_defocus = exit_waves.apply_ctf(ctf_defocus)
intensity_defocus = image_wave_defocus.intensity()
noisy_defocus = poisson_noise(intensity_defocus, dose)

image_wave_zernike = exit_waves.apply_ctf(ctf_zernike)
intensity_zernike = image_wave_zernike.intensity()
noisy_zernike = poisson_noise(intensity_zernike, dose)

cmap = 'gray'

vmax = np.max([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
vmin = np.min([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
im0 = ax0.imshow(noisy_focus.array, vmax = vmax, vmin = vmin, cmap = cmap)
im1 = ax1.imshow(noisy_defocus.array, vmax = vmax, vmin = vmin, cmap = cmap)
im2 = ax2.imshow(noisy_zernike.array, vmax = vmax, vmin = vmin, cmap = cmap)
im3 = ax3.imshow(Complex2RGB(ctf_zernike.as_complex_image(exit_waves.grid).array, 0, 1))

ax0.set_xticks([])  
ax0.set_yticks([]) 
ax1.set_xticks([])  
ax1.set_yticks([]) 
ax2.set_xticks([])  
ax2.set_yticks([]) 
ax3.set_xticks([])
ax3.set_yticks([])
ax3.set_xlabel('')
ax3.set_ylabel('')
ax0.set_title('In focus intensity');
ax1.set_title('Defocused intensity');
ax2.set_title('Intensity with Zernike\nphase plate (In focus)');

def update_ims(dose, defocus, phase_shift, zernike_radius):
    phase_shift = np.deg2rad(phase_shift)
    ctf_focus = abtem.transfer.CTF(
        defocus = 0,
        semiangle_cutoff = semiangle_cutoff,
        energy=energy
    )

    ctf_defocus = abtem.transfer.CTF(
        defocus = defocus,
        semiangle_cutoff = semiangle_cutoff,
        energy=energy
    )

    ctf_zernike = Diverse_CTF(
        energy, 
        defocus = 0,
        cutoff=zernike_radius/1000, 
        semiangle_cutoff = semiangle_cutoff,
        phase_shift = phase_shift
    )

    image_wave_focus = exit_waves.apply_ctf(ctf_focus)
    intensity_focus = image_wave_focus.intensity()
    noisy_focus = poisson_noise(intensity_focus, dose)

    image_wave_defocus = exit_waves.apply_ctf(ctf_defocus)
    intensity_defocus = image_wave_defocus.intensity()
    noisy_defocus = poisson_noise(intensity_defocus, dose)

    image_wave_zernike = exit_waves.apply_ctf(ctf_zernike)
    intensity_zernike = image_wave_zernike.intensity()
    noisy_zernike = poisson_noise(intensity_zernike, dose)
    
    vmax = np.max([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
    vmin = np.min([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
    
    im0.set_data(noisy_focus.array)
    im0.set_clim(vmax = vmax, vmin = vmin)
    im1.set_data(noisy_defocus.array)
    im1.set_clim(vmax = vmax, vmin = vmin)
    im2.set_data(noisy_zernike.array)
    im2.set_clim(vmax = vmax, vmin = vmin)
    im3.set_data(Complex2RGB(ctf_zernike.as_complex_image(exit_waves.grid).array, 0, 1))

    fig.canvas.draw_idle()

style = {
    'description_width': 'initial',
}

defocus = widgets.IntSlider(
    value = 1000, min = -10000, max = 10000, 
    step = 100,
    description = "defocus (A)",
    style = style
)


dose = widgets.FloatLogSlider(
    value=100,
    base=10,
    min=0, # min exponent of base
    max=5, # max exponent of base
    step=0.05, # exponent step
    description = r"dose (e$^-$/A$^2$)",
    style = style,
)

phase_shift = widgets.IntSlider(
    value = 90, min = 0, max = 180, 
    step = np.pi/8,
    description = r"phase shift ($^\circ$)",
    style = style
)

zernike_radius = widgets.IntSlider(
    value = 1.0, min = 1.0, max = 10, 
    step = 1.0,
    description = r"radius of shift (mrad)",
    style = style,
    # readout_format='.1f',
)

widgets.interactive_output(
    update_ims, 
    {
        'dose':dose,
        'defocus':defocus,
        'phase_shift':phase_shift,
        'zernike_radius':zernike_radius,
    },
)


widgets.VBox(
    [
        fig.canvas,
        HBox([
            dose,defocus, 
            VBox([phase_shift,zernike_radius])
        ]),
    ],
)


VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…