In [None]:
%aiida
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

In [None]:
from aiida.orm.data.structure import StructureData
from aiida.orm.data.parameter import ParameterData
from aiida.orm.data.array import ArrayData
from aiida.orm.data.base import Int, Str, Float, Bool
from aiida.orm.data.remote import RemoteData
from aiida.work import workfunction
from aiida.work.process import WorkCalculation
from aiida.work.run import submit
from aiida_cp2k.calculations import Cp2kCalculation
from aiida.orm import Code, Computer
from aiida.orm.querybuilder import QueryBuilder

import ase
import ase.io
import numpy as np
import os
import nglview
from copy import deepcopy
from pprint import pprint
from collections import defaultdict

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from math import floor, log10

from apps.scanning_probe.hrstm.hrstm_workchain import HRSTMWorkChain

# Select structure

In [None]:
from apps.surfaces.structure_browser import StructureBrowser

def on_struct_change(c):
    update_view()
    
struct_browser = StructureBrowser()
struct_browser.results.observe(on_struct_change, names='value')
viewer = nglview.NGLWidget()
clear_output()

display(ipw.VBox([struct_browser, viewer]))

def update_view():
    # remove old components
    if hasattr(viewer, "component_0"):
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_unitcell()
        cid = viewer.component_0.id
        viewer.remove_component(cid)
        
    structure = struct_browser.results.value
    if not structure:
        return
    
    atoms = structure.get_ase()
    atoms.pbc = [1, 1, 1]
    #print("CELL:", atoms.cell)
    # add new component
    viewer.add_component(nglview.ASEStructure(atoms)) # adds ball+stick
    viewer.add_unitcell()
    viewer.center()
    
    # Orient camera to look from positive z
    cell_z = atoms.cell[2, 2]
    com = atoms.get_center_of_mass()
    def_orientation = viewer._camera_orientation
    top_z_orientation = [1.0, 0.0, 0.0, 0,
                         0.0, 1.0, 0.0, 0,
                         0.0, 0.0, -np.max([cell_z, 30.0]) , 0,
                         -com[0], -com[1], -com[2], 1]
    viewer._set_camera_orientation(top_z_orientation)
    #viewer.camera = 'orthographic'

# Select computer and codes

In [None]:
qb = QueryBuilder()
qb.append(Computer, filters={'enabled': True}, project='name')
computer_names = [comp[0] for comp in qb.all()]

style = {'description_width': '120px'}
layout = {'width': '70%'}
drop_computer = ipw.Dropdown(description="Computer",
                             options=computer_names)

def comp_plugin_codes(computer_name, plugin_name):
    qb = QueryBuilder()
    qb.append(Computer, filters={'enabled': True}, project='name', tag='computer')
    qb.append(Code, project='*', has_computer='computer', filters={
        'attributes.input_plugin': plugin_name,
        'or': [{'extras': {'!has_key': 'hidden'}}, {'extras.hidden': False}]
    })
    qb.order_by({Code: {'id': 'desc'}})
    codes = qb.all()
    sel_codes = []
    for code in codes:
        if code[0] == computer_name:
            sel_codes.append(code[1])
    return sel_codes

def on_computer_change(c):
    global cp2k_codes, ppm_codes, hrstm_codes
    cp2k_codes = comp_plugin_codes(drop_computer.value, 'cp2k')
    ppm_codes = comp_plugin_codes(drop_computer.value, 'spm.afm')
    hrstm_codes = comp_plugin_codes(drop_computer.value, 'spm.hrstm')
    
    drop_cp2k.options = [c.label for c in cp2k_codes]
    drop_ppm.options = [(c.label, c) for c in ppm_codes if "_2pp" in c.label]
    drop_hrstm.options = [c.label for c in hrstm_codes]
    
    
drop_computer.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_ppm = ipw.Dropdown(description="PPM code")

drop_hrstm = ipw.Dropdown(description="HR-STM code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(drop_computer, drop_cp2k, drop_ppm, drop_hrstm, elpa_check)

# Simulation Parameters

### Probe Particle Model Parameters

In [None]:
style = {'description_width': 'initial'}
layout = {'width': '60%'}
scandx_ipw = ipw.BoundedFloatText(description="Scan d\(\mathbf{x}\) (\(Ang\))",
                                  value=0.1, min=0.05, max=0.5, step=0.05,
                                  style=style, layout=layout)
# NOTE: Minimum value of 3.0 because that is the size of the tip!
scanmin_ipw = ipw.BoundedFloatText(description="Scan \(z_{min}\) (\(Ang)\)",
                                   value=4.5, min=3.0, max=10.0, step=0.1,
                                   style=style, layout=layout)
scanmax_ipw = ipw.BoundedFloatText(description="Scan \(z_{max}\) (\(Ang\))",
                                   value=7.5, min=3.0, max=10.0, step=0.1,
                                   style=style, layout=layout)
amp_ipw = ipw.FloatText(description="Amplitude (\(Ang\))",
                        value=1.4, step=0.1,
                        style=style, layout=layout)
f0cantilever_ipw = ipw.FloatText(description="Cantilever (\(f_0\)) ",
                                 value=22352.5, step=0.1,
                                 style=style, layout=layout)
tpp_resp_ipw = ipw.ToggleButtons(description="2PP RESP model",
                            options = { # ChargeCuUp, ChargeCuDown, Ccharge, Ocharge
                                'pentacene': [-0.0669933, -0.0627402, 0.212718, -0.11767],
                                'ptcda':     [     -0.05,      -0.07,     0.23,    -0.13]
                            },
                            style=style, layout=layout)
# Show
display(scandx_ipw, scanmin_ipw, scanmax_ipw,
        amp_ipw, f0cantilever_ipw, tpp_resp_ipw)

### High-Resolution STM Parameters

In [None]:
style = {'description_width': 'initial'}
layout = {'width': '60%'}
layout_min = {'width': '25%'}

voltext_ipw = ipw.Label(value="Voltage Range Slider:",
                        style=style, layout=layout_min)
volstep_ipw = ipw.BoundedFloatText(description="Step",
                                   value=0.1, min=0.01, max=0.5, step=0.01,
                                   style=style, layout=layout_min)
volmin_ipw = ipw.FloatText(description="Min",
                           value=-0.5,
                           style=style, layout=layout_min)
volmax_ipw = ipw.FloatText(description="Max",
                           value=0.5,
                           style=style, layout=layout_min)
voltages_ipw = ipw.FloatRangeSlider(description="Bias Voltage Range (\(e\)V)",
                                    value=[-0.3,0.3], min=volmin_ipw.value, max=volmax_ipw.value,
                                    step=volstep_ipw.value,readout_format='.2f',
                                    style=style, layout=layout)
def voltage_slider(vmin, vmax, vstep):
    voltages_ipw.min = min(vmin,vmax)
    voltages_ipw.max = max(vmin,vmax)
    voltages_ipw.value = [round((voltages_ipw.value[0]-voltages_ipw.min)/vstep)*vstep+voltages_ipw.min,
                          round((voltages_ipw.value[1]-voltages_ipw.min)/vstep)*vstep+voltages_ipw.min]
    voltages_ipw.step = vstep
volrange_ipw = ipw.interactive(voltage_slider, vmin=volmin_ipw, vmax=volmax_ipw, vstep=volstep_ipw)

fwhm_ipw = ipw.BoundedFloatText(description="FWHM for DoS of Sample (\(e\)V)",
                                value=0.05, min=0.01, max=0.5, step=0.01,
                                style=style, layout=layout)
wfnstep_ipw = ipw.BoundedFloatText(description="Meshwidth for Grid Orbitals d\(\mathbf{x}\) (\(Ang\))",
                                   value=0.2, min=0.05, max=1.0, step=0.05,
                                   style=style, layout=layout)
extrap_ipw = ipw.BoundedFloatText(description="Extrapolation Plane (\(Ang\))",
                                  value=4.0, min=1.0, max=10.0, step=0.1,
                                  style=style, layout=layout)
# Tip stuff
tiptype_ipw = ipw.ToggleButtons(description="Tip Type",
                                value='blunt', options=['parametrized', 'blunt', 'sharp'],
                                style=style, layout=layout)
rotate_ipw = ipw.Checkbox(description="Rotate Tip Coefficients",
                          value=True,
                          style=style, layout=layout)
orbstip_ipw = ipw.BoundedIntText(description="Maximal Tip Orbital",
                                 value=1, min=0, max=1, step=1,
                                 style=style, layout=layout)
fwhmtip_ipw = ipw.BoundedFloatText(description="FWHM for DoS of Tip (\(e\)V)",
                                value=0.00, min=0.00, max=1.0, step=0.01,
                                disabled=(tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
## Parametrized tip info
stip_ipw = ipw.BoundedFloatText(description="\(s\)-Value",
                                value=0.1, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
pytip_ipw = ipw.BoundedFloatText(description="\(p_y\)-Value",
                                value=0.5, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
pztip_ipw = ipw.BoundedFloatText(description="\(p_z\)-Value",
                                value=0.0, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
pxtip_ipw = ipw.BoundedFloatText(description="\(p_x\)-Value",
                                value=0.5, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
para_list = [stip_ipw, pytip_ipw, pztip_ipw, pxtip_ipw, fwhmtip_ipw]
def para_values(value):
    if value=='parametrized':
        stip_ipw.disabled = False
        pytip_ipw.disabled = False
        pztip_ipw.disabled = False
        pxtip_ipw.disabled = False
        fwhmtip_ipw.disabled = True
    else:
        stip_ipw.disabled = True
        pytip_ipw.disabled = True
        pztip_ipw.disabled = True
        pxtip_ipw.disabled = True
        fwhmtip_ipw.disabled = False
para_ipw = ipw.interactive(para_values, value=tiptype_ipw)
# Show
display(ipw.HBox([voltext_ipw,volmin_ipw,volmax_ipw,volstep_ipw], style=style, layout=layout), voltages_ipw)
display(fwhm_ipw, wfnstep_ipw, extrap_ipw, 
        tiptype_ipw, rotate_ipw, orbstip_ipw, fwhmtip_ipw, stip_ipw, pytip_ipw, pztip_ipw, pxtip_ipw)

In [None]:
# Taken from AFM and slightly adjusted
def create_2pp_parameterdict(ase_geom):
    cell = ase_geom.cell
    top_z = np.max(ase_geom.positions[:, 2])
    dx = scandx_ipw.value
    resp = tpp_resp_ipw.value
    paramdict = {
        'Catom':        'Ctip',
        'Oatom':        'Otip',
        'ChargeCuUp':   resp[0],
        'ChargeCuDown': resp[1],
        'Ccharge':      resp[2],
        'Ocharge':      resp[3],
        'sigma':        0.7,
        'Cklat':        0.24600212465950813,
        'Oklat':        0.15085476515590224,
        'Ckrad':        20,
        'Okrad':        20,
        'rC0':          [0.0, 0.0, 1.82806112489999961213],
        'rO0':          [0.0, 0.0, 1.14881347770000097341],
        # We have periodic boundaries
        'PBC':          'True',
        'gridA':        list(cell[0]),
        'gridB':        list(cell[1]),
        'gridC':        list(cell[2]),
        'scanMin':      [0.0, 0.0, np.round(top_z, 1)+scanmin_ipw.value],
        'scanMax':      [cell[0,0], cell[1,1], np.round(top_z, 1)+scanmax_ipw.value],
        'scanStep':     [dx, dx, dx],
        'Amplitude':    amp_ipw.value,
        'f0Cantilever': f0cantilever_ipw.value
    }
    return paramdict

In [None]:
def find_struct_wf(structure):
    selected_comp = qb.all()[drop_computer.index][0]
    
    # check stm
    extras = structure.get_extras()
    for ex_k in extras.keys():
        if ex_k.startswith(('stm', 'pdos', 'afm', 'hrstm')):
            spm_workchain = load_node(extras[ex_k])
            cp2k_scf_calc = spm_workchain.get_outputs()[0]
            if cp2k_scf_calc.get_computer().hostname == selected_comp.hostname:
                wfn_path = cp2k_scf_calc.out.remote_folder.get_remote_path() + "/aiida-RESTART.wfn"
                # check if it exists
                file_exists = ! ssh {selected_comp.hostname} "if [ -f {wfn_path} ]; then echo 1 ; else echo 0 ; fi"
                if file_exists[0]:
                    print("Found .wfn from %s"%ex_k)
                    return wfn_path
    
    # check geo opt
    geo_opt_calc = structure.get_inputs()[0]
    geo_comp = geo_opt_calc.get_computer()
    if geo_comp.hostname == selected_comp.hostname:
        wfn_path = geo_opt_calc.out.remote_folder.get_remote_path() + "/aiida-RESTART.wfn"
        # check if it exists
        file_exists = ! ssh {selected_comp.hostname} "if [ -f {wfn_path} ]; then echo 1 ; else echo 0 ; fi"
        if file_exists[0]:
            print("Found .wfn from geo_opt")
            return wfn_path
    return ""

In [None]:
def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not drop_computer.value:
            print("Please select a computer.")
            return

        cp2k_code = cp2k_codes[drop_cp2k.index]
        ppm_code = ppm_codes[drop_ppm.index]
        hrstm_code = hrstm_codes[drop_hrstm.index]
        
        # External folders
        parent_dir = "parent_calc_folder/"
        ppm_dir = "ppm_calc_folder/"
        
        struct = struct_browser.results.value
        ase_geom = struct.get_ase()   
        cell = ArrayData()
        cell.set_array('cell', np.diag(ase_geom.cell))
        
        # PPM parameters
        ppm_params_dict = create_2pp_parameterdict(ase_geom)
        ppm_params = ParameterData(dict=ppm_params_dict)
        
        # PPM folder of position
        ppmQK = ppm_dir+"Q%1.2fK%1.2f/" %(ppm_params_dict['Ocharge'], ppm_params_dict['Oklat'])
        # Tip type to determine PDOS and PPM position files
        if tiptype_ipw.value != "parametrized":
            pdos_list = tiptype_ipw.value
            path = os.path.dirname(hrstm_code.get_remote_exec_path())+"/tips/"+tiptype_ipw.value+"/"
            pdos_list = [path+"aiida-PDOS-list2-1.pdos", path+"aiida-PDOS-list1-1.pdos"]
            tip_pos = [ppmQK+"PPpos", ppmQK+"PPdisp"]
        else: # Parametrized tip
            pdos_list = [str(stip_ipw.value), str(pytip_ipw.value), 
                         str(pztip_ipw.value), str(pxtip_ipw.value)]
            tip_pos = ppmQK+"PPdisp"

        # HRSTM parameters
        hrstm_params_dict = {
            '--output':          'hrstm',
            '--voltages':        np.round(np.arange(voltages_ipw.value[0],
                                                    voltages_ipw.value[1]+voltages_ipw.step, 
                                                    voltages_ipw.step), 
                                          len(str(volstep_ipw.min).split('.')[-1])).tolist(),
            # Sample information
            '--cp2k_input_file': parent_dir+'aiida.inp',
            '--basis_set_file':  parent_dir+'BASIS_MOLOPT',
            '--xyz_file':        parent_dir+'geom.xyz',
            '--wfn_file':        parent_dir+'aiida-RESTART.wfn',
            '--hartree_file':    parent_dir+'aiida-HART-v_hartree-1_0.cube',
            '--emin':            str(voltages_ipw.value[0]-2.0*fwhm_ipw.value),
            '--emax':            str(voltages_ipw.value[1]+2.0*fwhm_ipw.value),
            '--fwhm_sam':        str(fwhm_ipw.value),
            '--dx_wfn':          str(wfnstep_ipw.value),
            '--extrap_dist':     str(extrap_ipw.value),
            # Tip information
            '--pdos_list':       pdos_list,
            '--orbs_tip':        str(orbstip_ipw.value),
            '--tip_shift':       str(ppm_params_dict["rC0"][2]+ppm_params_dict["rO0"][2]),
            '--tip_pos_files':   tip_pos,
            '--fwhm_tip':        str(fwhmtip_ipw.value),
        }
        if rotate_ipw.value:
            hrstm_params_dict['--rotate'] = ''
        hrstm_params = ParameterData(dict=hrstm_params_dict)
   
        ## Try to access the restart-wfn file ##
        try:
            wfn_file_path = find_struct_wf(struct)
        except:
            wfn_file_path = ""
        if wfn_file_path == "":
            print("Didn't find any accessible .wfn file.")
        
        outputs = submit(
            HRSTMWorkChain,
            cp2k_code=cp2k_code,
            structure=struct,
            cell=cell,
            wfn_file_path=Str(wfn_file_path),
            elpa_switch=Bool(elpa_check.value),
            ppm_code=ppm_code,
            ppm_params=ppm_params,
            hrstm_code=hrstm_code,
            hrstm_params=hrstm_params
        )
                
        print(outputs)

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)