In [30]:
%matplotlib qt

import numpy as np
from ase import Atoms
from ase.cell import Cell
from ase.data import chemical_symbols
from abtem import show_atoms, orthogonalize_cell
from sklearn.neighbors import NearestNeighbors


def get_a_thickness(formula):
    a_thickness_dict = {
        'TiS2': (3.35, 3.011),
        'VS2': (3.19, 2.962),
        'NbS2': (3.37, 3.397),
        'NbSe2': (3.5, 3.343),
        'MoS2': (3.19, 3.128),
        'MoSe2': (3.32, 3.31),
        'MoTe2': (3.57, 3.579),
        'TaS2': (3.34, 3.099),
        'TaSe2': (3.47, 3.331),
        'WS2': (3.18, 3.12),
        'WSe2': (3.32, 3.324),
        'WTe2': (3.56, 3.578),
        'ReSe2': (3.47, 3.1)
    }
    try:
        a, t = a_thickness_dict[formula]
    except:
        a, t = 3.19, 3.128
    return a, t


def mx2_unit(formula='MoS2', size=(1, 1, 1), vacuum=2):
    # get a and thickness according to formula
    a, thickness = get_a_thickness(formula)

    basis = [(0, 0, 0),
             (2 / 3, 1 / 3, 0.5 * thickness),
             (2 / 3, 1 / 3, -0.5 * thickness)]
    cell = [[a, 0, 0], [-a / 2, a * 3 ** 0.5 / 2, 0], [0, 0, 0]]

    atoms = Atoms(formula, cell=cell, pbc=(1, 1, 0))
    atoms.set_scaled_positions(basis)
    if vacuum is not None:
        atoms.center(vacuum, axis=2)
    atoms = atoms.repeat(size)
    return atoms


def make_it_orhto(atoms):
    cell_ = atoms.cell.copy()
    a, b, c = atoms.cell
    cell_[1] = [0., b[1], 0.]

    atoms.set_cell(cell_)
    atoms.wrap()
    atoms.center()
    return atoms


def get_centered_m(atoms, return_ind=False):
    z_dict = {symbol: Z for Z, symbol in enumerate(chemical_symbols)}

    metal = [e for e in atoms.symbols.species() if e not in ['S', 'Se', 'Te']][0]
    Z = z_dict[metal]
    #pts_ = atoms.positions[atoms.numbers == Z]
    pts_ = atoms.positions
    pts = pts_[:, 0:2]
    p = atoms.cell.array.sum(axis=0)[None, 0:2] / 2

    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pts)
    d, ind = nbrs.kneighbors(p)
    p_xyz = pts_[ind[0][0]]
    p_xyz[2] = 0
    if return_ind:
        return p_xyz, ind[0][0]
    else:
        return p_xyz


def get_centered_x(atoms):
    z_dict = {symbol: Z for Z, symbol in enumerate(chemical_symbols)}

    element = [e for e in atoms.symbols.species() if e in ['S', 'Se', 'Te']][0]
    Z = z_dict[element]
    pts_ = atoms.positions[atoms.numbers == Z]
    pts = pts_[:, 0:2]
    p = atoms.cell.array.sum(axis=0)[None, 0:2] / 2

    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pts)
    d, ind = nbrs.kneighbors(p)
    p_xyz = pts_[ind[0][0]]
    p_xyz[2] = 0
    return p_xyz


def crop_atoms(atoms, L=20.):
    pts = atoms.positions
    x, y, z = pts.T
    l = L
    mask1 = np.logical_and(x > -l, x < l)
    mask2 = np.logical_and(y > -l, y < l)
    mask = mask1 * mask2
    pts_ = pts[mask]
    numbers_ = atoms.numbers[mask]
    return Atoms(numbers=numbers_, positions=pts_, cell=atoms.cell)


def get_mx2_atoms(L=10, formula='MoS2', vacuum=2, theta=0, center='m'):
    L = L / 2.

    unit = mx2_unit(formula=formula, vacuum=vacuum)

    # repeat, L has to be integer
    S = np.ceil(L).astype(int)
    atoms = unit.repeat([S, S, 1])
    # make it orthogonal
    atoms = make_it_orhto(atoms)
    if center == 'm':
        center_xyz = get_centered_m(atoms)
    elif center == 'x':
        center_xyz = get_centered_x(atoms)
    else:
        raise ValueError('center can only be m or x')

    atoms.translate(-center_xyz)
    atoms.rotate(theta, 'z')
    atoms = crop_atoms(atoms, L)
    atoms.translate([L, L, center_xyz[2]])

    c_new = unit.cell[2]
    cell_new = Cell.fromcellpar([2 * L, 2 * L, c_new, 90, 90, 90])

    return Atoms(atoms.symbols, atoms.positions, cell=cell_new)


def remove_atoms(atoms, inds):
    atoms_copy = atoms.copy()
    del atoms_copy[inds]
    return atoms_copy


def replace_atoms(atoms, inds, element):
    atoms_copy = atoms.copy()
    if np.iterable(inds):
        for ind in inds:
            atoms_copy.symbols[ind] = element
    else:
        atoms_copy.symbols[inds] = element
    return atoms_copy


def get_mx_elements(symbols):
    elements = np.unique(list(symbols), return_counts=False)[::-1]
    e1, e2 = elements[0:2]
    if e1 in ['S', 'Se', 'Te']:
        return e2, e1
    else:
        return e1, e2


# use composition
# center index, and nearby three indices
class MX2:

    def __init__(self, L=10, formula='MoS2', vacuum=2, theta=0, center='m'):
        self.formula = formula
        self.center = center
        self.a, t = get_a_thickness(self.formula)
        self.atoms = get_mx2_atoms(L=L, formula=formula, vacuum=vacuum, theta=theta, center=center)
        self.m_element, self.x_element = get_mx_elements(self.atoms.symbols)

    def show(self):
        show_atoms(self.atoms)

## Create multislice simulation tif files (Atoms)

In [33]:
formula = 'WSe2'
dopant_list = ['Ti', 'V', 'Cr', 'Mn', 'Co']
L = 20

model = MX2(L=L, formula=formula)
atoms = model.atoms
_, ind = get_centered_m(atoms, return_ind=True)
atoms1 = replace_atoms(atoms, [ind], element=dopant_list[0])