#  Submit Molecule on Slab Geometry Optimization

In [None]:
%aiida

In [None]:
# General imports
import nglview
import numpy as np
import ipywidgets as ipw
from collections import OrderedDict
from ase.data import vdw_radii
from ase import Atoms
from IPython.display import display, clear_output, HTML
import itertools

# AiiDA imports
from aiida_cp2k.utils import Cp2kInput
from aiida.orm import SinglefileData
from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget,  StructureBrowserWidget
from aiida_cp2k.workchains.base import Cp2kBaseWorkChain

# Custom imports
from apps.surfaces.widgets.metadata import MetadataWidget
from apps.surfaces.widgets import analyze_structure
from apps.surfaces.widgets.cp2k2dict import CP2K2DICT
from apps.surfaces.widgets.create_xyz_input_files import make_geom_file
from apps.surfaces.widgets.dft_details import DFTDetails
from apps.surfaces.widgets.viewer_details import ViewerDetails
from apps.surfaces.widgets.slab_validity import slab_is_valid
from apps.surfaces.widgets.cp2k_input_validity import validate_input #input_is_valid
from apps.surfaces.widgets.suggested_param import suggested_parameters
from apps.surfaces.widgets.get_cp2k_input import Get_CP2K_Input

from aiidalab_widgets_base import CodeDropdown, StructureManagerWidget, StructureBrowserWidget, StructureUploadWidget, SubmitButtonWidget, SmilesWidget

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
structure_selector = StructureManagerWidget(importers=[
    ("Import from computer", StructureUploadWidget()),
    ("AiiDA database", StructureBrowserWidget()),
    ],
    storable=True)#, node_class='StructureData')
display(structure_selector)

In [None]:
## GENERAL PARAMETERS
slab_analyzed = None
atoms = None
details_dict={}
workchain='SlabGeoOptWorkChain'
details_dict['workchain']=workchain        

In [None]:
process_description = ipw.Text(description='Process description: ',
                               placeholder='Type the name here.',
                               style={'description_width': '120px'},
                               layout={'width': '70%'})

In [None]:
## WIDGETS MONITOR FUNCTIONS

def on_struct_change(c):

    global details_dict, slab_analyzed, atoms
    if not structure_selector.structure_node:
        return
    atoms = structure_selector.structure
    atoms.pbc = [1, 1, 1]
    slab_analyzed = analyze_structure.analyze(atoms)
    if slab_analyzed['system_type'] != 'Slabxy':
        #viewer_widget.setup(None, None)
        #struct_browser.results.value=None ##TO DO SET EMPTY CHOICE
        msg='Only XY Slabs are allowed'
    else:
        details_dict={}
        details_dict['elements']=slab_analyzed['all_elements']

        #workchain='GWWorkChain'
        workchain='SlabGeoOptWorkChain'
        details_dict['workchain']=workchain        
        dft_details_widget.reset(slab_analyzed)
        msg=slab_analyzed['summary']
        #guess_calc_params(slab_analyzed)
    
        #viewer_widget.setup(atoms, slab_analyzed)
    
    with mol_ids_info_out:
        clear_output()
        #print(slab_analyzed['summary'])
        print(msg)
        
def on_fixed_atoms_btn_click(c):
    if dft_details_widget.btn_fixed_pressed:
        structure_selector.viewer.selection = dft_details_widget.fixed_atoms.value
    else:
        structure_selector.viewer.selection = set()     
        
        
#def guess_calc_params(slab_analyzed):
#    method = dft_details_widget.calc_type.value
#    valid_slab, msg = slab_is_valid(slab_analyzed,method)
#    if valid_slab:        
#        atoms_to_fix,num_nodes=suggested_parameters(slab_analyzed,method)
#        dft_details_widget.reset(fixed_atoms=atoms_to_fix,calc_type=method)
#    else:
#        print(msg)

In [None]:
## DISPLAY WIDGETS AND DEFINE JOB PARAMETERS

##STRUCTURE
#struct_browser = StructureBrowserWidget()
#struct_browser.results.observe(on_struct_change, names='value') 

structure_selector.observe(on_struct_change, names='structure') 

##VIEWER
#viewer_widget = ViewerDetails()

mol_ids_info_out = ipw.Output()

#display(ipw.VBox([struct_browser, viewer_widget, mol_ids_info_out]))
display(ipw.VBox([mol_ids_info_out]))


##CODE
computer_code_dropdown = CodeDropdown(input_plugin='cp2k')#, path_to_root="../../")

##DFT
dft_details_widget = DFTDetails(workchain=workchain,
                                structure_details=slab_analyzed
                               )    
dft_details_widget.btn_fixed_atoms.on_click(on_fixed_atoms_btn_click)
#dft_details_widget.calc_type.observe(lambda c: guess_calc_params(slab_analyzed), names='value')


##NUMBER NODES
metadata_widget = MetadataWidget()
if computer_code_dropdown.selected_code:
    metadata_widget.num_mpiprocs_per_machine.value = computer_code_dropdown.selected_code.computer.get_default_mpiprocs_per_machine()

def modify_mpiprocs_per_machine(change):
    metadata_widget.num_mpiprocs_per_machine.value = change['new'].computer.get_default_mpiprocs_per_machine()
    
computer_code_dropdown.observe(modify_mpiprocs_per_machine, names='selected_code')

##PLAIN TEXT INPUT
plain_input=ipw.Textarea(value='',disabled=False,
                                     layout={'width': '60%'})
plain_input_accordion = ipw.Accordion(selected_index=None)
plain_input_accordion.children=[plain_input]
plain_input_accordion.set_title(0,'plain input')

##VALIDATE AND CREATE INPUT
create_input=ipw.Button(description='create input', layout={'width': '10%'})



def get_plain_input():
    error_msg,submit_dict=CP2K2DICT(input_lines=plain_input.value)
    if error_msg != "":
        print(error_msg)
    else:
        input_builder = Cp2kBaseWorkChain.get_builder()
        
        # code
        input_builder.cp2k.code = computer_code_dropdown.selected_code
        # structure
        input_builder.cp2k.file.input_xyz = make_geom_file(struct_browser.results.value, Str("mol_on_slab.xyz")) 
       
        if details_dict['calc_type'] != 'Full DFT':
            # slab potential
            slab_element=list(slab_analyzed['slab_elements'])[0]
            input_builder.cp2k.file.pot_f = SinglefileData(file='/home/aiida/apps/surfaces/slab/'+slab_element+'.pot')
            mol_indexes = list(itertools.chain(*slab_analyzed['all_molecules']))
            input_builder.cp2k.file.mol_xyz = make_geom_file(struct_browser.results.value,
                                                             Str("mol.xyz"),
                                                             selection=List(list=mol_indexes))        
        
        
        input_builder.cp2k.metadata = metadata_widget.dict
        input_builder.metadata.label = "SlabGeoOptWorkChain"
        input_builder.metadata.description = process_description.value

        input_builder.cp2k.metadata['label'] = "slab_geo_opt"
        input_builder.cp2k.metadata['description'] = process_description.value
        input_builder.cp2k.parameters = Dict(dict=submit_dict)
    return input_builder        
    
def on_create_input_btn_click(c):
    dft_details_widget.get_widget_values(details_dict)
    num_machines = metadata_widget.dict['options']['resources']['num_machines']
    num_mpiprocs_per_machine = metadata_widget.dict['options']['resources']['num_mpiprocs_per_machine']
    ntasks = num_machines * num_mpiprocs_per_machine
    details_dict['mpi_tasks'] = ntasks
    details_dict['walltime'] = metadata_widget.dict['options']['max_wallclock_seconds']        
        
    inp_dict = Get_CP2K_Input(input_dict = details_dict).inp
    inp_plain = Cp2kInput(inp_dict)
    plain_input.value = inp_plain.render()
    
    can_submit,error_msg=validate_input(slab_analyzed,details_dict)
    if can_submit:
        btn_submit.btn_submit.disabled=False
    else:
        btn_submit.btn_submit.disabled=True
        print(error_msg)
    
create_input.on_click(on_create_input_btn_click)

##DISPLAY
display(ipw.VBox([
    computer_code_dropdown,
    dft_details_widget,
    plain_input_accordion,
    process_description,
    metadata_widget,
    create_input]))


#display code dropdown
#display(ipw.VBox([computer_code_dropdown, dft_details_widget, metadata_widget]))

#display submit button
btn_submit = SubmitButtonWidget(process=Cp2kBaseWorkChain,widgets_values=get_plain_input)
display(btn_submit)