In [None]:
from abtem import (GridScan, PixelatedDetector, Potential,
                   Probe, show_atoms, SMatrix, AnnularDetector, FrozenPhonons)
from abtem.measure import Measurement, Calibration, bandlimit, center_of_mass
from abtem.utils import energy2wavelength
from abtem.structures import orthogonalize_cell
from abtem.transfer import CTF, scherzer_defocus
from ase.io import read
from ase.build import mx2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.patches import Rectangle
from pylab import cm
from matplotlib.widgets import RectangleSelector
import matplotlib.patches as pch
import tifffile
import tkinter.filedialog as tkf
import json

#device = "cpu"
device = "gpu"

In [None]:
atoms_init = mx2(formula='WSe2', kind='2H', a=3.286, thickness=3.362, size=(1, 1, 1), vacuum=None)
atoms_init = orthogonalize_cell(atoms_init)
atoms_init.center(vacuum=2, axis=2)

atoms_init *= (30, 18, 2)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
show_atoms(atoms_init, ax=ax1, title='Top view', numbering=False)
show_atoms(atoms_init, ax=ax2, plane='yz', title='Side view', numbering=False)
plt.show()

In [None]:
num_atoms = atoms_init.numbers.shape[0]
elements = np.unique(atoms_init.numbers)
print(num_atoms)
print(elements)

atoms = atoms_init.copy()

In [None]:
print(atoms.cell)
print(np.unique(atoms_init.positions[:, 2]))

In [None]:
# twisted bilayer
twist_angle = 1.44 # degree
center = np.array([atoms.cell[0][0]/2, atoms.cell[1][1]/2])
print(center)
alpha, beta = np.cos(twist_angle*np.pi/180), np.sin(twist_angle*np.pi/180)
M = np.array([[alpha, beta, (1-alpha)*center[0]-beta*center[1]],
              [-beta, alpha, beta*center[0]+(1-alpha)*center[1]]])

layer_ind = []
z_ = [9.362, 11.043, 12.724]
for z in z_:
    layer_ind.extend(np.where(atoms_init.positions[:, 2]==z)[0])
print(len(layer_ind))

for ind in layer_ind:
    tmp = atoms.positions[ind][:2].copy() - center
    rotated = np.expand_dims(tmp, axis=0) @ M 
    atoms.positions[ind][:2] = rotated[0][:2] + center

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
show_atoms(atoms, ax=ax, title='Top view', numbering=False)
plt.show()

In [None]:
# vacancy generation
#def_element = elements
def_element = [74]
ind_element = []

for e in def_element:
    tmp_ind = np.where(atoms.numbers==e)
    ind_element.extend(tmp_ind[0].tolist())

def_ratio = 0.05
ri = np.random.choice(ind_element, int(len(ind_element)*def_ratio), replace=False)
print(ri)

for ai in ri:
    del atoms[ai]
    
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))
show_atoms(atoms, ax=ax1, title='Top view', numbering=False)
show_atoms(atoms, ax=ax2, plane='xz', title='Side view')
plt.show()

In [None]:
# substitutional defect generation
def_element = [74]
sub_element = [42]
sub_element_prob = [1.0]
ind_element = []

for e in def_element:
    tmp_ind = np.where(atoms.numbers==e)
    ind_element.extend(tmp_ind[0].tolist())
    
def_ratio = 0.1
ri = np.random.choice(ind_element, int(len(ind_element)*def_ratio), replace=False)
print(ri)

for ai in ri:
    atoms.numbers[ai] = np.random.choice(sub_element, 1, p=sub_element_prob)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))
show_atoms(atoms, ax=ax1, title='Top view', numbering=False)
show_atoms(atoms, ax=ax2, plane='xz', title='Side view')
plt.show()

In [None]:
# interstitial defect generation
# ongoing...

In [None]:
frozen_phonons = FrozenPhonons(atoms, 64, sigmas=0.05, seed=56)
tds_potential = Potential(frozen_phonons, 
                      sampling=.05,
                      projection='infinite', 
                      slice_thickness=1, 
                      parametrization='kirkland',
                      device=device).build()
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
tds_potential.project().show(ax=ax, cmap="inferno")
plt.show()

In [None]:
energy = 80E3
C3 = 1E3 # 100 nm
semiangle = 20 # mrad
sch_defocus = scherzer_defocus(C3, energy)
print(sch_defocus)

In [None]:
#ctf = CTF(parameters={'C10': sch_defocus*0.9,'C12': 20, 'phi12': 0.785,'C30': C3}, semiangle_cutoff=semiangle)
#ctf = CTF(parameters={'C10': sch_defocus*0.9, 'C30': C3}, semiangle_cutoff=semiangle)

probe = Probe(semiangle_cutoff=semiangle, energy=energy, device=device)
probe.grid.match(tds_potential)

print(probe.ctf.nyquist_sampling)

angle_step_default = (probe.wavelength *1000 / probe.extent[0], probe.wavelength *1000 / probe.extent[1])
print("angle step x", probe.wavelength *1000 / probe.extent[0])
print("angle step y", probe.wavelength *1000 / probe.extent[1])

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
probe.show(ax=ax[0])
probe.show(ax=ax[1], power=0.5)
plt.show()

In [None]:
probe_step = 0.6 # Å
gridscan = GridScan((0,0), np.array(tds_potential.extent), sampling=probe_step)
print(gridscan.gpts)

In [None]:
detect_max_angle = semiangle*4
angle_pixel_size = (0.4, 0.4)  # (mrad, mrad) tuple
detector = PixelatedDetector(max_angle=detect_max_angle, resample=angle_pixel_size)

In [None]:
measurement = probe.scan(gridscan, [detector], tds_potential)
print(measurement.shape)

In [None]:
print(measurement.shape)
print(*measurement.calibration_limits, sep="\n")
for i in range(measurement.dimensions):
    print(measurement.calibrations[i].name, measurement.calibrations[i].units,
          measurement.calibrations[i].sampling)

In [None]:
measurement_sampling = tuple(energy2wavelength(energy)*1000/(cal.sampling * pixels) 
                                                      for cal,pixels 
                                                      in zip(measurement.calibrations[-2:], 
                                                             measurement.shape[-2:]))

print(f'pixelated_measurement sampling: {measurement_sampling} Å')

In [None]:
measurement_extent = tuple(sampling*pixels for sampling,pixels 
                                      in zip(measurement_sampling, measurement.shape[-2:]))

print(f'pixelated_measurement extent: {measurement_extent} Å')

In [None]:
bright_detector = AnnularDetector(inner=10, outer=semiangle)
bright_measurement = bright_detector.integrate(measurement)

maadf_detector = AnnularDetector(inner=semiangle, outer=detect_max_angle)
maadf_measurement = maadf_detector.integrate(measurement)

pacbed = measurement.mean(axis=(0, 1))

fig, ax = plt.subplots(2, 2, figsize=(10, 10))
measurement.show(ax=ax[0, 0], cmap="inferno", power=0.5)
pacbed.show(ax=ax[0, 1], cmap="inferno", power=0.5)
bright_measurement.show(ax=ax[1, 0])
maadf_measurement.show(ax=ax[1, 1])
fig.tight_layout()
plt.show()

In [None]:
save_name = "WSe2_twisted_bilayer_1_44deg_simulated_4DSTEM_01"
tifffile.imwrite(save_name+".tif", measurement.array)

In [None]:
calibration_info = {}
calibration_info["material"] = "WSe2"
calibration_info["beam energy"] = energy
calibration_info["convergence angle"] = semiangle
calibration_info["scan pixel size"] = (measurement.calibrations[0].sampling, measurement.calibrations[1].sampling)
calibration_info["angle pixel size"] = (measurement.calibrations[2].sampling, measurement.calibrations[3].sampling)

In [None]:
with open(save_name+".txt", 'w') as file:
     file.write(json.dumps(calibration_info, ensure_ascii=False))