In [None]:
%aiida
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

In [None]:
from aiida.orm.data.structure import StructureData
from aiida.orm.data.parameter import ParameterData
from aiida.orm.data.array import ArrayData
from aiida.orm.data.base import Int, Str, Float, Bool
from aiida.orm.data.remote import RemoteData
from aiida.work import workfunction
from aiida.work.process import WorkCalculation
from aiida.work.run import submit
from aiida_cp2k.calculations import Cp2kCalculation
from aiida.orm import Code, Computer
from aiida.orm.querybuilder import QueryBuilder

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from apps.scanning_probe.hrstm.hrstm_workchain import HRSTMWorkChain

# Select structure

In [None]:
from apps.surfaces.structure_browser import StructureBrowser

def on_struct_change(c):
    update_view()
    
struct_browser = StructureBrowser()
struct_browser.results.observe(on_struct_change, names='value')
viewer = nglview.NGLWidget()
clear_output()

display(ipw.VBox([struct_browser, viewer]))

def update_view():
    # remove old components
    if hasattr(viewer, "component_0"):
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_unitcell()
        cid = viewer.component_0.id
        viewer.remove_component(cid)
        
    structure = struct_browser.results.value
    if not structure:
        return
    
    atoms = structure.get_ase()
    atoms.pbc = [1, 1, 1]
    #print("CELL:", atoms.cell)
    # add new component
    viewer.add_component(nglview.ASEStructure(atoms)) # adds ball+stick
    viewer.add_unitcell()
    viewer.center()
    
    # Orient camera to look from positive z
    cell_z = atoms.cell[2, 2]
    com = atoms.get_center_of_mass()
    def_orientation = viewer._camera_orientation
    top_z_orientation = [1.0, 0.0, 0.0, 0,
                         0.0, 1.0, 0.0, 0,
                         0.0, 0.0, -np.max([cell_z, 30.0]) , 0,
                         -com[0], -com[1], -com[2], 1]
    viewer._set_camera_orientation(top_z_orientation)
    #viewer.camera = 'orthographic'

# Select computer and codes

In [None]:
qb = QueryBuilder()
qb.append(Computer, filters={'enabled': True}, project='name')
computer_names = [comp[0] for comp in qb.all()]

style = {'description_width': '120px'}
layout = {'width': '70%'}
drop_computer = ipw.Dropdown(description="Computer",
                             options=computer_names)

def comp_plugin_codes(computer_name, plugin_name):
    qb = QueryBuilder()
    qb.append(Computer, filters={'enabled': True}, project='name', tag='computer')
    qb.append(Code, project='*', has_computer='computer', filters={
        'attributes.input_plugin': plugin_name,
        'or': [{'extras': {'!has_key': 'hidden'}}, {'extras.hidden': False}]
    })
    qb.order_by({Code: {'id': 'desc'}})
    codes = qb.all()
    sel_codes = []
    for code in codes:
        if code[0] == computer_name:
            sel_codes.append(code[1])
    return sel_codes

def on_computer_change(c):
    global cp2k_codes, ppm_codes, hrstm_codes
    cp2k_codes = comp_plugin_codes(drop_computer.value, 'cp2k')
    ppm_codes = comp_plugin_codes(drop_computer.value, 'spm.afm')
    hrstm_codes = comp_plugin_codes(drop_computer.value, 'spm.hrstm')
    
    drop_cp2k.options = [c.label for c in cp2k_codes]
    drop_ppm.options = [(c.label, c) for c in ppm_codes if "_2pp" in c.label]
    drop_hrstm.options = [c.label for c in hrstm_codes]
    
    
drop_computer.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_ppm = ipw.Dropdown(description="PPM code")

drop_hrstm = ipw.Dropdown(description="HR-STM code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(drop_computer, drop_cp2k, drop_ppm, drop_hrstm, elpa_check)

# High-resolution STM parameters

In [None]:
style = {'description_width': 'initial'}
layout_large = {'width': '450px'}
layout_middle = {'width' : '320px'}
layout_small = {'width': '175px'}

# Sample wavefunction
orbsSam_bit = ipw.BoundedIntText(description = "Maximal orbital for sample wavefunction",
                                 value = 1, min = 0, max = 1, step = 1,
                                 style = style, layout = layout_middle)
pbc_t = ipw.Text(description = "Periodic boundary conditions for sample",
                 value = '1 1 0',
                 style = style, layout = layout_middle)
ecut_bft = ipw.BoundedFloatText(description = "Added energy cutoff (\(e\)V)",
                                value = 0.1, min = 0.01, max = 1.0, step = 0.01,
                                style = style, layout = layout_middle)
fwhm_bft = ipw.BoundedFloatText(description = "FWHM of DoS (\(e\)V)",
                                value = 0.1, min = 0.02, max = 0.5, step = 0.01,
                                style = style, layout = layout_middle)

# Tip information
tipshape_d = ipw.Dropdown(description = "Shape of tip",
                          options = ['Sharp', 'Blunt'], value = 'Sharp',
                          style = style, layout = layout_middle)
orbsTip_bit = ipw.BoundedIntText(description = "Maximal orbital for tip",
                                 value = 1, min = 0, max = 1, step = 1,
                                 style = style, layout = layout_middle)
rotate_c = ipw.Checkbox(description = "Rotate tip coefficients",
                        value = True,
                        style = style, layout = layout_middle)

# Scan information
bias_fs = ipw.FloatRangeSlider(description = "Bias voltages: \(V_{min}\), \(V_{max}\) (\(e\)V)",
                               value = [-2.0, 2.0], min = -4.0, max = 4.0, step=0.1,
                               orientation = 'horizontal',
                               readout = True, readout_format = '.1f',
                               style = style, layout = layout_large)
biasstep_bft = ipw.BoundedFloatText(description = "Bias voltage step (\(e\)V)",
                                    value = 0.1, min = 0.01, max = 1.0, step = 0.01,
                                    style = style, layout = layout_middle)
height_t = ipw.Text(description = "Constant height scans",
                    value = '3.0 5.0',
                    style = style, layout = layout_middle)
scanstepx_bft = ipw.BoundedFloatText(description = "Scan d\(x\) (\(\r{A}ng\))",
                                     value = 0.1, min = 0.05, max = 0.5, step = 0.05,
                                     style = style, layout = layout_small)
scanstepy_bft = ipw.BoundedFloatText(description = "Scan d\(y\) (\(\r{A}ng\))",
                                     value = 0.1, min = 0.05, max = 0.5, step = 0.05,
                                     style = style, layout = layout_small)
scanstepz_bft = ipw.BoundedFloatText(description = "Scan d\(z\) (\(\r{A}ng\))",
                                     value = 0.1, min = 0.05, max = 0.5, step = 0.05,
                                     style = style, layout = layout_small)


display(orbsSam_bit, pbc_t, ecut_bft, fwhm_bft,
        tipshape_d, orbsTip_bit, rotate_c,
        bias_fs, biasstep_bft, height_t, scanstepx_bft, scanstepy_bft, scanstepz_bft)

In [None]:
ppm_params_list = {"Sharp": {
    # TODO tip specific
    'ChargeCuUp':   -0.0669933,
    'ChargeCuDown': -0.0627402,
    'Ccharge':      0.212718,
    'Ocharge':      -0.11767,
    'sigma':        0.7,
    'Cklat':        0.24600212465950813,
    'Oklat':        0.15085476515590224,
    'Ckrad':        20,
    'Okrad':        20,
    'rC0':          [0.0, 0.0, 1.82806112489999961213],
    'rO0':          [0.0, 0.0, 1.14881347770000097341],
},  
                   "Blunt": {
    # TODO tip specific
    'ChargeCuUp':   -0.0669933,
    'ChargeCuDown': -0.0627402,
    'Ccharge':      0.212718,
    'Ocharge':      -0.11767,
    'sigma':        0.7,
    'Cklat':        0.24600212465950813,
    'Oklat':        0.15085476515590224,
    'Ckrad':        20,
    'Okrad':        20,
    'rC0':          [0.0, 0.0, 1.82806112489999961213],
    'rO0':          [0.0, 0.0, 1.14881347770000097341],
},}

In [None]:
def find_struct_wf(structure):
    qb = QueryBuilder()
    qb.append(Computer, filters={'enabled': True})
    selected_comp = qb.all()[drop_computer.index][0]
    
    # check stm
    extras = structure.get_extras()
    for ex_k in extras.keys():
        if ex_k.startswith(('stm', 'pdos', 'afm', 'hrstm')):
            spm_workchain = load_node(extras[ex_k])
            cp2k_scf_calc = spm_workchain.get_outputs()[0]
            if cp2k_scf_calc.get_computer().hostname == selected_comp.hostname:
                wfn_path = cp2k_scf_calc.out.remote_folder.get_remote_path() + "/aiida-RESTART.wfn"
                # check if it exists
                file_exists = ! ssh {selected_comp.hostname} "if [ -f {wfn_path} ]; then echo 1 ; else echo 0 ; fi"
                if file_exists[0]:
                    print("Found .wfn from %s"%ex_k)
                    return wfn_path
    
    # check geo opt
    geo_opt_calc = structure.get_inputs()[0]
    geo_comp = geo_opt_calc.get_computer()
    if geo_comp.hostname == selected_comp.hostname:
        wfn_path = geo_opt_calc.out.remote_folder.get_remote_path() + "/aiida-RESTART.wfn"
        # check if it exists
        file_exists = ! ssh {selected_comp.hostname} "if [ -f {wfn_path} ]; then echo 1 ; else echo 0 ; fi"
        if file_exists[0]:
            print("Found .wfn from geo_opt")
            return wfn_path
    return ""

def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not drop_computer.value:
            print("Please select a computer.")
            return
        
        parent_dir = "parent_calc_folder/"
        
        struct = struct_browser.results.value
        ase_geom = struct.get_ase()
        
        cell = ArrayData()
        cell.set_array('cell', np.diag(ase_geom.cell))
        
        # Bias voltages
        biases = ''
        for bias in bias_fs.value:
            biases += ' ' + str(bias)
        
        # Sample cells
        cell = ase_geom.cell
        top_z = np.max(ase_geom.positions[:, 2])
        # Scanning steps
        dx = scanstepx_bft.value
        dy = scanstepy_bft.value
        dz = scanstepz_bft.value
        # Requested scanning heights
        heights = list(map(float, height_t.value.split()))
        ppm_params_dict = ppm_params_list[tipshape_d.value]
        ppm_params_dict.update({
            'Catom':        'Ctip',
            'Oatom':        'Otip',
            # TODO maybe should be true?
            'PBC':          'False',
            'gridA':        list(ase_geom.cell[0]),
            'gridB':        list(ase_geom.cell[1]),
            'gridC':        list(ase_geom.cell[2]),
            'scanMin':      [0.0, 0.0, top_z+np.min(heights)-dz],
            'scanMax':      [ase_geom.cell[0,0], ase_geom.cell[1,1], top_z+np.max(heights)+dz],
            'scanStep':     [dx,dy,dz],
            'Amplitude':    1.4,
            'f0Cantilever': 22352.5
        })
        ppm_params = ParameterData(dict=ppm_params_dict)
        # Extra tip informations
        tip_shift = str(ppm_params_dict["rC0"]+ppm_params_dict["rO0"])
        pdos = pdos_list[tipshape_d.value]

        hrstm_params = ParameterData(dict={
            '--output_file':  'hrstm.npz',
            '--voltages':     biases,
            # Sample information
            '--cp2k_input_s': parent_dir+'aiida.inp',
            '--basis_sets_s': parent_dir+'BASIS_MOLOPT',
            '--xyz_s':        parent_dir+'geom.xyz',
            '--coeffs_s':     parent_dir+'aiida-RESTART.wfn',
            '--orbs_sam':     str(orbsSam_bit.value),
            '--pbc':          str(pbc_t.value),
            '--emin':         str(bias_fs.value[0]-ecut_bft.value),
            '--emax':         str(bias_fs.value[1]+ecut_bft.value),
            '--fwhm':         str(fwhm_bit.value),
            # Tip information
            '--pdos_list':    pdos_list,
            '--orbs_tip':     str(orbsTip_bit.value),
            '--tip_shift':    tip_shift,
            '--etip':         str(0.0),
            'rotate':         rotate_c,
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        ppm_code = ppm_codes[drop_ppm.index]
        hrstm_code = hrstm_codes[drop_hrstm.index]
   
        ## Try to access the restart-wfn file ##
        try:
            wfn_file_path = find_struct_wf(struct)
        except:
            wfn_file_path = ""
        if wfn_file_path == "":
            print("Didn't find any accessible .wfn file.")
         
        outputs = submit(
            HRSTMWorkChain,
            cp2k_code=cp2k_code,
            structure=struct,
            cell=cell,
            wfn_file_path=Str(wfn_file_path),
            elpa_switch=Bool(elpa_check.value),
            ppm_code=ppm_code,
            ppm_params=ppm_param,
            hrstm_code=hrstm_code,
            hrstm_params=hrstm_params
        )
                
        print(outputs)

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)