In [None]:
%aiida

In [None]:
import numpy as np
import scipy.constants as const
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
import re
import gzip
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import OrderedDict
import urllib.parse
import io

from IPython.display import FileLink, FileLinks
from base64 import b64encode


import matplotlib
import matplotlib.pyplot as plt

from  apps.scanning_probe import common

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
def read_and_process_pdos_file(pdos_path):
    header = open(pdos_path).readline()
    fermi = float(re.search("Fermi.* ([+-]?[0-9]*[.]?[0-9]+)", header).group(1))
    try:
        kind = re.search("atomic kind.(\S+)", header).group(1)
    except:
        kind = None
    data = np.loadtxt(pdos_path)
    out_data = np.zeros((data.shape[0], 2))
    out_data[:, 0] = (data[:, 1] - fermi) * 27.21138602 # energy
    out_data[:, 1] = np.sum(data[:, 3:], axis=1) # "contracted pdos"
    return out_data, kind
    

def process_pdos_files(scf_calc):
    retr_files = scf_calc.outputs.retrieved.list_object_names()
    mol_pdos = None
    sel_pdos = []
    kind_pdos = {}
    for retr_file in sorted(retr_files):
        if retr_file.startswith('aiida-list'):
            path = scf_calc.outputs.retrieved.open(retr_file).name
            pdos, _ = read_and_process_pdos_file(path)
            if retr_file.startswith('aiida-list1'):
                mol_pdos = pdos
            else:
                sel_pdos.append(pdos)
        if retr_file.startswith('aiida-k'):
            path = scf_calc.outputs.retrieved.open(retr_file).name
            k_pdos, kind = read_and_process_pdos_file(path)
            kind_pdos[kind] = k_pdos
    return mol_pdos, kind_pdos, sel_pdos

In [None]:
mol_pdos = None
kind_pdos = None
sel_pdos = None
tdos = None
ov_matrix = None
ov_energies = None
ov_gas_homo = None
ov_gas_energies = None

pdos_options = None
gas_orb_labels = []

def load_pk(b):
    global mol_pdos, kind_pdos, sel_pdos, tdos, pdos_options
    global ov_matrix, ov_energies, ov_gas_homo, ov_gas_energies
    try:
        workcalc = load_node(pk=pk_select.value)
        slab_scf_calc = common.get_calc_by_label(workcalc, 'slab_scf')
        overlap_calc = common.get_calc_by_label(workcalc, 'overlap')
    except:
        print("Incorrect pk.")
        return
    
    # load pdos
    mol_pdos, kind_pdos, sel_pdos = process_pdos_files(slab_scf_calc)
    tdos = np.zeros(mol_pdos.shape)
    tdos[:, 0] = mol_pdos[:, 0]
    for kp in kind_pdos.items():
        tdos[:, 1] += kp[1][:, 1]
    
    sel_pdos_dict = {"sel%d"%i:e for i,e in enumerate(sel_pdos)}
    pdos_options = OrderedDict(
        list({'TDOS': tdos, 'molecule': mol_pdos}.items()) +
        list(sel_pdos_dict.items()) +
        list(kind_pdos.items())
    )
    
    # load overlap
    overlap_data = np.load(overlap_calc.outputs.retrieved.open('overlap.npz').name)
    ov_matrix = overlap_data['overlap_matrix']
    ov_energies = overlap_data['en_grp1']
    ov_gas_energies = overlap_data['en_grp2']
    ov_gas_homo = int(overlap_data['homo_grp2'])
    
    for i_gas in range(ov_matrix.shape[1]):
        wrt_h = i_gas - ov_gas_homo
        if wrt_h < 0:
            label = "HOMO%d"%wrt_h
        elif wrt_h == 0:
            label = "HOMO"
        elif wrt_h == 1:
            label = "LUMO"
        else:
            label = "LUMO+%d"%(wrt_h-1)
        label += " (%.2f)"%ov_gas_energies[i_gas]
        gas_orb_labels.append(label)
    
    initialize_selections()
    
    initialize_pdos_lines()
    initialize_overlap_lines()

pk_select = ipw.IntText(value=0, description='pk')

load_pk_btn = ipw.Button(description='Load pk')
load_pk_btn.on_click(load_pk)
display(pk_select, load_pk_btn)

# PDOS and overlap

In [None]:
def create_series_w_broadening(x_values, y_values, x_arr, fwhm, shape='g'):
    spectrum = np.zeros(len(x_arr))
    def lorentzian(x_):
        #factor = np.pi*fwhm/2 # to make maximum 1.0
        return 0.5*fwhm/(np.pi*(x_**2+(0.5*fwhm)**2))
    def gaussian(x_):
        sigma = fwhm/2.3548
        return 1/(sigma*np.sqrt(2*np.pi))*np.exp(-x_**2/(2*sigma**2))
    for xv, yv in zip(x_values, y_values):
        if shape == 'g':
            spectrum += yv*gaussian(x_arr - xv)
        else:
            spectrum += yv*lorentzian(x_arr - xv)
    return spectrum

In [None]:
def create_the_plot():
    
    fwhm = fwhm_slider.value
    de = np.min([fwhm/10, 0.005])
    elim = energy_range_slider.value
    energy_arr = np.arange(elim[0], elim[1], de)
    
    #plt.figure(figsize=(12, 6))
    fig = plt.figure(figsize=(12, 6))
    
    ### -----------------------------------------------
    ### Pdos part
    
    ax1 = plt.gca()
    
    pdos_max = 0.0
    
    for line_serie in pdos_elem_list:
        series_sel, color_picker, fill_check, norm_factor, rm_btn = line_serie
        data = pdos_options[series_sel.value]
        
        label = series_sel.value
        if norm_factor.value != 1.0:
            label = r'$%.1f\cdot$ %s' % (norm_factor.value, label)
        
        series = create_series_w_broadening(data[:, 0], data[:, 1], energy_arr, fwhm) * norm_factor.value

        ax1.plot(energy_arr, series, color_picker.value, label=label)
        if fill_check.value:
            ax1.fill_between(energy_arr, 0.0, series, facecolor=color_picker.value, alpha=0.4)
            
        if np.max(series) > pdos_max:
            pdos_max = np.max(series)
        

    ax1.set_xlim(elim)
    ax1.set_ylim([0.0, pdos_max])
    ax1.set_ylabel("DOS [a.u.]")

    ax1.legend(loc='upper left')
    
    ### -----------------------------------------------
    ### overlap part
    
    ax2 = ax1.twinx()
    
    overlap_max = 0.0
    
    for line_serie in overlap_elem_list:
        series_sel, color_picker, fill_check, norm_factor, rm_btn = line_serie
        i_data = gas_orb_labels.index(series_sel.value)
        data = ov_matrix[:, i_data]
        
        label = series_sel.value
        if norm_factor.value != 1.0:
            label = r'$%.1f\cdot$ %s' % (norm_factor.value, label)
        
        series = create_series_w_broadening(ov_energies, data, energy_arr, fwhm) * norm_factor.value
        # -----------------------------------------------------------------
        # Normalize the series such that 1 corresponds to perfect match
        #gaussian_peak = 1/(fwhm/2.3548*np.sqrt(2*np.pi))
        #series /= gaussian_peak
        # -----------------------------------------------------------------
        ax2.plot(energy_arr, series, color_picker.value, label=label, lw=2.0)
        if fill_check.value:
            ax2.fill_between(energy_arr, 0.0, series, facecolor=color_picker.value, alpha=0.4)
            
        if np.max(series) > overlap_max:
            overlap_max = np.max(series)
        
    
    #overlap_lim = np.min([np.around(overlap_max+0.055, 1), 1.0])
    overlap_lim = np.around(overlap_max+0.7)
    ax2.set_ylim([0.0, overlap_lim])
    ax2.set_ylabel("Projection density [1/eV]")
    ax2.legend(loc='upper right')

    ax1.set_xlabel("$E-E_F$ [eV]")

    plt.show()
    
    return fig

In [None]:
def initialize_selections():
    min_e = np.around(np.min(ov_energies), 1)
    max_e = np.around(np.max(ov_energies), 1)
    
    energy_range_slider.min = min_e
    energy_range_slider.max = max_e
    energy_range_slider.value = [min_e, max_e]
    
def make_plot(b):
    with plot_output:
        fig = create_the_plot()
        mk_png_link(fig)
        mk_pdf_link(fig)
        
def clear_plot(b):
    with plot_output:
        clear_output()

style = {'description_width': '140px'}
layout = {'width': '50%'}

fwhm_slider = ipw.FloatSlider(
    value=0.05,
    min=0.01,
    max=0.2,
    step=0.01,
    description='broadening fwhm (eV)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style,
    layout=layout
)

energy_range_slider = ipw.FloatRangeSlider(
    value=[0.0, 0.0],
    min=0.0,
    max=0.0,
    step=0.1,
    description='energy range (eV)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style,
    layout=layout
)

pdos_elem_list = []
pdos_line_vbox = ipw.VBox([])
add_pdos_btn = ipw.Button(description='Add pdos')

overlap_elem_list = []
overlap_line_vbox = ipw.VBox([])
add_overlap_btn = ipw.Button(description='Add overlap')


plot_output = ipw.Output()
plot_btn = ipw.Button(description="plot")
plot_btn.on_click(make_plot)
clear_btn = ipw.Button(description="clear")
clear_btn.on_click(clear_plot)

display(fwhm_slider, energy_range_slider, pdos_line_vbox, add_pdos_btn, overlap_line_vbox, add_overlap_btn, ipw.HBox([plot_btn, clear_btn]), plot_output)

In [None]:
def mk_png_link(fig):
    imgdata = io.BytesIO()
    fig.savefig(imgdata, format='png', dpi=300, bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    pngfile = b64encode(imgdata.getvalue()).decode()
    
    filename = "pdos.png"
    
    html = '<a download="{}" href="'.format(filename)
    html += 'data:image/png;name={};base64,{}"'.format(filename, pngfile)
    html += ' id="pdos_png_link"'
    html += ' target="_blank">Export png</a>'
    
    display(HTML(html))
    
def mk_pdf_link(fig):
    imgdata = io.BytesIO()
    fig.savefig(imgdata, format='pdf', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    pdffile = b64encode(imgdata.getvalue()).decode()
    
    filename = "pdos.pdf"
    
    html = '<a download="{}" href="'.format(filename)
    html += 'data:image/png;name={};base64,{}"'.format(filename, pdffile)
    html += ' id="pdos_png_link"'
    html += ' target="_blank">Export pdf</a>'
    
    display(HTML(html))

In [None]:
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def remove_line_row(b, elem_list, lines_vbox):
    rm_btn_list = [elem[4] for elem in elem_list]
    rm_index = rm_btn_list.index(b)
    del elem_list[rm_index]
    lines_vbox.children = remove_from_tuple(lines_vbox.children, rm_index)
    
def add_line_row(b, elem_list, lines_vbox, series, i_sel=0, col='black', fill=False, factor=1.0):

    series_sel = ipw.Dropdown(
        options=series,
        value=series[i_sel],
        description='series:',
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='200px')
    )

    color_picker = ipw.ColorPicker(
        concise=False,
        description='color',
        value=col,
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='200px')
    )

    fill_check = ipw.Checkbox(
        value=fill,
        description='fill',
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='auto')
    )

    norm_factor = ipw.FloatText(
        value=factor,
        step=0.01,
        description='factor',
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='150px')
    )

    rm_btn = ipw.Button(description='x', layout=ipw.Layout(width='30px'))
    rm_btn.on_click(lambda b: remove_line_row(b, elem_list, lines_vbox))
    
    elements = [series_sel, color_picker, fill_check, norm_factor, rm_btn]
    element_widths = ['210px', '210px', '70px', '160px', '30px']
    boxed_row = [ipw.HBox([row_el], layout=ipw.Layout(border='0.1px solid', width=row_w)) for row_el, row_w in zip(elements, element_widths)]
    
    elem_list.append(elements)
    lines_vbox.children += (ipw.HBox(boxed_row), )

In [None]:
def initialize_pdos_lines():
    
    add_pdos_btn.on_click(lambda b: add_line_row(b, pdos_elem_list, pdos_line_vbox, pdos_options.keys()))
    
    add_line_row(None, pdos_elem_list, pdos_line_vbox, list(pdos_options.keys()), 0, 'lightgray', True, 0.1)
    add_line_row(None, pdos_elem_list, pdos_line_vbox, list(pdos_options.keys()), 1, 'black', True)
    
def initialize_overlap_lines():
    
    mpl_def_colors = [col['color'] for col in list(mpl.rcParams['axes.prop_cycle'])]
    
    add_overlap_btn.on_click(lambda b: add_line_row(b, overlap_elem_list, overlap_line_vbox, gas_orb_labels))
    
    for i_gas, gas_orb_label in enumerate(gas_orb_labels):
        
        add_line_row(None, overlap_elem_list, overlap_line_vbox, gas_orb_labels, i_gas,
                     mpl_def_colors[i_gas%len(mpl_def_colors)], False)

In [None]:
### Load the URL after everything is set up ###
try:
    url = urllib.parse.urlsplit(jupyter_notebook_url)
    pk_select.value = urllib.parse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass