In [None]:
%aiida
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

In [None]:
from aiida.orm.data.structure import StructureData
from aiida.orm.data.parameter import ParameterData
from aiida.orm.data.array import ArrayData
from aiida.orm.data.base import Int, Str, Float, Bool
from aiida.orm.data.remote import RemoteData
from aiida.work import workfunction
from aiida.work.process import WorkCalculation
from aiida.work.run import submit
from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import Code, Computer
from aiida.orm.querybuilder import QueryBuilder

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from apps.scanning_probe.orb.orb_workchain import OrbitalWorkChain

from apps.scanning_probe import common

from apps.surfaces.widgets import find_mol
from apps.surfaces.widgets.viewer_details import ViewerDetails

# Select structure

In [None]:
from apps.surfaces.structure_browser import StructureBrowser

atoms = None
slab_analyzed = None

def on_struct_change(c):
    global atoms, slab_analyzed
    structure = struct_browser.results.value
    if structure:
        atoms = structure.get_ase()
        atoms.pbc = [1, 1, 1]
        
        #slab_analyzed = find_mol.analyze_slab(atoms)
        viewer_widget.setup(atoms)
        
        cell_text.value = " ".join([str(c) for c in np.diag(atoms.cell)])

    
struct_browser = StructureBrowser()
struct_browser.results.observe(on_struct_change, names='value')

viewer_widget = ViewerDetails()

display(ipw.VBox([struct_browser, viewer_widget]))


# Select computer and codes

In [None]:
qb = QueryBuilder()
qb.append(Computer, filters={'enabled': True}, project='name')
computer_names = [comp[0] for comp in qb.all()]

style = {'description_width': '120px'}
layout = {'width': '70%'}
drop_computer = ipw.Dropdown(description="Computer",
                             options=computer_names)

def comp_plugin_codes(computer_name, plugin_name):
    qb = QueryBuilder()
    qb.append(Computer, filters={'enabled': True}, project='name', tag='computer')
    qb.append(Code, project='*', has_computer='computer', filters={
        'attributes.input_plugin': plugin_name,
        'or': [{'extras': {'!has_key': 'hidden'}}, {'extras.hidden': False}]
    })
    qb.order_by({Code: {'id': 'desc'}})
    codes = qb.all()
    sel_codes = []
    for code in codes:
        if code[0] == computer_name:
            sel_codes.append(code[1])
    return sel_codes

def on_computer_change(c):
    global cp2k_codes, stm_codes
    cp2k_codes = comp_plugin_codes(drop_computer.value, 'cp2k')
    stm_codes = comp_plugin_codes(drop_computer.value, 'spm.stm')
    
    drop_cp2k.options = [c.label for c in cp2k_codes]
    drop_stm.options = [c.label for c in stm_codes]
    
    
drop_computer.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_stm = ipw.Dropdown(description="STM code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(drop_computer, drop_cp2k, drop_stm, elpa_check)

# SCF parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

cell_text = ipw.Text(placeholder='10.0 20.0 30.0',
                    description='cell',
                    style=style, layout={'width': '370px'})
display(cell_text)

In [None]:
def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, vis_spin_button, multiplicity_text]:
        w.disabled = not uks_switch.value

def visualize_spin_guess(b):
    spin_up = [int(v)-1 for v in spin_up_text.value.split()]
    spin_dw = [int(v)-1 for v in spin_dw_text.value.split()]
    viewer_widget.reset()
    viewer_widget.highlight_atoms(spin_up, color='red', size=0.3, opacity=0.4)
    viewer_widget.highlight_atoms(spin_dw, color='blue', size=0.3, opacity=0.4)

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              tooltip='VDW_POTENTIAL',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1 2 3',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='1 2 3',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})
vis_spin_button = ipw.Button(description="Visualize",
                             disabled=True,
                             style = {'description_width': '0px'}, layout={'width': '75px'})
vis_spin_button.on_click(visualize_spin_guess)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})


display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text]), vis_spin_button]), multiplicity_text)

# Orbital parameters

In [None]:
style = {'description_width': '140px'}
layout = {'width': '40%'}
layout_small = {'width': '25%'}

n_homo_inttext = ipw.IntText(
                        description='num HOMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)
n_lumo_inttext = ipw.IntText(
                        description='num LUMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)

heights_text = ipw.Text(description='Heights (ang)',
                              value='3.0 5.0',
                              style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=3.0,
                        style=style, layout=layout_small)


display(n_homo_inttext, n_lumo_inttext, heights_text, extrap_plane_floattext)

In [None]:
def find_struct_available_wfn(structure_node, current_hostname):
    # can't move to python module because of "!"
    qb = QueryBuilder()
    qb.append(Computer, filters={'enabled': True})
    
    # check spm
    extras = structure_node.get_extras()
    for ex_k in extras.keys():
        if ex_k.startswith(('stm', 'pdos', 'afm', 'orb')):
            spm_workchain = load_node(extras[ex_k])
            cp2k_scf_calc = spm_workchain.get_outputs()[0]
            if cp2k_scf_calc.get_computer().hostname == current_hostname:
                wfn_path = cp2k_scf_calc.out.remote_folder.get_remote_path() + "/aiida-RESTART.wfn"
                # check if it exists
                file_exists = ! ssh {current_hostname} "if [ -f {wfn_path} ]; then echo 1 ; else echo 0 ; fi"
                if file_exists[0]:
                    print("Found .wfn from %s"%ex_k)
                    return wfn_path
    
    # check geo opt
    if len(structure_node.get_inputs()) > 0:
        geo_opt_calc = structure_node.get_inputs()[0]
        geo_comp = geo_opt_calc.get_computer()
        if geo_comp.hostname == current_hostname:
            wfn_path = geo_opt_calc.out.remote_folder.get_remote_path() + "/aiida-RESTART.wfn"
            # check if it exists
            file_exists = ! ssh {current_hostname} "if [ -f {wfn_path} ]; then echo 1 ; else echo 0 ; fi"
            if file_exists[0]:
                print("Found .wfn from geo_opt")
                return wfn_path
    return ""

def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not drop_computer.value:
            print("Please select a computer.")
            return
        
        cell = np.array(cell_text.value.split(), dtype=float)
        
        if len(cell) != 3 or any(cell) <= 0.0:
            print("Invalid cell")
            return
        
        dft_params_dict = {
            'mgrid_cutoff':    600,
            'elpa_switch':     elpa_check.value,
            'cell':            list(cell),
            'uks':             uks_switch.value
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
        dft_params = ParameterData(dict=dft_params_dict)

        extrap_plane = extrap_plane_floattext.value
        parent_dir = "parent_calc_folder/"
        stm_params = ParameterData(dict={
            '--cp2k_input_file':    parent_dir+'aiida.inp',
            '--basis_set_file':     parent_dir+'BASIS_MOLOPT',
            '--xyz_file':           parent_dir+'geom.xyz',
            '--wfn_file':           parent_dir+'aiida-RESTART.wfn',
            '--hartree_file':       parent_dir+'aiida-HART-v_hartree-1_0.cube',
            '--orb_output_file':    'orb.npz',
            '--emin':               '-2.0',
            '--emax':               '2.0',
            '--eval_region':        ['G', 'G', 'G', 'G', 'n-1.0_C', 'p%.1f'%extrap_plane],
            '--dx':                 '0.15',
            '--eval_cutoff':        '14.0',
            '--extrap_extent':      '5.0',
            '--orb_heights':        heights_text.value.split(),
            '--n_homo_ch':          str(n_homo_inttext.value),
            '--n_lumo_ch':          str(n_lumo_inttext.value),
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        stm_code = stm_codes[drop_stm.index]
        
        struct = struct_browser.results.value
        
        ## Try to access the restart-wfn file ##
        hostname = cp2k_code.get_computer().hostname
        wfn_file_path = find_struct_available_wfn(struct, hostname)
        if wfn_file_path == "":
            print("Didn't find any accessible .wfn file.")
    
        outputs = submit(OrbitalWorkChain,
                 cp2k_code=cp2k_code,
                 structure=struct,
                 wfn_file_path=Str(wfn_file_path),
                 dft_params=dft_params,
                 stm_code=stm_code,
                 stm_params=stm_params
                )
                
        print(outputs)

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)