# STEM simulations with the PRISM algorithm

In [1]:
%load_ext autoreload
%autoreload 2
from ase.build import mx2
from ase.io import read
import numpy as np

from tensorwaves.waves import PrismWaves
from tensorwaves.potentials import Potential
from tensorwaves.detect import RingDetector

## Quick Simulation

In [2]:
atoms = read('../data/mos2.traj')

cell = np.diag(atoms.get_cell())

print('Simulation super cell:', cell)

Simulation super cell: [12.72       11.01584314  5.19      ]


In [3]:
waves = PrismWaves(energy=80e3, cutoff=.02, interpolation=2, sampling=.1)

S = waves.multislice(atoms)

detector = RingDetector(inner=.05, outer=.2)

scan = S.gridscan(start=(0,0), end=(cell[0] / 2, cell[1] / 2), sampling=.2, detectors=detector)

Scanning [||||||||||||||||||||||||||||||||||||||||||||||||||] 837/837 


In [4]:
image = scan.image()

image.show()

InvalidArgumentError: Input to reshape is a tensor with 19251 values, but the requested shape has 837 [Op:Reshape]

## Step-by-step Simulation

### Set up unit cell

In [None]:
atoms = mx2(formula='MoS2', kind='2H', a=3.18, thickness=3.19)
atoms *= (2,2,1)
atoms.cell[1,0] = 0
atoms.wrap()
atoms *= (2,2,1)
atoms.center(vacuum=1, axis=2)

cell = np.diag(atoms.get_cell())

from ase.io import write

write('MoS2.traj',atoms)


print(cell)

### Create Potential

In [None]:
potential = Potential(atoms=atoms, parametrization='kirkland', num_slices=10, sampling=.1)

potential.current_slice = 3

potential.show(fig_scale=2)

### Create Scattering Matrix

In [None]:
waves = PrismWaves(energy=80e3, cutoff=.02, interpolation=1, gpts=potential.gpts, extent=potential.extent)
S = waves.get_scattering_matrix()

S = S.multislice(potential)

### Set Aberrations & Examine Probe

In [None]:
S.position = (2,0)

S.aberrations.parametrization.C10 = 10
S.aperture.radius = .02

probe = S.get_probe()
probe.show(fig_scale=2, display_space='direct')

In [None]:
S.aberrations.parametrization.defocus

### Create Detector

In [None]:
detector = RingDetector(inner=.05, outer=.2, gpts=S.gpts, extent=S.extent, energy=S.energy)

detector.get_tensor().show(display_space='fourier')

### Perform Grid Scan 

In [None]:
start = (0,0)
end = (cell[0,0]/2, cell[1,1]/2)

scan = S.gridscan(start=(0,0), end=end, sampling=.2, detectors=detector)

In [None]:
image = scan.image()
image.show(fig_scale=2)

In [None]:
image.numpy()