In [None]:
%aiida

In [None]:

from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import Code, Computer
from aiida.orm.querybuilder import QueryBuilder
from aiida.engine import submit, run, run_get_node

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from apps.scanning_probe.orb.orb_workchain import OrbitalWorkChain

from apps.scanning_probe import common

from apps.surfaces.widgets import analyze_structure
from apps.surfaces.widgets.viewer_details import ViewerDetails

from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget, MetadataWidget, StructureBrowserWidget
from aiidalab_widgets_base import ComputerDropdown

# Select structure

In [None]:
atoms = None
slab_analyzed = None

def on_struct_change(c):
    global atoms, slab_analyzed
    structure = struct_browser.results.value
    if structure:
        atoms = structure.get_ase()
        atoms.pbc = [1, 1, 1]
        
        #slab_analyzed = find_mol.analyze_slab(atoms)
        viewer_widget.setup(atoms)
        
        cell_text.value = " ".join([str(c) for c in np.diag(atoms.cell)])

    
struct_browser = StructureBrowserWidget()
struct_browser.results.observe(on_struct_change, names='value')

viewer_widget = ViewerDetails()

display(ipw.VBox([struct_browser, viewer_widget]))


# Select computer and codes

In [None]:
computer_drop = ComputerDropdown()

def on_computer_change(c):
    global cp2k_codes, stm_codes
    cp2k_codes = common.comp_plugin_codes(computer_drop.selected_computer.name, 'cp2k')
    stm_codes = common.comp_plugin_codes(computer_drop.selected_computer.name, 'spm.stm')
    
    drop_cp2k.options = [c.label for c in cp2k_codes]
    drop_stm.options = [c.label for c in stm_codes]
    
    
computer_drop._dropdown.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_stm = ipw.Dropdown(description="STM code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(computer_drop, drop_cp2k, drop_stm, elpa_check)

# SCF parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

cell_text = ipw.Text(placeholder='10.0 20.0 30.0',
                    description='cell',
                    style=style, layout={'width': '370px'})
display(cell_text)

In [None]:
def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, vis_spin_button, multiplicity_text]:
        w.disabled = not uks_switch.value

def visualize_spin_guess(b):
    spin_up = [int(v)-1 for v in spin_up_text.value.split()]
    spin_dw = [int(v)-1 for v in spin_dw_text.value.split()]
    viewer_widget.reset()
    viewer_widget.highlight_atoms(spin_up, color='red', size=0.3, opacity=0.4)
    viewer_widget.highlight_atoms(spin_dw, color='blue', size=0.3, opacity=0.4)

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              tooltip='VDW_POTENTIAL',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1 2 3',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='1 2 3',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})
vis_spin_button = ipw.Button(description="Visualize",
                             disabled=True,
                             style = {'description_width': '0px'}, layout={'width': '75px'})
vis_spin_button.on_click(visualize_spin_guess)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})


display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text]), vis_spin_button]), multiplicity_text)

# Orbital parameters

In [None]:
style = {'description_width': '140px'}
layout = {'width': '40%'}
layout_small = {'width': '25%'}

n_homo_inttext = ipw.IntText(
                        description='num HOMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)
n_lumo_inttext = ipw.IntText(
                        description='num LUMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)

heights_text = ipw.Text(description='Heights (ang)',
                              value='3.0 5.0',
                              style=style, layout=layout)

isovals_text = ipw.Text(description='Isovalues',
                              value='1e-7',
                              style=style, layout=layout)

fwhms_text = ipw.Text(description='FWHMs (eV)',
                              value='0.04',
                              style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=3.5,
                        style=style, layout=layout_small)


display(n_homo_inttext, n_lumo_inttext, heights_text, isovals_text, fwhms_text, extrap_plane_floattext)

In [None]:
def file_exists_func(hostname, path):
    # has to be defined in jupyter notebook because of !
    file_exists = ! ssh {hostname} "if [ -f {path} ]; then echo 1 ; else echo 0 ; fi"
    if file_exists[0] == '0':
        return False
    else:
        return True

def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not computer_drop.selected_computer:
            print("Please select a computer.")
            return
        
        cell = np.array(cell_text.value.split(), dtype=float)
        
        if len(cell) != 3 or any(cell) <= 0.0:
            print("Invalid cell")
            return
        
        dft_params_dict = {
            'mgrid_cutoff':    600,
            'elpa_switch':     elpa_check.value,
            'cell':            list(cell),
            'uks':             uks_switch.value
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
        dft_params = Dict(dict=dft_params_dict)

        extrap_plane = extrap_plane_floattext.value
        parent_dir = "parent_calc_folder/"
        stm_params = Dict(dict={
            '--cp2k_input_file':    parent_dir+'aiida.inp',
            '--basis_set_file':     parent_dir+'BASIS_MOLOPT',
            '--xyz_file':           parent_dir+'geom.xyz',
            '--wfn_file':           parent_dir+'aiida-RESTART.wfn',
            '--hartree_file':       parent_dir+'aiida-HART-v_hartree-1_0.cube',
            '--orb_output_file':    'orb.npz',
            '--eval_region':        ['G', 'G', 'G', 'G', 'n-1.0_C', 'p%.1f'%extrap_plane],
            '--dx':                 '0.15',
            '--eval_cutoff':        '14.0',
            '--extrap_extent':      '5.0',
            '--n_homo':             str(n_homo_inttext.value),
            '--n_lumo':             str(n_lumo_inttext.value),
            '--orb_heights':        heights_text.value.split(),
            '--orb_isovalues':      isovals_text.value.split(),
            '--orb_fwhms':          fwhms_text.value.split(),
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        stm_code = stm_codes[drop_stm.index]
        
        struct = struct_browser.results.value
        
        ## Try to access the restart-wfn file ##
        selected_comp = cp2k_code.computer
        wfn_file_path = common.find_struct_wf(struct, selected_comp, file_exists_func)
        if wfn_file_path == "":
            print("Didn't find any accessible .wfn file.")
        else:
            if uks_switch.value:
                print("Not re-using .wfn for UKS calculation.")
                wfn_file_path = ""
        
        #node = submit(
        outputs, node = run_get_node(
            OrbitalWorkChain,
            cp2k_code=cp2k_code,
            structure=struct,
            wfn_file_path=Str(wfn_file_path),
            dft_params=dft_params,
            stm_code=stm_code,
            stm_params=stm_params
        )
        
        # set calculation version; also used to determine post-processing
        node.set_extra("version", 0)
        

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)