In [434]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt

import sys
import platform
if platform.system() == 'Darwin':
	sys.path.append('/Users/jiadongdan/Dropbox/stempy')
else:
	sys.path.append('D:\\Dropbox\\stempy')
    
from stempy.io import *
from stempy.denoise import *
from stempy.datasets import *
from stempy.utils import *
from stempy.plot import *
from stempy.feature import *
from stempy.manifold import *
from stempy.clustering import *
from stempy.spatial import * 
from stempy.graph import *

In [440]:
import numpy as np
from ase import Atoms
from ase.cell import Cell
from ase.data import chemical_symbols
from ase.formula import Formula
from sklearn.neighbors import NearestNeighbors
from skimage.filters import gaussian
from skimage.transform import rotate


def mx2(formula='MoS2', kind='1H', a=3.18, c=6, size=(3, 3, 1), vacuum=None, center='m'):
    if kind == '1H':
        coords = [(2 / 3., 1 / 3., 1 / 2.), (1 / 3., 2 / 3., 1 / 4.), (1 / 3., 2 / 3., 3 / 4.)]
        cell = Cell.fromcellpar([a, a, c, 90, 90, 120])
    elif kind == '1T':
        coords = [(1 / 2., 1 / 2., 1 / 2.), (1 / 6., 5 / 6., 1 / 4.), (5 / 6., 1 / 6., 3 / 4.)]
        cell = Cell.fromcellpar([a, a, c, 90, 90, 120])
    elif kind == "1T'":
        coords = [[0., 0.2042, 0.22344],
                  [1., 0.2042, 0.22344],
                  [0.5, 0.69727, 0.13079],
                  [0.5, 0.35203, 0.66327],
                  [0., 0.85774, 0.75609],
                  [1., 0.85774, 0.75609],
                  [0.5, 0.956, 0.45963],
                  [0., 0.60038, 0.42711],
                  [1., 0.60038, 0.42711]]
        cell = Cell.fromcellpar(a, a * 1.7911490156481573, c, 90, 90, 90)
        formula = 'S6Mo3'
    else:
        raise ValueError('Structure not recognized:', kind)

    atoms = Atoms(formula, cell=cell, pbc=(1, 1, 0))
    atoms.set_scaled_positions(coords)
    if vacuum is not None:
        atoms.center(vacuum, axis=2)
    atoms = atoms.repeat(size)
    return atoms


def mx2_1H(formula='MoS2', a=3.18, c=6):
    t = 3.1
    z = t / c / 2
    positions = [(0, 0, 0.5), (2 / 3, 1 / 3, 0.5 - z), (2 / 3, 1 / 3, 0.5 + z)]
    cell = Cell.fromcellpar([a, a, c, 90, 90, 120])
    atoms = Atoms(formula, cell=cell, pbc=(1, 1, 0))
    atoms.set_scaled_positions(positions)
    return atoms


def mx2_1T(formula='MoS2', a=3.18, c=6):
    t = 3.1
    z = t / c / 2
    positions = [(0, 0, 0.5), (2 / 3, 1 / 3, 0.5 - z), (1 / 3, 2 / 3, 0.5 + z)]
    cell = Cell.fromcellpar([a, a, c, 90, 90, 120])
    atoms = Atoms(formula, cell=cell, pbc=(1, 1, 0))
    atoms.set_scaled_positions(positions)
    return atoms


def mx2_1T_prime(formula='MoS2', a=3.16, c=6):
    # convert to 1T prime formula
    element_x = [e for e in Formula(formula).count().keys() if e in ['S', 'Se', 'Te']][0]
    element_m = [e for e in Formula(formula).count().keys() if e not in ['S', 'Se', 'Te']][0]
    formula_ = element_x * 4 + element_m * 2
    b = a * 1.7911490156481573
    positions = [[0., 0.2042, 0.22344],
                 [0.5, 0.69727, 0.13079],
                 [0.5, 0.35203, 0.66327],
                 [0., 0.85774, 0.75609],
                 [0.5, 0.956, 0.45963],
                 [0., 0.60038, 0.42711]]
    cell = Cell.fromcellpar([a, b, c, 90, 90, 90])
    atoms = Atoms(formula_, cell=cell, pbc=(1, 1, 0))
    atoms.set_scaled_positions(positions)
    return atoms


def make_it_orhto(atoms):
    cell_ = atoms.cell.copy()
    a, b, c = atoms.cell
    cell_[1] = [0., b[1], 0.]

    atoms.set_cell(cell_)
    atoms.wrap()
    atoms.center()
    return atoms


def get_centered_m(atoms):
    z_dict = {symbol: Z for Z, symbol in enumerate(chemical_symbols)}

    metal = [e for e in atoms.symbols.species() if e not in ['S', 'Se', 'Te']][0]
    Z = z_dict[metal]
    pts_ = atoms.positions[atoms.numbers == Z]
    pts = pts_[:, 0:2]
    p = atoms.cell.array.sum(axis=0)[None, 0:2] / 2

    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pts)
    d, ind = nbrs.kneighbors(p)
    p_xyz = pts_[ind[0][0]]
    p_xyz[2] = 0
    return p_xyz


def get_centered_x(atoms):
    z_dict = {symbol: Z for Z, symbol in enumerate(chemical_symbols)}

    element = [e for e in atoms.symbols.species() if e in ['S', 'Se', 'Te']][0]
    Z = z_dict[element]
    pts_ = atoms.positions[atoms.numbers == Z]
    pts = pts_[:, 0:2]
    p = atoms.cell.array.sum(axis=0)[None, 0:2] / 2

    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pts)
    d, ind = nbrs.kneighbors(p)
    p_xyz = pts_[ind[0][0]]
    p_xyz[2] = 0
    return p_xyz


def crop_atoms(atoms, L=20):
    pts = atoms.positions
    x, y, z = pts.T
    l = L
    mask1 = np.logical_and(x > -l, x < l)
    mask2 = np.logical_and(y > -l, y < l)
    mask = mask1 * mask2
    pts_ = pts[mask]
    numbers_ = atoms.numbers[mask]
    return Atoms(numbers=numbers_, positions=pts_, cell=atoms.cell)


def split_pts(pts, p):
    ind = np.arange(len(pts))
    num_defects = np.maximum(1, int(len(ind) * p))
    ind_select = np.random.choice(ind, num_defects)
    mask = np.isin(ind, ind_select)
    pts_defects = pts[mask]
    pts = pts[~mask]
    return pts, pts_defects


class MX2(Atoms):

    def __init__(self, L=20, formula='MoS2', kind='1H', a=3.18, c=6, theta=0, center='m', **kwargs):
        # get chemical elements for M and X
        e1, e2 = list(Formula('MoS2').count().keys())
        if e1 in ['S', 'Se', 'Te']:
            self.m, self.x = e2, e1
        else:
            self.m, self.x = e1, e2
        self.dopant = None

        self.a = a

        if kind == '1H':
            unit = mx2_1H(formula=formula, a=a, c=c)
        elif kind == '1T':
            unit = mx2_1T(formula=formula, a=a, c=c)
        elif kind == "1T'":
            unit = mx2_1T_prime(formula=formula, a=a, c=c)
        else:
            raise ValueError('Structure not recognized:', kind)

        # repeat
        atoms = unit.repeat([L, L, 1])
        # make it orthogonal
        atoms = make_it_orhto(atoms)
        if center == 'm':
            center_xyz = get_centered_m(atoms)
        elif center == 'x':
            center_xyz = get_centered_x(atoms)

        atoms.translate(-center_xyz)
        atoms.rotate(theta, 'z')
        atoms = crop_atoms(atoms, L)
        atoms.translate([L, L, center_xyz[2]])

        cell_new = Cell.fromcellpar([2 * L, 2 * L, c, 90, 90, 90])

        super().__init__(atoms.symbols, atoms.positions, cell=cell_new)

    @property
    def ptsm(self):
        return self.positions[self.symbols == self.m]

    @property
    def ptsx(self):
        pts = self.positions[self.symbols == self.x]
        mask = pts[:, 2] < pts.mean(axis=0)[2]
        return pts[mask]

    def make_dopants(self, p=0.05, dopant='W'):
        ind = np.where(self.symbols == self.m)[0]
        num_dopants = np.maximum(1, int(len(ind) * p))
        ind_select = np.random.choice(ind, num_dopants)
        self.dopant = dopant
        self.symbols[ind_select] = dopant

    def to_image(self, size=512, p_dopant=0.05, p_vac=0.05, p_empty=0.02, ratio=0.5, ratio_dopant=1.5, sigma=None):
        ptsm = self.ptsm
        ptsx = self.ptsx
        # dopants
        ptsm, pts_dopants = split_pts(ptsm, p_dopant)
        # single vacancy
        ptsx, pts_vac = split_pts(ptsx, p_vac)
        # double vacancy
        ptsx, pts_empty = split_pts(ptsx, p_empty)

        img1 = np.zeros((size, size))
        img2 = np.zeros((size, size))
        img3 = np.zeros((size, size))
        img4 = np.zeros((size, size))

        scale = size / self.cell.cellpar()[0]
        if sigma is None:
            sigma = scale * self.a // 6

        for (x, y) in ptsm[:, 0:2]:
            img1[int(y * scale), int(x * scale)] = 1.
        for (x, y) in ptsx[:, 0:2]:
            img2[int(y * scale), int(x * scale)] = 1
        for (x, y) in pts_dopants[:, 0:2]:
            img3[int(y * scale), int(x * scale)] = 1
        for (x, y) in pts_vac[:, 0:2]:
            img4[int(y * scale), int(x * scale)] = 1
        img1 = gaussian(img1, sigma)
        img2 = gaussian(img2, sigma)
        img3 = gaussian(img3, sigma)
        img4 = gaussian(img4, sigma)

        # plot_compare([img1, img2, img3, img4])
        layer_image = img1 + img2 * ratio + img3 * ratio_dopant + img4 * ratio * 0.5
        return layer_image / layer_image.max()


def generate_image(size=512, a=12, p_dopant=0.05, p_vac=0.05, p_empty=0.02, ratio=0.5, ratio_dopant=1.5, sigma=None,
                   theta=None):
    size_ = int(size * np.sqrt(2)) + 1
    L = size_ // 2
    mx2 = MX2(L=L, a=a)
    img = mx2.to_image(size_, p_dopant, p_vac, p_empty, ratio, ratio_dopant, sigma)
    if theta is None:
        return img[size_ // 2 - size // 2:size_ // 2 + size // 2, size_ // 2 - size // 2:size_ // 2 + size // 2]
    else:
        img_rot = rotate(img, angle=theta)
        return img_rot[size_ // 2 - size // 2:size_ // 2 + size // 2, size_ // 2 - size // 2:size_ // 2 + size // 2]



def _check_pts(pts, size):
    mask1 = pts[:, 0] < size
    mask2 = pts[:, 1] < size
    mask = np.logical_and(mask1, mask2)
    return pts[mask]

class mx2image:

    def __init__(self, size=512, a=12, theta=0, p_dopants=0.05, p_single=0.05, p_double=0.02):
        self.size = size
        self.a = a
        self.theta = theta

        #size_ = int(size * np.sqrt(2)) + 1
        L = self.size // 2
        mx2 = MX2(L=L, a=self.a, theta=self.theta)

        self.p_dopants = p_dopants
        self.pts_single = p_single
        self.pts_double = p_double

        ptsm = mx2.ptsm
        ptsx = mx2.ptsx
        # dopants
        ptsm, pts_dopants = split_pts(ptsm, p_dopants)
        # single vacancy
        ptsx, pts_single = split_pts(ptsx, p_single)
        # double vacancy
        ptsx, pts_double = split_pts(ptsx, p_double)

        self.ptsm = np.round(ptsm).astype(int)
        self.ptsx = np.round(ptsx).astype(int)
        self.pts_dopants = np.round(pts_dopants).astype(int)
        self.pts_single = np.round(pts_single).astype(int)
        self.pts_double = np.round(pts_double).astype(int)

        self.ptsm = _check_pts(self.ptsm, self.size)
        self.ptsx= _check_pts(self.ptsx, self.size)
        self.pts_dopants = _check_pts(self.pts_dopants, self.size)
        self.pts_single = _check_pts(self.pts_single, self.size)
        self.pts_double = _check_pts(self.pts_double, self.size)

        self.pts = np.vstack([self.ptsm, self.ptsx, self.pts_dopants, self.pts_single, self.pts_double])
        self.lbs = np.array([0] * len(self.ptsm) + [1] * len(self.ptsx) + [2] * len(self.pts_dopants) + [3] * len(self.pts_single) + [4] * len(self.pts_double))

    def generate_data(self, s1=0.5, s2=1.5, sigma=None):
        shape = (self.size, self.size)
        img1 = np.zeros(shape)
        img2 = np.zeros(shape)
        img3 = np.zeros(shape)
        img4 = np.zeros(shape)

        if sigma is None:
            sigma = self.a // 6

        for (x, y) in self.ptsm[:, 0:2]:
            img1[int(y), int(x)] = 1
        for (x, y) in self.ptsx[:, 0:2]:
            img2[int(y), int(x)] = 1
        for (x, y) in self.pts_dopants[:, 0:2]:
            img3[int(y), int(x)] = 1
        for (x, y) in self.pts_single[:, 0:2]:
            img4[int(y), int(x)] = 1
        img1 = gaussian(img1, sigma, mode='constant')
        img2 = gaussian(img2, sigma, mode='constant')
        img3 = gaussian(img3, sigma, mode='constant')
        img4 = gaussian(img4, sigma, mode='constant')

        plot_compare([img1, img2, img3, img4])
        layer_image = img1 + img2 * s1 + img3 * s2 + img4 * s1 * 0.5
        return layer_image / layer_image.max()

In [444]:
aa = mx2image(theta=20)

In [445]:
img = aa.generate_data()

In [449]:
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))
ax.imshow(img, cmap='gray')
ax.scatter(aa.pts[:, 0], aa.pts[:, 1], color=colors_from_lbs(aa.lbs), s=5) 

<matplotlib.collections.PathCollection at 0x1b60123dbe0>

In [450]:
imshow(img)

<matplotlib.image.AxesImage at 0x1b60124fa90>

In [384]:
kk = generate_image(theta=10)

imshow(kk) 

<matplotlib.image.AxesImage at 0x1b5d3a77280>

In [386]:
aa = MX2(L=512//2, a=12, theta=10)

In [387]:
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))
ax.scatter(aa.ptsm[:, 0], aa.ptsm[:, 1], s=3)
ax.scatter(aa.ptsx[:, 0], aa.ptsx[:, 1], s=3)

<matplotlib.collections.PathCollection at 0x1b5d3ac7f70>