In [1]:
%matplotlib
%load_ext autoreload
%autoreload 2
import numpy as np
from IPython.display import clear_output
import lloyds
from poisson_disc import Grid
from labels import project_positions,create_label
from ase import Atoms,Atom
from pyqstem import PyQSTEM
from pyqstem.imaging import CTF
from pyqstem.util import atoms_plot
import matplotlib.pyplot as plt
import matplotlib.path as mplPath
from scipy.spatial import Voronoi
from ase.io import write,read
import scipy.spatial
from ase.visualize import view

Using matplotlib backend: TkAgg


In [2]:
def lookup_nearest(x0, y0, x, y, z):
    xi = np.abs(x-x0).argmin()
    yi = np.abs(y-y0).argmin()
    return z[yi,xi]

def strain(positions,direction,cell,power=-3,amplitude=10**3,N=(64,64)):
    noise=spectral_noise.power_law_noise(N,power)
    x=np.linspace(0,cell[0],N[0])
    y=np.linspace(0,cell[1],N[1])

    positions[:,direction]+=amplitude*np.array([lookup_nearest(p[0], p[1], x, y, noise) for p in positions]).T
    
    return positions

def random_sheet(cell,r):
    
    grid = Grid(r, cell[0], cell[1])
    
    rand = (np.random.uniform(0, cell[0]), np.random.uniform(0, cell[1]))
    positions = grid.poisson(rand)
    
    positions=lloyds.repeat(positions,cell[:2])
    
    vor=Voronoi(positions)
    
    positions=vor.vertices
    
    positions=positions[positions[:,0]<cell[0]]
    positions=positions[positions[:,0]>0]
    positions=positions[positions[:,1]<cell[1]]
    positions=positions[positions[:,1]>0]
    
    positions=lloyds.relax(positions,cell[:2],num_iter=1,bc='periodic')
    
    num_holes=np.random.randint(0,3)
    for i in range(num_holes):
        size=(.4+.6*np.random.rand())*cell[0]
        hole=size*blob()+[np.random.uniform(0, cell[0]), np.random.uniform(0, cell[1])]
        contained = mplPath.Path(hole).contains_points(positions)
        positions = positions[np.where(contained==0)[0]]
    
    positions=np.hstack((positions,np.array([[0]*len(positions)]).T))
    
    atoms=Atoms(['C']*len(positions),positions)

    atoms.set_cell(cell)
    atoms.set_positions(positions)
    atoms.set_pbc(1)
    atoms.wrap()
    atoms.center()
    
    print(len(atoms)/(cell[0]*cell[1]),2/5.24)
    
    return atoms

In [33]:
N=360
sampling=24.48/2048*10
L=sampling*N
cell=(L,L,5)

atoms=random_sheet(cell,1.9)
#atoms=grains_sheet(cell,8)

sites = atoms.get_positions()


mos2=Atoms()
mos2.set_cell(atoms.get_cell()*1.2)


dz=.5

for site in sites:
    
    r=np.random.rand()
    
    if r < .5:
        mos2+=Atom('Mo',position=site*1.2)
    elif r < 1:
        mos2+=Atom('S',position=site*1.2+[0,0,dz])
        mos2+=Atom('S',position=site*1.2+[0,0,-dz])
        
        

#atoms_plot(mos2)

view(mos2)

0.37371296995088416 0.38167938931297707


In [34]:
num_examples=1
#num_examples=len(examples)

dir_name='graphene-random'
first_number=0
label_size=(N,N)

for i in range(num_examples):
    #atoms=examples[i]
    #atoms=random_sheet(cell,1.9)
    
    atoms=mos2
    
    qstem=PyQSTEM('TEM')
    
    image_size=(int(atoms.get_cell()[0,0]*12),int(atoms.get_cell()[1,1]*12))

    qstem.set_atoms(atoms)
    qstem.build_wave('plane',80,image_size)
    qstem.build_potential(int(atoms.get_cell()[2,2]*2))
    qstem.run()
    wave=qstem.get_wave()
    wave.array=wave.array.astype(np.complex64)
    
    positions=project_positions(atoms,distance=0)/sampling
    classes=[0]*len(positions)
    label=create_label(positions,label_size,6)
    
    #np.save('../data/{0}/label/label_{1:04d}.npy'.format(dir_name,first_number+i),label)
    #write('../data/{0}/model/model_{1:04d}.cfg'.format(dir_name,first_number+i),atoms)
    #wave.save('../data/{0}/wave/wave_{1:04d}.npz'.format(dir_name,first_number+i))
    
    print('iteration',i)
    clear_output(wait=True)

iteration 0


In [40]:
wave.view()

In [39]:
sampling=24.48/2048*10
Cs=-11.5*10**4
defocus=-80/400000.*Cs+200
focal_spread=20


ctf=CTF(defocus=defocus,Cs=Cs,focal_spread=focal_spread)
image=wave.apply_ctf(ctf).detect(resample=sampling)

print(image.shape,label.shape)

fig,(ax1,ax2,ax3)=plt.subplots(1,3,figsize=(10,4))

atoms_plot(atoms,ax=ax1)
ax1.axis('off')

ax2.imshow(np.flipud(image.T),cmap='gray');
ax2.axis('off')

ax3.imshow(np.flipud(label[:,:,0].T))
ax3.axis('off')

#plt.tight_layout()

(432, 432) (360, 360, 1)


(-0.5, 359.5, 359.5, -0.5)