In [None]:
from abtem import GridScan, PixelatedDetector, Potential, Probe, show_atoms, SMatrix, AnnularDetector, FrozenPhonons
from abtem.detect import PixelatedDetector
from abtem.reconstruct import MixedStatePtychographicOperator, RegularizedPtychographicOperator, MultislicePtychographicOperator
from abtem.measure import Measurement, Calibration, bandlimit, center_of_mass
from abtem.utils import energy2wavelength
from abtem.transfer import CTF, scherzer_defocus
from abtem.structures import orthogonalize_cell
from abtem.noise import poisson_noise
from ase.build import mx2, surface
from ase.io import read
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Times New Roman'
import tifffile

In [None]:
STO_unit_cell   = read('data/SrTiO3.cif')
STO_atoms       = surface(STO_unit_cell, (1,1,0), 4, periodic=True)*(2,3,1)
STO_atoms_thick      = STO_atoms *(1,1,3)

from abtem import show_atoms
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (9,3))

show_atoms(STO_atoms_thick, ax = ax1, title = 'Beam view')
show_atoms(STO_atoms_thick, ax = ax2, plane = 'yz', title = 'Side view')
show_atoms(STO_atoms_thick, ax = ax3, plane = 'xz', title = 'Side view')

fig.tight_layout()
plt.show()

In [None]:
energy = 200E3
aperture_semiangle = 30 # mrad
C3 = -7E4
sch_defocus = scherzer_defocus(C3, energy)
print(sch_defocus)

In [None]:
frozen_phonons_thick = FrozenPhonons(STO_atoms_thick, 64, {'Sr' : .05, 'Ti' : .05, 'O' : .05}, seed=1)
potential_thick      = Potential(frozen_phonons_thick,
                                 sampling=0.02,
                                 projection='infinite',
                                 parametrization='kirkland').build()

ctf             = CTF(parameters={'C10': sch_defocus*0.9, 'C30':C3}, semiangle_cutoff=aperture_semiangle)
probe           = Probe(semiangle_cutoff=aperture_semiangle,
                        energy=energy,
                        ctf=ctf)

probe.match_grid(potential_thick)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
probe.show(ax=axes[0])
probe.show(ax=axes[1], power=0.5)
plt.show()

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
potential_thick.project().show(ax=ax1, cmap="inferno")
plt.show()

In [None]:
pixelated_detector  = PixelatedDetector(max_angle=aperture_semiangle*3.0)
gridscan            = GridScan((0, 0), potential_thick.extent, sampling = 0.5)
measurement   = probe.scan(gridscan, pixelated_detector, potential_thick)

print(measurement.shape)
for i in range(measurement.dimensions):
    print(measurement.calibrations[i].name, measurement.calibrations[i].units, measurement.calibrations[i].sampling)

In [None]:
bright_detector = AnnularDetector(inner=10, outer=30)
bright_measurement = bright_detector.integrate(measurement)

maadf_detector = AnnularDetector(inner=50, outer=90)
maadf_measurement = maadf_detector.integrate(measurement)

pacbed = measurement.mean(axis=(0, 1))

fig, axes = plt.subplots(1, 4, figsize=(20, 5))
measurement.show(ax=axes[0], cmap="inferno", power=0.5)
pacbed.show(ax=axes[1], cmap="inferno", power=0.1)
bright_measurement.show(ax=axes[2])
maadf_measurement.show(ax=axes[3])
fig.tight_layout()
plt.show()

In [None]:
com_x, com_y = center_of_mass(measurement)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

com_x.interpolate(.1).show(ax=ax1)
com_y.interpolate(.1).show(ax=ax2)
fig.tight_layout()
plt.show()

In [None]:
icom = center_of_mass(measurement, return_icom=True)

icom.show()
plt.show()

In [None]:
band_limited_measurment = bandlimit(measurement, aperture_semiangle*2.0)
band_limited_measurment.show(cmap="inferno", power=0.5)
plt.show()

In [None]:
n_iter = 10
experimental_ptycho_operator = RegularizedPtychographicOperator(band_limited_measurment,
                                                               semiangle_cutoff=aperture_semiangle,
                                                               energy=energy,
                                                               parameters={'object_px_padding':(10, 10)}).preprocess()

exp_objects, exp_probes, exp_positions, exp_sse  = experimental_ptycho_operator.reconstruct(
    max_iterations = n_iter,
    random_seed=1,
    return_iterations=True,
    parameters={'alpha':1.0,
                'beta':1.0})

print(exp_objects[-1].shape)
for i in range(exp_objects[-1].dimensions):
    print(exp_objects[-1].calibrations[i].name, exp_objects[-1].calibrations[i].units,
          exp_objects[-1].calibrations[i].sampling)

In [None]:
plot_every = int(n_iter/5)

fig, axes = plt.subplots(2, int(np.ceil(len(exp_objects) / plot_every))+1, figsize=(20, 8))

for i, j in enumerate(range(0,len(exp_objects), plot_every)):
    axes[0,i].imshow(np.angle(exp_objects[j].array).T, origin='lower', cmap='gray')
    axes[0,i].set_title('iteration: %d, SSE: %.2e'%(j+1, exp_sse[j]))
    axes[1,i].imshow(np.abs(exp_probes[j].array).T, origin='lower', cmap='gray')
    axes[0,i].axis("off")
    axes[1,i].axis("off")

axes[0,-1].imshow(np.angle(exp_objects[-1].array).T, origin='lower', cmap='gray')
axes[0,-1].set_title('iteration: %d, SSE: %.2e'%(n_iter, exp_sse[-1]))
axes[1,-1].imshow(np.abs(exp_probes[-1].array).T, origin='lower', cmap='gray')
axes[0,-1].axis("off")
axes[1,-1].axis("off")

fig.tight_layout()
plt.show()

In [None]:
fig, axd = plt.subplots(1, 2, figsize=(20, 10))
exp_objects[-1].angle().show(ax=axd[0], title=f"SSE = {float(exp_sse[-1]):.3e}", cmap='inferno')
exp_probes[-1].intensity().show(ax=axd[1], cmap="gray", power=0.5)
fig.tight_layout()
plt.show()

# Multislice ptychography
![alt text](image/multislice_ptychography.jpg "practice")

In [None]:
n_slice = 3
slice_thicknesses = STO_atoms_thick.cell.lengths()[-1]/n_slice
print(slice_thicknesses)

In [None]:
multislice_reconstruction_ptycho_operator = MultislicePtychographicOperator(band_limited_measurment,
                                                                            semiangle_cutoff=aperture_semiangle,
                                                                            energy=energy,
                                                                            num_slices = n_slice,
                                                                            slice_thicknesses = slice_thicknesses,
                                                                            parameters={'object_px_padding':(0,0)}).preprocess()

mspie_objects, mspie_probes, mspie_positions, mspie_sse = multislice_reconstruction_ptycho_operator.reconstruct(
    max_iterations = 5,
    verbose=True,
    random_seed=1,
    return_iterations=True)

In [None]:
plot_every = int(n_iter/5)

fig, axes = plt.subplots(2, int(np.ceil(len(mspie_objects) / plot_every))+1, figsize=(20, 8))

for i, j in enumerate(range(0,len(mspie_objects), plot_every)):
    axes[0,i].imshow(np.sum(np.angle(mspie_objects[j].array), axis=0).T, origin='lower', cmap='gray')
    axes[0,i].set_title('iteration: %d, SSE: %.2e'%(j+1, mspie_sse[j]))
    axes[1,i].imshow(np.sum(np.abs(mspie_probes[j].array), axis=0).T**2, origin='lower', cmap='gray')
    axes[0,i].axis("off")
    axes[1,i].axis("off")

axes[0,-1].imshow(np.angle(mspie_objects[-1].array).sum(axis=0).T, origin='lower', cmap='gray')
axes[0,-1].set_title('iteration: %d, SSE: %.2e'%(n_iter, mspie_sse[-1]))
axes[1,-1].imshow(np.sum(np.abs(mspie_probes[-1].array), axis=0).T**2, origin='lower', cmap='gray')
axes[0,-1].axis("off")
axes[1,-1].axis("off")

fig.tight_layout()
plt.show()

In [None]:
fig, axd = plt.subplots(1, 2, figsize=(10, 5))
mspie_objects[-1].angle().sum(0).show(ax=axd[0], title=f"SSE = {float(mspie_sse[-1]):.3e}", cmap='inferno')
mspie_probes[-1][0].intensity().show(ax=axd[1], cmap="gray", power=0.5)
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, n_slice, figsize=(5*n_slice, 5))
for i in range(n_slice):
    ax[i].imshow(np.angle(mspie_objects[-1].array[i]).T, cmap="inferno")
    ax[i].axis("off")
    
fig.tight_layout()
plt.show()