In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
%load_ext aiida
%aiida
import ipywidgets as ipw
from IPython.display import display, clear_output

import matplotlib.pyplot as plt
import ase

import re
import numpy as np
import urllib.parse
from tempfile import NamedTemporaryFile
from base64 import b64encode
from surface_tools.helpers import HART_2_EV, BOHR_2_ANG

# Nudged Elastic Band calculations

In [None]:

def render_thumbnail(atoms):
    tmp = NamedTemporaryFile()
    ase.io.write(tmp.name, atoms, format='png')
    raw = open(tmp.name, 'rb').read()
    tmp.close()
    return b64encode(raw).decode()

def display_thumbnail(th):
    return '<img width="400px" src="data:image/png;base64,{}" title="">'.format(th)
def html_thumbnail(th):
    return ipw.HTML('<img width="400px" src="data:image/png;base64,{}" title="">'.format(th))


all_ase=[]


In [None]:
def make_replica_html(structure_data_list, energies, distances):
    html = '<table>'
    
    n_col = 4
    for i, (rep, en, dist) in enumerate(zip(structure_data_list, energies, distances)):
        thumbnail = rep.get_extra('thumbnail')
        # The table cell.
        if i%n_col == 0:
            html += '<tr>'
        html += f'<td><img width="400px" src="data:image/png;base64,{thumbnail}" title="">'

        # Output some information about the replica.
        html += f'<p><b>Nr: </b>{i} <br> <b>Energy:</b> {en:.6f} eV <br> <b>Dist. to prev:</b> {dist:.4f} ang</p>'
        html += f'<p>pk: {rep.pk}</p>'
        # And the download link.
        html += f'<p><a target="_blank" href="export_structure.ipynb?uuid={rep.uuid}">View & export</a></p><td>'\
        if i%n_col == n_col-1:
            html += '</tr>'
            
    html += '</tr>'
    html += '</table>'
    return html

In [None]:
def sorted_opt_rep_keys(keys):
    return sorted([ (int(key.split('_')[2]), key) for key in keys if 'opt_replica' in key])

def process_and_show_neb(_=None):
    global all_ase
    
    wc=load_node(pk_select.value)
    structure_data_list = []
    btn_show.disabled = True
    with main_out:
        clear_output()

    try:
        # old workchain
        old_workchain = True
        nreplicas = wc.inputs['nreplicas'].value
    except:
        # new workchain
        old_workchain = False
        nreplicas = wc.inputs.neb_params['number_of_replica']
                 
    for i_rep in range(nreplicas):
        if old_workchain:
            label = f"opt_replica_{i_rep}"
        else:
            label = f"opt_replica_{str(i_rep).zfill(3)}"

        structure_data_list.append(wc.outputs[label])
        
    energies_array = wc.outputs['replica_energies'].get_array('energies') * HART_2_EV
    distances_array = wc.outputs['replica_distances'].get_array('distances') * BOHR_2_ANG
        
    energies_array = np.array([e_arr - e_arr[0] for e_arr in energies_array])
            
    ## Add thumbnails to replicas if they are not already added
    ## ans store list of ASE structures for the viz.
    for rep in structure_data_list:
        the_ase=rep.get_ase()
        all_ase.append(the_ase)
        if not "thumbnail" in rep.extras:
            rep.set_extra("thumbnail", render_thumbnail(the_ase))
    
    replica_html = make_replica_html(structure_data_list, energies_array[-1], distances_array[-1])
    
    barrier_list = [np.max(e_arr) for e_arr in energies_array]
    
    with main_out:
        f, axarr = plt.subplots(1, 2, figsize=(14, 4))
        
        axarr[0].plot(energies_array[-1], 'o-')
        axarr[0].set_ylabel("Energy (eV)")
        axarr[0].set_xlabel("Replica nr")
        axarr[0].set_title("NEB energy profile")
        
        axarr[1].plot(barrier_list, 'o-')
        axarr[1].axhline(barrier_list[-1], linestyle='--', color='lightgray')
        axarr[1].set_ylabel("Barrier (eV)")
        axarr[1].set_xlabel("Iteration nr")
        axarr[1].set_title("NEB convergence")
        
        plt.show()
        
        display(ipw.HTML(replica_html))
        
        print("List of all replica PKs:")
        rep_pk_str = "["
        for struct in structure_data_list:
            rep_pk_str += f"={struct.pk} "
        print(rep_pk_str[:-1] + "]")
    
    
    btn_show.disabled = False

In [None]:
style = {'description_width': '120px'}
layout = {'width': '70%'}

pk_select=ipw.IntText(description='Load node: ', layout=layout, style=style)

btn_show = ipw.Button(description="Show")
btn_show.on_click(process_and_show_neb)

main_out = ipw.Output()

display(pk_select,btn_show, main_out)

In [None]:
# Load the URL after everything is set up.
try:
    url = urllib.parse.urlsplit(jupyter_notebook_url)
    pk_select.value = urllib.parse.parse_qs(url.query)['pk'][0]
    process_and_show_neb(0)
except:
    pass