In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation

import numpy as np
import scipy.constants as const
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
import re
import gzip
import matplotlib.pyplot as plt
from collections import OrderedDict
import urlparse
import io
import zipfile
import StringIO

import matplotlib.pyplot as plt

from apps.scanning_probe import common
from apps.scanning_probe import igor

In [None]:
colormaps = ['seismic', 'gist_heat']

def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

import matplotlib

class FormatScalarFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, fformat="%1.1f", offset=True, mathText=True):
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,
                                                        useMathText=mathText)
    def _set_format(self, vmin, vmax):
        self.format = self.fformat
        if self._useMathText:
            self.format = '$%s$' % matplotlib.ticker._mathdefault(self.format)

def make_plot(fig, ax, data, extent, title=None, title_size=None, center0=False, vmin=None, vmax=None, cmap='gist_heat', noadd=False):
    if center0:
        data_amax = np.max(np.abs(data))
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=-data_amax, vmax=data_amax)
    else:
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=vmin, vmax=vmax)
    
    if noadd:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.set_xlabel(r"x ($\AA$)")
        ax.set_ylabel(r"y ($\AA$)")
        if 1e-3 < np.max(data) < 1e3:
            cb = fig.colorbar(im, ax=ax)
        else:
            cb = fig.colorbar(im, ax=ax, format=FormatScalarFormatter("%.1f"))
        cb.formatter.set_powerlimits((-2, 2))
        cb.update_ticks()
    ax.set_title(title)
    if title_size:
        ax.title.set_fontsize(title_size)
    ax.axis('scaled')

class SeriesPlotter():
    
    def __init__(self, set_indexes_function):
        
        self.series = None
        self.extent = None
        self.figure_xy_ratio = None
        self.labels = None
        self.wc_pk = None
        
        self.set_indexes_function = set_indexes_function
        
        ### -------------------------------------------
        ### Selector
        
        self.elem_list = []
        self.selections_vbox = ipw.VBox([])
        
        self.add_row_btn = ipw.Button(description='Add series row', disabled=True)
        self.add_row_btn.on_click(lambda b: self.add_selection_row())
        
        style = {'description_width': '80px'}
        layout = {'width': '40%'}
        
        self.selector_widget = ipw.VBox([self.add_row_btn, self.selections_vbox])
        ### -------------------------------------------
        ### Plotter
        
        self.plot_btn = ipw.Button(description='Plot', disabled=True)
        self.plot_btn.on_click(self.plot_series)
        
        self.clear_btn = ipw.Button(description='Clear', disabled=True)
        self.clear_btn.on_click(self.full_clear)
        
        self.plot_output = ipw.VBox()
        ### -------------------------------------------
        
        self.fig_y = 4
        
        ### -------------------------------------------
        ### Creating a zip
        self.zip_btn = ipw.Button(description='Image zip', disabled=True)
        self.zip_btn.on_click(self.create_zip_link)

        self.zip_progress = ipw.FloatProgress(
                value=0,
                min=0,
                max=1.0,
                description='progress:',
                bar_style='info',
                orientation='horizontal'
            )

        self.link_out = ipw.Output()
        
        
    def setup(self, series, extent, labels, wc_pk):
        self.series = series
        self.add_selection_row()
        self.add_row_btn.disabled=False
        self.plot_btn.disabled=False
        self.clear_btn.disabled=False
    
        self.extent = extent
        self.figure_xy_ratio = (extent[1] - extent[0]) / (extent[3] - extent[2])
        self.labels = labels
        
        self.zip_btn.disabled = False
        self.wc_pk = wc_pk
        
        
    def add_selection_row(self):
        
        drop_full_series = ipw.Dropdown(description="series", options=sorted(self.series.keys()),
            style = {'description_width': 'auto'})
        drop_cmap = ipw.Dropdown(description="colormap", options=colormaps,
            style = {'description_width': 'auto'})
        sym_check = ipw.Checkbox(
            value=False,
            description='sym. zero',
            disabled=False,
            style = {'description_width': 'auto'},
            layout=ipw.Layout(width='auto')
        )
        rm_btn = ipw.Button(description='x', layout=ipw.Layout(width='30px'))
        rm_btn.on_click(lambda b: self.remove_line_row(b))

        elements = [drop_full_series, drop_cmap, sym_check, rm_btn]
        element_widths = ['210px', '210px', '120px', '35px']
        boxed_row = ipw.HBox([ipw.HBox([row_el], layout=ipw.Layout(border='0.1px solid', width=row_w)) for row_el, row_w in zip(elements, element_widths)])

        self.elem_list.append(elements)
        self.selections_vbox.children += (boxed_row, )
    
    def remove_line_row(self, b):
        rm_btn_list = [elem[3] for elem in self.elem_list]
        rm_index = rm_btn_list.index(b)
        del self.elem_list[rm_index]
        self.selections_vbox.children = remove_from_tuple(self.selections_vbox.children, rm_index)
    
    def plot_series(self, b):
        
        fig_y_in_px = 0.8*self.fig_y*matplotlib.rcParams['figure.dpi']
        
        num_series = len(self.elem_list)

        box_layout = ipw.Layout(overflow_x='scroll',
                        border='3px solid black',
                        width='100%',
                        height='%dpx' % (fig_y_in_px*num_series + 70),
                        display='inline-flex',
                        flex_flow='column wrap',
                        align_items='flex-start')
        
        plot_hbox = ipw.Box(layout=box_layout)
        self.plot_output.children += (plot_hbox, )

        plot_hbox.children = ()
        
        i_start, i_end = self.set_indexes_function()
    
        for i in range(i_start, i_end):
            plot_out = ipw.Output()
            plot_hbox.children += (plot_out, )
            with plot_out:
                fig = plt.figure(figsize=(self.fig_y*self.figure_xy_ratio, self.fig_y*num_series))

                for i_ser in range(num_series):

                    series_label = self.elem_list[i_ser][0].value
                    cmap = self.elem_list[i_ser][1].value
                    sym_check = self.elem_list[i_ser][2].value

                    title = '%s %s'%(series_label, self.labels[i])
                    data = self.series[series_label]

                    ax = plt.subplot(num_series, 1, i_ser+1)

                    make_plot(fig, ax, data[i, :, :], center0=sym_check,
                              extent=self.extent, title=title, cmap=cmap, noadd=True)

                plt.show()
                
    def create_zip_link(self, b):
        
        self.zip_btn.disabled = True
        
        filename = "orbs_pk%d.zip" % self.wc_pk

        zip_buffer = io.BytesIO()
        with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED, False) as zip_file:
            self.data_to_zip(zip_file)

        ! mkdir -p tmp

        with open('tmp/'+filename, 'wb') as f:
            f.write(zip_buffer.getvalue())

        with self.link_out:
            display(HTML('<a href="tmp/%s" target="_blank">download zip</a>' % filename))
        
    def data_to_zip(self, zip_file):
        
        i_start, i_end = self.set_indexes_function()
        
        num_series = len(self.elem_list)
        
        total_pics = (i_end - i_start) * num_series
        
        for i in range(i_start, i_end):
            for i_ser in range(num_series):
                
                series_label = self.elem_list[i_ser][0].value
                cmap = self.elem_list[i_ser][1].value
                sym_check = self.elem_list[i_ser][2].value
                
                title = '%s %s'%(series_label, self.labels[i])
                data = self.series[series_label]
                
                plot_name = series_label.lower().replace(" ", '_').replace("=", '')
                plot_name += "_%02d%s" % (i-i_start, self.labels[i])
                
                # ---------------------------------------------------
                # Add the png to zip
                
                fig = plt.figure(figsize=(self.fig_y*self.figure_xy_ratio, self.fig_y))
                ax = plt.gca()

                make_plot(fig, ax, data[i, :, :], center0=sym_check,
                          extent=self.extent, title=title, cmap=cmap, noadd=False)
                
                imgdata = StringIO.StringIO()
                fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
                zip_file.writestr(plot_name+".png", imgdata.getvalue())
                plt.close()
                
                # ---------------------------------------------------
                # Add txt data to the zip
                header = "xlim=(%.2f, %.2f), ylim=(%.2f, %.2f)" % (self.extent[0], self.extent[1],
                                                                   self.extent[2], self.extent[3])
                txtdata = StringIO.StringIO()
                np.savetxt(txtdata, data[i, :, :], header=header, fmt="%.3e")
                zip_file.writestr("txt/"+plot_name+".txt", txtdata.getvalue())
                # ---------------------------------------------------

                # ---------------------------------------------------
                # Add IGOR format to zip
                igorwave = igor.Wave2d(
                        data=data[i, :, :],
                        xmin=self.extent[0],
                        xmax=self.extent[1],
                        xlabel='x [Angstroms]',
                        ymin=self.extent[2],
                        ymax=self.extent[3],
                        ylabel='y [Angstroms]',
                )
                zip_file.writestr("itx/"+plot_name+".itx", str(igorwave))
                # ---------------------------------------------------

                self.zip_progress.value += 1.0/float(total_pics-1)
                
                
    def full_clear(self, b):
        self.plot_output.children = ()


In [None]:
orb_indexes = None
cp2k_calc = None

def load_pk(b):
    global cp2k_calc
    global orb_indexes
    try:
        workcalc = load_node(pk=pk_select.value)
        cp2k_calc = common.get_calc_by_label(workcalc, 'scf_diag')
        orb_calc = common.get_calc_by_label(workcalc, 'orb')
    except:
        print("Incorrect pk.")
        return
    
    #geom_info.value = common.get_slab_calc_info(workcalc)
    
    ### ----------------------------------------------------
    ### Load data
    loaded_data = np.load(orb_calc.out.retrieved.get_abs_path('orb.npz'))
    
    orbital_data = loaded_data['orbitals']
    heights = loaded_data['heights']
    orb_indexes = loaded_data['orb_list']
    energies = loaded_data['orb_list']
    x_arr = loaded_data['x_arr'] * 0.529177
    y_arr = loaded_data['y_arr'] * 0.529177
    
    nspin = len(orbital_data)

    ### ----------------------------------------------------
    ### Create series
    orbital_series = {}
    
    for i_spin in range(nspin):
        for i_h, h in enumerate(heights):
            orbital_series["orb h=%.1f, s%d" % (h, i_spin)] = orbital_data[i_spin, i_h, :, :, :]
            orbital_series["orb^2 h=%.1f, s%d" % (h, i_spin)] = orbital_data[i_spin, i_h, :, :, :]**2
            
    ### TODO: Also STS/STM series at orbital energies
    
    ### Labels for each MO in the series
    labels = []
    for i_orb in orb_indexes:
        if i_orb <= 0:
            labels.append("HOMO%+d" % i_orb)
        else:
            labels.append("LUMO%+d" % (i_orb-1))
    
    ### ----------------------------------------------------
    extent = [np.min(x_arr), np.max(x_arr), np.min(y_arr), np.max(y_arr)]
    
    series_plotter.setup(orbital_series, extent=extent, labels=labels, wc_pk=workcalc.pk)
    wfn_kit_button.disabled = False
    

style = {'description_width': '50px'}
layout = {'width': '70%'}
    
pk_select = ipw.IntText(value=0, description='pk', style=style, layout=layout)

load_pk_btn = ipw.Button(description='Load pk', style=style, layout=layout)
load_pk_btn.on_click(load_pk)

geom_info = ipw.HTML()

display(ipw.HBox([ipw.VBox([pk_select, load_pk_btn]), geom_info]))

# Orbital images

In [None]:
def selected_orbital_indexes():
    n_homo = n_homo_inttext.value
    n_lumo = n_lumo_inttext.value
    
    i_start_ = np.where(np.logical_and(orb_indexes <= 1, orb_indexes > -n_homo))[0]
    i_start = i_start_[0] if len(i_start_) != 0 else 1
    
    i_end_ = np.where(np.logical_and(orb_indexes > 0, orb_indexes < n_lumo+2))[0]
    i_end = i_end_[-1] if len(i_end_) != 0 else len(orb_indexes)
    
    return i_start, i_end

In [None]:
style = {'description_width': '80px'}
layout = {'width': '40%'}

series_plotter = SeriesPlotter(set_indexes_function=selected_orbital_indexes)

### -----------------------------------------------
### Plot selector

n_homo_inttext = ipw.IntText(
                        description='num HOMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout)
n_lumo_inttext = ipw.IntText(
                        description='num LUMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout)

n_orb_select = ipw.HBox([n_homo_inttext, n_lumo_inttext],
                        style=style, layout={'width': '60%'})

### -----------------------------------------------


display(series_plotter.selector_widget, n_orb_select,
        series_plotter.plot_btn, series_plotter.clear_btn, series_plotter.plot_output)

# Export
**Image zip** exports the currently selected orbital images in png, txt and IGOR pro formats.

**Cube creation kit** creates an archive containing all necessary ingredients to generate the Kohn-Sham orbital cube files with the `cube_from_wfn.py` script available from https://github.com/eimrek/atomistic_tools.

In [None]:
display(ipw.HBox([series_plotter.zip_btn, series_plotter.zip_progress]), series_plotter.link_out)

In [None]:
def create_wfn_zip(b):
    wfn_kit_button.disabled=True
    ! mkdir -p tmp
    cube_kit_name = "cube-kit.zip"
    zipf = zipfile.ZipFile('tmp/%s'%cube_kit_name, 'w', zipfile.ZIP_DEFLATED)
    fd = cp2k_calc.get_outputs_dict()['retrieved']
    for fn in ['BASIS_MOLOPT', 'aiida.inp', 'aiida.out',  'geom.xyz', 'aiida-RESTART.wfn']:
        zipf.write(fd.get_abs_path(fn), arcname=fn)
    
    run_script_path = "/project/apps/scanning_probe/orb/misc/run_cube_from_wfn.py"
    zipf.write(run_script_path, arcname="run_cube_from_wfn.py")
    zipf.close()
    with wfn_kit_output:
        display(HTML('<a href="tmp/%s" target="_blank">download zip</a>' %cube_kit_name))
        
wfn_kit_button = ipw.Button(description='Cube creation kit', disabled=True)
wfn_kit_button.on_click(create_wfn_zip)

wfn_kit_output = ipw.Output()

display(wfn_kit_button, wfn_kit_output)

In [None]:
def clear_tmp(b):
    ! rm -rf tmp && mkdir tmp
    with series_plotter.link_out:
        clear_output()
    series_plotter.zip_progress.value = 0.0
    
    with wfn_kit_output:
        clear_output()
        
    if series_plotter.series is not None:
        series_plotter.zip_btn.disabled = False
        wfn_kit_button.disabled = False
    
clear_tmp_btn = ipw.Button(description='clear tmp')
clear_tmp_btn.on_click(clear_tmp)
display(clear_tmp_btn)

In [None]:
### Load the URL after everything is set up ###
try:
    url = urlparse.urlsplit(jupyter_notebook_url)
    pk_select.value = urlparse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass