In [None]:
%load_ext aiida
%aiida

In [None]:
import numpy as np
#import scipy.constants as const
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
#import re
#import gzip
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import OrderedDict
import urllib.parse
import io

#from IPython.display import FileLink, FileLinks
from base64 import b64encode


#import matplotlib
import matplotlib.pyplot as plt

from scanning_probe import common

import scanning_probe.pdos.pdos_postprocess as pdos_pp

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
dos_data = None
dos_options = None
overlap_data = None
orbital_labels = None
energy_lim = None

def load_pk(b):
    global dos_data, dos_options, overlap_data, orbital_labels, energy_lim

    new_version = False
    workcalc = load_node(pk=pk_select.value)
    overlap_calc = common.get_calc_by_label(workcalc, 'overlap')    
    try:
        slab_scf_calc = common.get_calc_by_label(workcalc, 'slab_scf')
    except AssertionError:
        try:
            slab_scf_calc = workcalc   
            new_version = True        
        except IndexError:        
            print("Incorrect pk.")
            return

    labels_are_present =  'molecule' in workcalc.inputs.pdos_lists[0] and len(workcalc.inputs.pdos_lists) >1
    energy_lim = [
        float(workcalc.inputs.overlap_params['--emin1']),
        float(workcalc.inputs.overlap_params['--emax1']),
    ]
    
    geom_info.value = common.get_slab_calc_info(workcalc.inputs.slabsys_structure)
    
    dos_data = pdos_pp.process_pdos_files(slab_scf_calc,new_version)
    if not labels_are_present:
        dos_options = OrderedDict([
            ('total DOS', dos_data['tdos']),
            ('molecule PDOS', dos_data['mol']),
            *[(f"selection {k.split('_')[-1]}", dos_data[k]) for k in dos_data if k.startswith('sel')],
            *[(f"kind {k.split('_')[-1]}", dos_data[k]) for k in dos_data if k.startswith('kind_')]
        ])
    else:
        labels = [sel[1] for sel in workcalc.inputs.pdos_lists[1:]  ]
        dos_options = OrderedDict([
            ('total DOS', dos_data['tdos']),
            ('molecule PDOS', dos_data['mol']),
            *[(labels[int(k[4:])-2], dos_data[k]) for k in dos_data if k.startswith('sel')],
            *[(f"kind {k.split('_')[-1]}", dos_data[k]) for k in dos_data if k.startswith('kind_')]
        ])
    
    with overlap_calc.outputs.retrieved.open('overlap.npz',mode='rb') as fhandle:
        overlap_data = pdos_pp.load_overlap_npz(fhandle.name)
        
    orbital_labels = pdos_pp.get_full_orbital_labels(overlap_data)
    #orbital_labels = [i for sublist in orbital_labels for i in sublist]
    
    initialize_selections()
    initialize_pdos_lines()
    initialize_overlap_lines()

pk_select = ipw.IntText(value=0, description='pk')

load_pk_btn = ipw.Button(description='Load pk')
load_pk_btn.on_click(load_pk)

geom_info = ipw.HTML()

display(ipw.HBox([ipw.VBox([pk_select, load_pk_btn]), geom_info]))

# PDOS and overlap

In [None]:
def create_the_plot():
    
    fwhm = fwhm_slider.value
    de = np.min([fwhm/10, 0.005])
    elim = energy_range_slider.value
    energy_arr = np.arange(elim[0], elim[1], de)
    
    ### -----------------------------------------------
    ### Collect data into an array (w headers) as well
    collect_data = np.reshape(energy_arr,(1, energy_arr.size))
    collect_data_headers = ['energy [eV]']
    
    ### -----------------------------------------------
    ### Make the figure
    fig = plt.figure(figsize=(12, 6))
    
    ### ------------------------------
    ### Pdos part
    
    ax1 = plt.gca()
    
    ylim = [None, None]
    
    for line_serie in pdos_elem_list:
        series_sel, color_picker, norm_factor, rm_btn = line_serie
        data = dos_options[series_sel.value]
        label = series_sel.value
        if norm_factor.value != 1.0:
            label = r'$%.2f\cdot$ %s' % (norm_factor.value, label)
            
        for i_spin in range(len(data)):

            series = pdos_pp.create_series_w_broadening(data[i_spin][:, 0], data[i_spin][:, 1], energy_arr, fwhm)
            series *= norm_factor.value
            
            kwargs = {}
            if i_spin == 0:
                kwargs['label'] = label
            if 'molecule' in label.lower():
                kwargs['zorder'] = 300
                if i_spin == 0:
                    ylim[1] = 1.2 * np.max(series)
                else:
                    ylim[0] = 1.2 * np.min(-series)
            
            ax1.plot(energy_arr, series * (-2* i_spin + 1), color_picker.value, **kwargs)
            ax1.fill_between(energy_arr, 0.0, series * (-2* i_spin + 1), facecolor=color_picker.value, alpha=0.2)
            
            collect_data_headers.append(f'{label} s{i_spin}')
            collect_data = np.vstack([collect_data, series])
        
    ### ------------------------------
    ### overlap part
    
    for i_spin in range(overlap_data['nspin_g2']):
        cumulative = None
    
        for i_ser, line_serie in enumerate(overlap_elem_list[i_spin]):
            series_sel, color_picker, norm_factor, rm_btn = line_serie
            i_orb = orbital_labels[i_spin].index(series_sel.value)
            data = overlap_data['overlap_matrix'][i_spin][:, i_orb]

            label = series_sel.value
            if norm_factor.value != 1.0:
                label = r'$%.1f\cdot$ %s' % (norm_factor.value, label)

            series = pdos_pp.create_series_w_broadening(overlap_data['energies_g1'][i_spin], data, energy_arr, fwhm)
            series *= norm_factor.value
            
            if cumulative is None:
                cumulative = series
            else:
                cumulative += series
            
            ax1.fill_between(energy_arr, 0.0, cumulative * (-2* i_spin + 1),
                         facecolor=color_picker.value, alpha=1.0, zorder=-i_ser+100, label=label)
            
            collect_data_headers.append(f'{label} s{i_spin}')
            collect_data = np.vstack([collect_data, series])
            
        if i_spin == 0 and overlap_data['nspin_g2'] == 2:
            # add empty legend entries to align the spin channels
            for i in range(len(pdos_elem_list)):
                ax1.fill_between([0.0], 0.0, [0.0], color='w', alpha=0, label=' ')
    
    plt.legend(ncol=overlap_data['nspin_g2'], loc='center left',bbox_to_anchor=(1.01, 0.5))
    
    plt.xlim([np.min(energy_arr), np.max(energy_arr)])
    
    if overlap_data['nspin_g2'] == 1:
        ylim[0] = 0.0
    plt.ylim(ylim)
    
    plt.axhline(0.0, color='k', lw=2.0, zorder=200)
    
    plt.ylabel("Density of States [a.u.]")
    plt.xlabel("$E-E_F$ [eV]")

    plt.show()
    
    return fig, (collect_data_headers, collect_data.T)

In [None]:
def initialize_selections():    
    energy_range_slider.min = energy_lim[0]
    energy_range_slider.max = energy_lim[1]
    energy_range_slider.value = [energy_lim[0], energy_lim[1]]
    
def make_plot(b):
    with plot_output:
        fig, collected_data = create_the_plot()
        mk_png_link(fig)
        mk_pdf_link(fig)
        mk_txt_link(collected_data)
        
def clear_plot(b):
    with plot_output:
        clear_output()

style = {'description_width': '140px'}
layout = {'width': '50%'}

fwhm_slider = ipw.FloatSlider(
    value=0.10,
    min=0.01,
    max=0.2,
    step=0.01,
    description='broadening fwhm (eV)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style,
    layout=layout
)

energy_range_slider = ipw.FloatRangeSlider(
    value=[0.0, 0.0],
    min=0.0,
    max=0.0,
    step=0.1,
    description='energy range (eV)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style,
    layout=layout
)

pdos_elem_list = []
pdos_line_vbox = ipw.VBox([])
add_pdos_btn = ipw.Button(description='Add pdos')

overlap_elem_list = [[], []]
overlap_line_vboxes = [ipw.VBox([]), ipw.VBox([])]

add_overlap_btns = [
    ipw.Button(description='Add overlap'),
    ipw.Button(description='Add overlap beta')
]
add_overlap_btns[1].layout.display = 'none'

plot_output = ipw.Output()
plot_btn = ipw.Button(description="plot")
plot_btn.on_click(make_plot)
clear_btn = ipw.Button(description="clear")
clear_btn.on_click(clear_plot)

display(
    fwhm_slider, energy_range_slider, pdos_line_vbox, add_pdos_btn,
    overlap_line_vboxes[0], overlap_line_vboxes[1], ipw.HBox(add_overlap_btns),
    ipw.HBox([plot_btn, clear_btn]), plot_output
)

In [None]:
def mk_png_link(fig):
    imgdata = io.BytesIO()
    fig.savefig(imgdata, format='png', dpi=300, bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    pngfile = b64encode(imgdata.getvalue()).decode()
    
    filename = "pdos.png"
    
    html = '<a download="{}" href="'.format(filename)
    html += 'data:image/png;name={};base64,{}"'.format(filename, pngfile)
    html += ' id="pdos_png_link"'
    html += ' target="_blank">Export png</a>'
    
    display(HTML(html))
    
def mk_pdf_link(fig):
    imgdata = io.BytesIO()
    fig.savefig(imgdata, format='pdf', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    pdffile = b64encode(imgdata.getvalue()).decode()
    
    filename = "pdos.pdf"
    
    html = '<a download="{}" href="'.format(filename)
    html += 'data:image/png;name={};base64,{}"'.format(filename, pdffile)
    html += ' id="pdos_png_link"'
    html += ' target="_blank">Export pdf</a>'
    
    display(HTML(html))
    
def mk_txt_link(collected_data):
    headers, data = collected_data
    
    header = ", ".join(headers)
    
    tempio = io.BytesIO()
    np.savetxt(tempio, data, header=header, fmt="%.4e", delimiter=', ')

    enc_file = b64encode(tempio.getvalue()).decode()

    filename = "pdos.txt"

    html = '<a download="{}" href="'.format(filename)
    html += 'data:chemical/txt;name={};base64,{}"'.format(filename, enc_file)
    html += ' id="export_link"'
    html += ' target="_blank">Export txt</a>'
    
    display(HTML(html))
    

In [None]:
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def remove_line_row(b, elem_list, lines_vbox):
    rm_btn_list = [elem[3] for elem in elem_list]
    rm_index = rm_btn_list.index(b)
    del elem_list[rm_index]
    lines_vbox.children = remove_from_tuple(lines_vbox.children, rm_index)
    
def add_line_row(b, elem_list, lines_vbox, series, i_sel=0, col='black', factor=1.0):
    
    series_sel = ipw.Dropdown(
        options=series,
        value=series[i_sel],
        description='series:',
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='200px')
    )

    color_picker = ipw.ColorPicker(
        concise=False,
        description='color',
        value=col,
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='200px')
    )

    norm_factor = ipw.FloatText(
        value=factor,
        step=0.01,
        description='factor',
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='150px')
    )

    rm_btn = ipw.Button(description='x', layout=ipw.Layout(width='30px'))
    rm_btn.on_click(lambda b: remove_line_row(b, elem_list, lines_vbox))
    
    elements = [series_sel, color_picker, norm_factor, rm_btn]
    element_widths = ['210px', '210px', '160px', '35px']
    boxed_row = [ipw.HBox([row_el], layout=ipw.Layout(border='0.1px solid', width=row_w)) for row_el, row_w in zip(elements, element_widths)]
    
    elem_list.append(elements)
    lines_vbox.children += (ipw.HBox(boxed_row), )

In [None]:
def initialize_pdos_lines():
    
    add_pdos_btn.on_click(lambda b: add_line_row(b, pdos_elem_list, pdos_line_vbox, list(dos_options.keys())))
    
    add_line_row(None, pdos_elem_list, pdos_line_vbox, list(dos_options.keys()), 0, 'lightgray', 0.02)
    add_line_row(None, pdos_elem_list, pdos_line_vbox, list(dos_options.keys()), 1, 'black')
    
def initialize_overlap_lines():
    
    mpl_def_colors = [col['color'] for col in list(mpl.rcParams['axes.prop_cycle'])]
    mpl_def_blu = ['#3d5a80' , '#98c1d9', '#e0fbfc','#219ebc']
    mpl_def_red = ['#d90429','#ee6c4d','#fb8500','#ffb703']
    mpl_def_blu.reverse()
    for i in range(8-overlap_data['homo_i_g2'][0]):
        mpl_def_colors.append(mpl_def_colors.pop(0))
    
    if overlap_data['nspin_g2'] == 2:
        add_overlap_btns[0].description = 'Add overlap alpha'
        add_overlap_btns[1].layout.display = None
    
    for i_spin in range(overlap_data['nspin_g2']):
        add_overlap_btns[i_spin].on_click(
            lambda b, s=i_spin: add_line_row(b, overlap_elem_list[s], overlap_line_vboxes[s], orbital_labels[s])
        )

        ihomo=0
        ilumo=0
        for i_gas, label in enumerate(orbital_labels[i_spin]):
            if 'HOMO' in label:
                color = mpl_def_red[ihomo%len(mpl_def_red)]
                ihomo += 1
            else:
                color = mpl_def_blu[ilumo%len(mpl_def_red)]
                ilumo += 1                

            add_line_row(None, overlap_elem_list[i_spin], overlap_line_vboxes[i_spin], orbital_labels[i_spin],
                         i_gas, color)

In [None]:
### Load the URL after everything is set up ###
try:
    url = urllib.parse.urlsplit(jupyter_notebook_url)
    pk_select.value = urllib.parse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass