In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
%aiida
from aiida.work.process import WorkCalculation
from aiida_cp2k.calculations import Cp2kCalculation
from aiida.orm.data.structure import StructureData
from aiida.orm.data.parameter import ParameterData
from aiida.orm.data.folder import FolderData

from aiida.common.exceptions import NotExistent
from aiida.orm.data.base import Str

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

import matplotlib.pyplot as plt
from pprint import pprint

import nglview

import ase.io
import tempfile

import re
import numpy as np
import scipy.constants


from apps.surfaces.widgets import find_mol

from collections import OrderedDict

from tempfile import NamedTemporaryFile
from base64 import b64encode
def render_thumbnail(atoms):
    tmp = NamedTemporaryFile()
    ase.io.write(tmp.name, atoms, format='png') # does not accept StringIO
    raw = open(tmp.name).read()
    tmp.close()
    return b64encode(raw)

def display_thumbnail(th):
    return '<img width="400px" src="data:image/png;base64,{}" title="">'.format(th)
def html_thumbnail(th):
    return ipw.HTML('<img width="400px" src="data:image/png;base64,{}" title="">'.format(th))


viewer = nglview.NGLWidget()

style = {'description_width': '120px'}
layout = {'width': '70%'}
slider_image_nr = ipw.IntSlider(description='image nr.:',
                              value=1, step=1,
                              min=1, max=2,
                              style=style, layout=layout)



all_ase=[]
slab_analyzed = None

def on_image_nr_change(c):
    visualized_ase = all_ase[slider_image_nr.value-1]
    refresh_structure_view(visualized_ase)
slider_image_nr.observe(on_image_nr_change, 'value')    

clear_output()
display(ipw.VBox([slider_image_nr, viewer]))

In [None]:

def visualize_atoms_in_viewer(viewer, atoms, details):
    atoms.pbc = [1, 1, 1]

    mol_inds = [item for sublist in details['all_molecules'] for item in sublist]
    rest_inds = details['slabatoms']+details['bottom_H']+details['adatoms'] +details['unclassified']

    molecules_ase = atoms[mol_inds]
    rest_ase = atoms[rest_inds]

    # component 0: Molecule
    viewer.add_component(nglview.ASEStructure(molecules_ase), default_representation=False)
    viewer.add_ball_and_stick(aspectRatio=1.8, radius=0.25, opacity=1.0, component=0)
    #viewer.add_ball_and_stick(aspectRatio=1.5, scale=2.0, opacity=1.0, component=0)

    # component 1: Everything else
    viewer.add_component(nglview.ASEStructure(rest_ase), default_representation=False)
    viewer.add_ball_and_stick(aspectRatio=10.0, opacity=1.0, component=1)

    viewer.add_unitcell()

def initialize_structure_view():
    global viewer
        
    # delete all old components
    while hasattr(viewer, "component_0"):
        viewer.component_0.clear_representations()
        cid = viewer.component_0.id
        viewer.remove_component(cid)
    
    if len(all_ase)==0:
        return
    
    atoms = all_ase[0]
    
    visualize_atoms_in_viewer(viewer, atoms, slab_analyzed)
    
    #### -----------------------------------------------------
    ## Camera orientation
    
    viewer.center()

    # Orient camera to look from positive z
    cell_z = atoms.cell[2, 2]
    com = atoms.get_center_of_mass()
    def_orientation = viewer._camera_orientation
    top_z_orientation = [1.0, 0.0, 0.0, 0,
                         0.0, 1.0, 0.0, 0,
                         0.0, 0.0, -np.max([cell_z, 30.0]) , 0,
                         -com[0], -com[1], -com[2], 1]
    viewer._set_camera_orientation(top_z_orientation)
    #viewer.camera = 'orthographic'
    #### -----------------------------------------------------
    

In [None]:
def refresh_structure_view(viz_atoms):
    global viewer
    
    old_camera_orientation = viewer._camera_orientation
    
    # delete all old components
    while hasattr(viewer, "component_0"):
        viewer.component_0.clear_representations()
        cid = viewer.component_0.id
        viewer.remove_component(cid)
        
#    if len(all_ase)==0:
#        return
    #atoms = all_ase[im]
    viz_atoms.pbc = [1, 1, 1]
    
    visualize_atoms_in_viewer(viewer, viz_atoms, slab_analyzed)
    
    #### -----------------------------------------------------
    ## Camera orientation
    
    viewer._set_camera_orientation(old_camera_orientation)
    #### -----------------------------------------------------


#From Kristjan    
#    viewer.add_component(nglview.ASEStructure(viz_atoms), default_representation=False) # adds ball+stick
#    #viewer.add_component(nglview.ASEStructure(slab_atoms)) # adds ball+stick
#    viewer.add_unitcell()
#    #viewer.center()
#    
#    viewer.component_0.add_ball_and_stick(aspectRatio=10.0, opacity=1.0)
#    
#    viewer._set_camera_orientation(old_camera_orientation)

In [None]:
def replicas_from_restart(restart_file_path):
    content = open(restart_file_path).read()
    m = re.search(r'\n\s*&CELL\n(.*?)\n\s*&END CELL\n', content, re.DOTALL)
    cell_lines = [line.strip().split() for line in m.group(1).split("\n")]
    cell_str = [line[1:] for line in cell_lines if line[0] in 'ABC']
    cell = np.array(cell_str, np.float64)
    
    structure_data_list = []
    
    matches = re.findall(r'\n\s*&COORD\n(.*?)\n\s*&END COORD\n', content, re.DOTALL)

    coord_line_sets = [
        [line.strip().split() for line in m.split("\n")] for m in matches
    ]
    coord_set_with_elements = coord_line_sets[-1]
    replica_coord_line_sets = coord_line_sets[:-1]
    element_list = [line[0] for line in coord_set_with_elements]
    for i_rep, rep_coord_lines in enumerate(replica_coord_line_sets):
        positions = np.array(rep_coord_lines, np.float64)
        ase_atoms = ase.Atoms(symbols=element_list, positions=positions, cell=cell)
        structure_data_list.append(StructureData(ase=ase_atoms))
        
    return structure_data_list

def neb_energies_and_distances_from_output(output_file_path):
    content = open(output_file_path).read()
    energy_str_list = re.findall(r'ENERGIES \[au\] = (.*?)BAND', content, re.DOTALL)
    energies_list = [np.array(e_str.split(),dtype=float) for e_str in energy_str_list]
    energies_list = [(e_arr - e_arr[0]) * 27.21138602 for e_arr in energies_list]
    
    dist_str_list = re.findall(r'DISTANCES REP = (.*?)ENERGIES', content, re.DOTALL)
    distances_list = [np.array(['0.0']+ds.split(),dtype=float) * 0.529177249 for ds in dist_str_list]
    
    return distances_list, energies_list

In [None]:
def make_replica_html(structure_data_list, energies):
    html = '<table>'
    
    for i, (rep, en) in enumerate(zip(structure_data_list, energies)):
        thumbnail = rep.get_extra('thumbnail')
        # The table cell
        if i%4 == 0:
            html += '<tr>'
        html += '<td><img width="400px" src="data:image/png;base64,{}" title="">'.format(thumbnail)
        # Output some information about the replica...
        html += '<p><b>Nr: </b>{} <br> <b>Energy:</b> {}</p>'\
                .format(i, en)
        html += '<p>pk: {}</p>'.format(rep.pk)
        # ... and the download link.
        html += '<p><a target="_blank" href="../export_structure.ipynb?uuid={}">Export Structure</a></p><td>'\
                .format(rep.uuid)
        if i%4 == 3:
            html += '</tr>'
            
    html += '</tr>'
    html += '</table>'
    return html

In [None]:
def sorted_opt_rep_keys(keys):
    return sorted([ (int(key.split('_')[2]), key) for key in keys if 'opt_replica' in key])

def process_and_show_neb(c):
    global all_ase, slab_analyzed
    
    structure_data_list = []
    
    btn_show.disabled = True
    with main_out:
        clear_output()
        
    
    wc = drop_nebs.value
    cp2k_calc = wc.get_outputs()[0]
    cc_output_dict = cp2k_calc.get_outputs_dict()

    #### --------------------------------------------------------------
    ## Create a list of structure_data object corresponding to replicas
    if 'opt_replica_0' in cc_output_dict:
        # NEW: get replicas from output nodes
        for rep_key in sorted_opt_rep_keys(cc_output_dict.keys()):
            if rep_key[0] >= len(structure_data_list):
                structure_data_list.append(cc_output_dict[rep_key[1]])
    else:
        # OLD: store replicas in nodes and link through extra
        for rep_key in sorted_opt_rep_keys(cp2k_calc.get_extras().keys()):
            struct_pk = cp2k_calc.get_extras()[rep_key[1]]
            structure_data_list.append(load_node(struct_pk))

        if len(structure_data_list) == 0:
            restart_file_path = cc_output_dict['retrieved'].get_abs_path('aiida-1.restart')
            structure_data_list = replicas_from_restart(restart_file_path)
            for i_rep, rep in enumerate(structure_data_list):
                rep.store()
                cp2k_calc.set_extra("opt_replica_%d" % i_rep, rep.pk)
                with main_out:
                    print("Added replica %d with pk %d to database" % (i_rep, rep.pk))
        
    #### --------------------------------------------------------------
    ## Add thumbnails to replicas if they are not already added
    ## ans store list of ASE structures for the viz
    for rep in structure_data_list:
        the_ase=rep.get_ase()
        all_ase.append(the_ase)
        if not "thumbnail" in rep.get_extras():
            rep.set_extra("thumbnail", render_thumbnail(the_ase))
    
    #### --------------------------------------------------------------
    ## Analyze the slab to use nice representation in the 3d viewer
    slab_analyzed = find_mol.analyze_slab(all_ase[0])
    
    #### --------------------------------------------------------------
    output_file_path = cc_output_dict['retrieved'].get_abs_path('aiida.out')
    distances_list, energies_list = neb_energies_and_distances_from_output(output_file_path)
    
    replica_html = make_replica_html(structure_data_list, energies_list[-1])
    
    barrier_list = [np.max(e_arr) for e_arr in energies_list]
    
    with main_out:
        f, axarr = plt.subplots(1, 2, figsize=(14, 4))
        
        axarr[0].plot(energies_list[-1], 'o-')
        axarr[0].set_ylabel("Energy (eV)")
        axarr[0].set_xlabel("Replica nr")
        axarr[0].set_title("NEB energy profile")
        
        axarr[1].plot(barrier_list, 'o-')
        axarr[1].axhline(barrier_list[-1], linestyle='--', color='lightgray')
        axarr[1].set_ylabel("Barrier (eV)")
        axarr[1].set_xlabel("Iteration nr")
        axarr[1].set_title("NEB convergence")
        
        plt.show()
        
        display(ipw.HTML(replica_html))
        
        print("List of all replica PKs:")
        rep_pk_str = "["
        for struct in structure_data_list:
            rep_pk_str += "%d " % struct.pk
        print(rep_pk_str[:-1] + "]")
    
    
    btn_show.disabled = False
    slider_image_nr.max=len(all_ase)
    initialize_structure_view()

In [None]:
qb = QueryBuilder()
qb.append(WorkCalculation,
          filters={
              'attributes._process_label': {'==': 'NEBWorkchain'}
          },
          tag='wc',
          project='*'
)
qb.append(Cp2kCalculation,
          output_of='wc',
          filters={
              'attributes.state': {'==': 'FINISHED'}}
)
qb.order_by({WorkCalculation: {'id': 'desc'}})
neb_wcs = qb.all()

In [None]:
sel_options = OrderedDict([("PK %d: " % (neb_wc[0].pk) + neb_wc[0].description, neb_wc[0]) for neb_wc in neb_wcs])

In [None]:
style = {'description_width': '120px'}
layout = {'width': '70%'}

drop_nebs = ipw.Dropdown(options = sel_options,
                            description = 'NEB: ', layout=layout, style=style)

btn_show = ipw.Button(description="Show")
btn_show.on_click(process_and_show_neb)

main_out = ipw.Output()

display(drop_nebs, btn_show, main_out)