#  Submit Molecule Geometry Optimization

In [None]:
%aiida

In [None]:
# General imports
import nglview
import numpy as np
import ipywidgets as ipw
from collections import OrderedDict
from ase.data import vdw_radii
from IPython.display import display, clear_output, HTML

# AiiDA imports
from aiida.orm import SinglefileData
from aiidalab_widgets_base import CodeDropdown
from aiida_cp2k.workchains.base import Cp2kBaseWorkChain

# Custom imports
from apps.surfaces.structure_browser import StructureBrowser
from apps.surfaces.widgets import find_mol
from apps.surfaces.widgets.computer_code_selection import ComputerCodeDropdown
from apps.surfaces.widgets.dft_details_dev import DFTDetails
from apps.surfaces.widgets.viewer_details import ViewerDetails
from apps.surfaces.widgets.slab_validity import slab_is_valid
from apps.surfaces.widgets.suggested_param import suggested_parameters
from apps.surfaces.widgets.submit_button_dev import SubmitButton
from apps.surfaces.widgets.metadata import MetadataWidget
from apps.surfaces.widgets.get_cp2k_input_dev import Get_CP2K_Input

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
## GENERAL PARAMETERS
job_details = {'workchain':'MoleculeOptWorkChain'}
job_details['calc_type']='Full DFT'
slab_analyzed = None
atoms = None

In [None]:
## WIDGETS MONITOR FUNCTIONS
def on_struct_change(c):
    global job_details, slab_analyzed, atoms
    if not struct_browser.results.value:
        return
    atoms = struct_browser.results.value.get_ase()
    atoms.pbc = [1, 1, 1]
    slab_analyzed = find_mol.analyze_slab(atoms)
    job_details['slab_analyzed'] = slab_analyzed
    job_details['elements']      = slab_analyzed['all_elements']
    cell=job_details['slab_analyzed']['the_cell']
    cell=str(cell[0][0])+' '+str(cell[1][1])+' '+str(cell[2][2])    
    dft_details_widget.reset(cell=cell)
    
    #guess_calc_params(slab_analyzed)
    
    viewer_widget.setup(atoms, slab_analyzed)
    
    with mol_ids_info_out:
        clear_output()
        print(slab_analyzed['summary'])

def on_fixed_atoms_btn_click(c):
    if dft_details_widget.btn_fixed_pressed:
        viewer_widget.show_fixed(dft_details_widget.fixed_atoms.value)
    else:
        viewer_widget.show_fixed("")

In [None]:
def make_geom_file(structure, filename):
    import tempfile
    import shutil
    from io import BytesIO

    spin_guess = extract_spin_guess(structure)
    atoms = structure.get_ase()
    n_atoms = len(atoms)
    tmpdir = tempfile.mkdtemp()
    file_path = tmpdir + "/" + filename
    orig_file = BytesIO()
    atoms.write(orig_file, format='xyz')
    orig_file.seek(0)
    all_lines = orig_file.readlines()
    comment = all_lines[1].strip()
    orig_lines = all_lines[2:]

    modif_lines = []
    for i_line, line in enumerate(orig_lines):
        new_line = line
        lsp = line.split()
        if spin_guess is not None:
            if i_line in spin_guess[0]:
                new_line = lsp[0]+"1 " + " ".join(lsp[1:])+"\n"
            if i_line in spin_guess[1]:
                new_line = lsp[0]+"2 " + " ".join(lsp[1:])+"\n"
        modif_lines.append(new_line)


    final_str = "%d\n%s\n" % (n_atoms, comment) + "".join(modif_lines)

    with open(file_path, 'w') as f:
        f.write(final_str)
    aiida_f = SinglefileData(file=file_path)
    shutil.rmtree(tmpdir)
    return aiida_f

def extract_spin_guess(struct_node):
    sites_list = struct_node.attributes['sites']

    spin_up_inds = []
    spin_dw_inds = []

    for i_site, site in enumerate(sites_list):
        if site['kind_name'][-1] == '1':
            spin_up_inds.append(i_site)
        elif site['kind_name'][-1] == '2':
            spin_dw_inds.append(i_site)

    return [spin_up_inds, spin_dw_inds]
# ==========================================================================

In [None]:
## DISPLAY WIDGETS AND DEFINE JOB PARAMETERS

##STRUCTURE
struct_browser = StructureBrowser()
struct_browser.results.observe(on_struct_change, names='value') 

##VIEWER
viewer_widget = ViewerDetails()

mol_ids_info_out = ipw.Output()

display(ipw.VBox([struct_browser, viewer_widget, mol_ids_info_out]))

##CODE
#computer_code_dropdown = ComputerCodeDropdown(job_details=job_details,input_plugin='cp2k')
computer_code_dropdown = CodeDropdown(input_plugin='cp2k')
##DFT
dft_details_widget = DFTDetails(job_details           = job_details,
                                widgets_disabled      = {
                                    'calc_type'     : True,
                                    'center_switch' : False,
                                    'charge'        : False,
                                    'multiplicity'  : False,
                                    'uks_switch'    : False,
                                    'cell'          : False
                                }
                               )    
dft_details_widget.btn_fixed_atoms.on_click(on_fixed_atoms_btn_click)
#dft_details_widget.calc_type.observe(lambda c: guess_calc_params(), names='value')

metadata_widget = MetadataWidget()

def param_function():
    global job_details, slab_analyzed, atoms
    
    # use builder to construct the input parameters
    input_builder = Cp2kBaseWorkChain.get_builder()

    # code
    input_builder.cp2k.code = computer_code_dropdown.selected_code
    # structure
    input_builder.cp2k.file.input_xyz = make_geom_file(struct_browser.results.value, "mol.xyz")
    
    
    # TODO: Carlo, fix this 
    # from here
    dft_details_widget.read_widgets_and_update_job_details()    
    job_details['elements'] = slab_analyzed['all_elements']
    # TODO: check the UKS update
    ### ODD CHARGE AND RKS
    odd_charge = slab_analyzed['total_charge']
    if 'charge' in job_details.keys():
        odd_charge += job_details['charge'] 
        rks = True
        if 'uks_switch' in job_details.keys():
            if  job_details['uks_switch'] == 'UKS':
                rks=False                       

    if  odd_charge%2 >0 and rks  :
        print("ODD CHARGE AND RKS")
        return None

    if job_details['workchain'] == 'NEBWorkchain':
        if len(job_details['replica_pks'].split()) < 2:
            print('Please select at least two  replica_pks')
            return None

    cell_ase = atoms.cell.flatten().tolist()
    if 'cell' in job_details:
        if job_details['cell'] == '' or job_details['cell'] == None :
            job_details['cell'] = cell_ase   
        else:
            cell_abc=job_details['cell'].split()
            job_details['cell']=np.diag(np.array(cell_abc, dtype=float)).flatten().tolist()
    else:
        job_details['cell'] = cell_ase
    # till here

    # Finalize cp2k input parameters
    input_builder.cp2k.parameters = Dict(dict=Get_CP2K_Input(input_dict = job_details).inp)
    
    # Metadata
    input_builder.cp2k.metadata = metadata_widget.dict

    return input_builder


#display code dropdown
display(ipw.VBox([computer_code_dropdown, dft_details_widget, metadata_widget]))

#display submit button
btn_submit = SubmitButton(workchain=Cp2kBaseWorkChain,param_function=param_function)
display(btn_submit)