# Projected Density of States

In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)
    
from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation

import re
import urlparse
import numpy as np
from xml.etree import ElementTree
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.collections import LineCollection
from matplotlib.ticker import FormatStrFormatter
import scipy.ndimage
import ipywidgets as ipw
from IPython.display import clear_output
from IPython.core.display import HTML, Javascript
import nglview

import io
from IPython.display import FileLink, FileLinks
from base64 import b64encode

import StringIO

In [None]:
def get_calc_by_label(workcalc, label):
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={'uuid':workcalc.uuid})
    qb.append(JobCalculation, output_of=WorkCalculation, filters={'label':label})
    assert qb.count() == 1
    calc = qb.first()[0]
    assert(calc.get_state() == 'FINISHED')
    return calc

In [None]:
url = urlparse.urlsplit(jupyter_notebook_url)
params = urlparse.parse_qs(url.query)
pk = urlparse.parse_qs(url.query)['pk'][0]

workcalc = load_node(pk=int(pk))
vacuum_level = workcalc.get_extra('vacuum_level')
homo = workcalc.get_extra('homo')
lumo = workcalc.get_extra('lumo')

pdos_calc = get_calc_by_label(workcalc, "export_pdos")
bands_calc = get_calc_by_label(workcalc, "bands")
structure = bands_calc.inp.structure
ase_struct = structure.get_ase()
natoms = len(ase_struct)

bands = bands_calc.out.output_band.get_bands()
if bands.ndim == 2:
    bands = bands[None,:,:]

In [None]:
atomic_proj_xml = pdos_calc.out.retrieved.get_abs_path('atomic_proj.xml')

root = ElementTree.parse(atomic_proj_xml).getroot()
nbands = int(root.find('HEADER/NUMBER_OF_BANDS').text)
nkpoints = int(root.find('HEADER/NUMBER_OF_K-POINTS').text)
nspins = int(root.find('HEADER/NUMBER_OF_SPIN_COMPONENTS').text)
natwfcs = int(root.find('HEADER/NUMBER_OF_ATOMIC_WFC').text)

kpoint_weights = np.fromstring(root.find('WEIGHT_OF_K-POINTS').text, sep=' ')

eigvalues = np.zeros((nspins, nbands, nkpoints))
for i in range(nspins):
    for k in range(nkpoints):
        eigtag = 'EIG.%s'%(i+1) if nspins > 1 else 'EIG'
        arr = np.fromstring(root.find('EIGENVALUES/K-POINT.%d/%s'%(k+1, eigtag)).text, sep='\n')
        eigvalues[i, :, k] = arr * 13.60569806589 - vacuum_level # convert Ry to eV

projections = np.zeros((nspins, nbands, nkpoints, natwfcs))
for i in range(nspins):
    for k in range(nkpoints):
        for l in range(natwfcs):
            spintag = 'SPIN.%d/'%(i+1) if nspins > 1 else ""
            raw = root.find('PROJECTIONS/K-POINT.%d/%sATMWFC.%d'%(k+1, spintag, l+1)).text
            arr = np.fromstring(raw.replace(",", "\n"), sep="\n")
            arr2 = arr.reshape(nbands, 2) # group real and imaginary part together
            arr3 = np.sum(np.square(arr2), axis=1) # calculate square of abs value
            projections[i, :, k, l] = arr3
            

In [None]:
output_log = pdos_calc.out.retrieved.get_abs_path('aiida.out')
    
# parse mapping atomic functions -> atoms
# example:     state #   2: atom   1 (C  ), wfc  2 (l=1 m= 1)
content = open(output_log).read()
m = re.findall("\n\s+state #\s*(\d+): atom\s*(\d+) ", content, re.DOTALL)
atmwfc2atom = dict([(int(i), int(j)) for i,j in m])
assert len(atmwfc2atom) == natwfcs
assert len(set(atmwfc2atom.values())) == natoms

In [None]:
#def correct_band_crossings(kpts, eigvals_in, wfc_projs_in):
#    """ 
#    use parabola fitting and a heuristic to determine
#    bands correctly in case of crossings
#    """
#    eigvals = np.copy(eigvals_in)
#    wfc_projs = np.copy(wfc_projs_in)
#    
#    for i_spin in range(nspins):
#        for i_k in range(1, nkpoints-1):
#            k_vals = kpts[i_k-1:i_k+2]
#
#            for i_band in range(nbands):
#                for i_band2 in range(i_band+1, nbands):
#
#                    eigval = eigvals[i_spin, i_band, i_k]
#                    dif = np.abs(eigval - eigvals[i_spin, i_band, i_k-1])
#                    if np.abs(eigvals[i_spin, i_band2, i_k+1] - eigval) < 2*dif:
#
#                        e_vals_cur = eigvals[i_spin, i_band, i_k-1:i_k+2]
#                        e_vals_pre = eigvals[i_spin, i_band2, i_k-1:i_k+2]
#
#                        # switched last points
#                        e_vals_cur_sw = [e_vals_cur[0], e_vals_cur[1], e_vals_pre[2]]
#                        e_vals_pre_sw = [e_vals_pre[0], e_vals_pre[1], e_vals_cur[2]]
#
#                        # fit parabolas
#                        fit_cur = np.polyfit(k_vals, e_vals_cur, 2)
#                        fit_pre = np.polyfit(k_vals, e_vals_pre, 2)
#                        fit_cur_sw = np.polyfit(k_vals, e_vals_cur_sw, 2)
#                        fit_pre_sw = np.polyfit(k_vals, e_vals_pre_sw, 2)                            
#                        
#                        # Switch if the sum of the curvature of parabolas is smaller
#                        if (np.abs(fit_cur_sw[0]) + np.abs(fit_pre_sw[0]))+10.0 < (np.abs(fit_cur[0]) + np.abs(fit_pre[0])):
#                            
#                            orig_band = np.copy(eigvals[i_spin, i_band, i_k+1:])
#                            eigvals[i_spin, i_band, i_k+1:] = eigvals[i_spin, i_band2, i_k+1:]
#                            eigvals[i_spin, i_band2, i_k+1:] = orig_band
#                            
#                            orig_wfc = np.copy(wfc_projs[i_spin, i_band, i_k+1:, :])
#                            wfc_projs[i_spin, i_band, i_k+1:, :] = wfc_projs[i_spin, i_band2, i_k+1:, :]
#                            wfc_projs[i_spin, i_band2, i_k+1:, :] = orig_wfc
#    
#    return eigvals, wfc_projs

kpts = np.linspace(0.0, 0.5, nkpoints)
#eigvalues, projections = correct_band_crossings(kpts, eigvalues, projections)

bands = np.swapaxes(eigvalues, 1, 2) + vacuum_level

In [None]:
def w0gauss(x,n):
    arg = np.minimum(200.0, x**2)
    w0gauss = np.exp ( - arg) / np.sqrt(np.pi)
    if n==0 :
        return w0gauss
    hd = 0.0
    hp = np.exp( - arg)
    ni = 0
    a = 1.0 / np.sqrt(np.pi)
    for i in range(1, n+1):
        hd = 2.0 * x * hp - 2.0 * ni * hd
        ni = ni + 1
        a = - a / (i * 4.0)
        hp = 2.0 * x * hd-2.0 * ni * hp
        ni = ni + 1
        w0gauss = w0gauss + a * hp
    return w0gauss

In [None]:
def calc_pdos(sigma, ngauss, Emin, Emax, atmwfcs=None):
    DeltaE = 0.01
    x = np.arange(Emin,Emax,DeltaE)
    
    # calculate histogram for all spins, bands, and kpoints in parallel
    xx = np.tile(x[:, None, None, None], (1, nspins, nbands, nkpoints))
    arg = (xx - eigvalues) / sigma
    delta = w0gauss(arg, n=ngauss) / sigma
    if atmwfcs:
        p = np.sum(projections[:,:,:,atmwfcs], axis=3) # sum over selected atmwfcs
    else:
        p = np.sum(projections, axis=3) # sum over all atmwfcs

    c = delta * p * kpoint_weights
    y = np.sum(c, axis=(2,3)) # sum over bands and kpoints
    
    return x, y

In [None]:
def igor_pdos():
    center = (homo + lumo)/2.0
    Emin, Emax = center-3.0, center+3.0
    if selected_atoms:
        atmwfcs = [k-1 for k, v in atmwfc2atom.items() if v-1 in selected_atoms]
    else:
        atmwfcs = None
    pdos = calc_pdos(ngauss=ngauss_slider.value, sigma=sigma_slider.value, Emin=Emin, Emax=Emax, atmwfcs=atmwfcs)
    e = pdos[0]
    p = pdos[1].transpose()[0]
    tempio = io.StringIO()
    with tempio as f:
        f.write(u'IGOR\rWAVES\te1\tp1\rBEGIN\r')
        for x, y in zip(e, p):
            f.write(u'\t{:.8f}\t{:.8f}\r'.format(x, y))
        f.write(u'END\r')
        f.write(u'X SetScale/P x 0,1,"", e1; SetScale y 0,0,"", e1\rX SetScale/P x 0,1,"", p1; SetScale y 0,0,"", p1\r')
        return f.getvalue()
    
def mk_igor_link():
    igorvalue = igor_pdos()
    igorfile = b64encode(igorvalue)
    filename = ase_struct.get_chemical_formula() + "_pk%d.itx" % structure.pk

    html = '<a download="{}" href="'.format(filename)
    html += 'data:chemical/x-igor;name={};base64,{}"'.format(filename, igorfile)
    html += ' id="pdos_link"'
    html += ' target="_blank">Export itx-PDOS</a>'

    javascript = 'var link = document.getElementById("pdos_link");\n'
    javascript += 'link.download = "{}";'.format(filename)

    display(HTML(html))
    
def mk_bands_txt_link():
    tempio = io.StringIO()
    with tempio as f:
        np.savetxt(f, bands[0])
        value = f.getvalue()
        
    enc_file = b64encode(value)
    filename = ase_struct.get_chemical_formula() + "_pk%d.txt" % structure.pk

    html = '<a download="{}" href="'.format(filename)
    html += 'data:chemical/x-igor;name={};base64,{}"'.format(filename, enc_file)
    html += ' id="bands_link"'
    html += ' target="_blank">Export bands .txt</a>'

    javascript = 'var link = document.getElementById("bands_link");\n'
    javascript += 'link.download = "{}";'.format(filename)

    display(HTML(html))
    
def mk_png_link(fig):
    imgdata = StringIO.StringIO()
    fig.savefig(imgdata, format='png', dpi=300, bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    pngfile = b64encode(imgdata.buf)
    
    filename = ase_struct.get_chemical_formula() + "_pk%d.png" % structure.pk
    
    html = '<a download="{}" href="'.format(filename)
    html += 'data:image/png;name={};base64,{}"'.format(filename, pngfile)
    html += ' id="pdos_png_link"'
    html += ' target="_blank">Export png</a>'
    
    display(HTML(html))
    
def mk_pdf_link(fig):
    imgdata = StringIO.StringIO()
    fig.savefig(imgdata, format='pdf', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    pdffile = b64encode(imgdata.buf)
    
    filename = ase_struct.get_chemical_formula() + "_pk%d.pdf" % structure.pk
    
    html = '<a download="{}" href="'.format(filename)
    html += 'data:image/png;name={};base64,{}"'.format(filename, pdffile)
    html += ' id="pdos_png_link"'
    html += ' target="_blank">Export pdf</a>'
    
    display(HTML(html))

In [None]:
def plot_pdos(ax, pdos_full, ispin, pdos=None):
    x, y = pdos_full
    ax.plot(y[:,ispin], x, color='black') # vertical plot
    tfrm = matplotlib.transforms.Affine2D().rotate_deg(90) + ax.transData
    ax.fill_between(x, 0.0, -y[:,ispin], facecolor='lightgray', transform=tfrm)
    
    ax.set_xlim(0, 1.05*np.amax(y))
    ax.set_xlabel('DOS [a.u.]')
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
    
    if pdos != None:
        x, y = pdos
        col = matplotlib.colors.to_rgb(colorpicker.value)
        ax.plot(y[:,ispin], x, color='k')
        #ax.plot(y[:,ispin], x, color='blue')
        tfrm = matplotlib.transforms.Affine2D().rotate_deg(90) + ax.transData
        ax.fill_between(x, 0.0, -y[:,ispin], facecolor=col, transform=tfrm)
        #ax.fill_between(x, 0.0, -y[:,ispin], facecolor='cyan', transform=tfrm)

In [None]:
def var_width_lines(x, y, lw, aspect):
    
    nx = len(x)
    edge_up = [np.zeros(nx), np.zeros(nx)]
    edge_down = [np.zeros(nx), np.zeros(nx)]
    
    for i_x in range(nx):
        if i_x == 0:
            dx = x[i_x] - x[i_x+1]
            dy = y[i_x] - y[i_x+1]
        elif i_x == nx - 1:
            dx = x[i_x-1] - x[i_x]
            dy = y[i_x-1] - y[i_x]
        else:
            dx = x[i_x-1] - x[i_x+1]
            dy = y[i_x-1] - y[i_x+1]
            
        line_dir = np.array((dx,  dy))
        # Convert line vector to "figure coordinates"
        line_dir[0] /= aspect
        
        perp_dir = np.array((line_dir[1], -line_dir[0]))
        shift_vec = perp_dir/np.sqrt(perp_dir[0]**2 + perp_dir[1]**2)*lw[i_x]
        
        # convert shift_vec back to "data coordinates"
        shift_vec[0] *= aspect
        
        edge_up[0][i_x] = x[i_x] + shift_vec[0]
        edge_up[1][i_x] = y[i_x] + shift_vec[1]
        edge_down[0][i_x] = x[i_x] - shift_vec[0]
        edge_down[1][i_x] = y[i_x] - shift_vec[1]
    
    return edge_up, edge_down

def plot_bands(ax, ispin, fig_aspect, atmwfcs=None):
    nspins, nkpoints, nbands = bands.shape
    
    ax.set_title("Spin %d"%ispin)
    ax.axhline(y=homo, linewidth=2, color='gray', ls='--')
    
    ax.set_xlabel('k [$2\pi/a$]')
    x_data = np.linspace(0.0, 0.5, nkpoints)
    ax.set_xlim(0, 0.5)
    
    y_datas = bands[ispin,:,:] - vacuum_level
    
    for i_band in range(nbands):
        y_data = y_datas[:,i_band]
        ax.plot(x_data, y_data, '-', color='black')
        
        ### plot the projection on bands
        if atmwfcs is not None:
            
            line_widths = np.zeros(len(x_data))
            for atomwfc in atmwfcs:
                line_widths += projections[ispin, i_band, :, atomwfc]*band_proj_box.value
                
            edge_up, edge_down = var_width_lines(x_data, y_data, line_widths, fig_aspect)
            
            edge_up_interp = np.interp(x_data, edge_up[0], edge_up[1])
            edge_down_interp = np.interp(x_data, edge_down[0], edge_down[1])
            
            conv_kernel = np.ones(3)/3
            edge_up_smooth = scipy.ndimage.filters.convolve(edge_up_interp, conv_kernel)
            edge_down_smooth = scipy.ndimage.filters.convolve(edge_down_interp, conv_kernel)
            
            #ax.plot(x_data, edge_up_interp, '-', color='orange')
            #ax.plot(x_data, edge_down_interp, '-', color='orange')
            ax.fill_between(x_data, edge_down_smooth, edge_up_smooth, facecolor=matplotlib.colors.to_rgb(colorpicker.value))
            #ax.fill_between(x_data, edge_down_smooth, edge_up_smooth, facecolor='cyan')

In [None]:
def plot_all():
        
    sigma = sigma_slider.value
    ngauss = ngauss_slider.value
    emin = emin_box.value
    emax = emax_box.value
    
    figsize = (12, 8)
    fig = plt.figure()
    fig.set_size_inches(figsize[0], figsize[1])
    fig.subplots_adjust(wspace=0.1, hspace=0)
    
    fig_aspect = figsize[1]/(figsize[0]/4.0) * 0.5/(emax-emin)
    
    sharey = None
    pdos_full = calc_pdos(ngauss=ngauss, sigma=sigma, Emin=emin, Emax=emax)
    
    # DOS projected to selected atoms
    pdos = None
    atmwfcs = None
    if selected_atoms:
        # collect all atmwfc located on selected atoms
        atmwfcs = [k-1 for k, v in atmwfc2atom.items() if v-1 in selected_atoms]
        print("Selected atmwfcs: "+str(atmwfcs))
        pdos = calc_pdos(ngauss=ngauss, sigma=sigma, Emin=emin, Emax=emax, atmwfcs=atmwfcs)
       
    for ispin in range(nspins):
        # band plot
        ax1 = fig.add_subplot(1, 4, 2*ispin+1, sharey=sharey)
        if not sharey:
            ax1.set_ylabel('E [eV]')
            sharey = ax1
        else:
            ax1.tick_params(axis='y', which='both',left='on',right='off', labelleft='off')
        plot_bands(ax=ax1, ispin=ispin, fig_aspect=fig_aspect, atmwfcs=atmwfcs)

        # pdos plot
        ax2 = fig.add_subplot(1, 4, 2*ispin+2, sharey=sharey)
        ax2.tick_params(axis='y', which='both',left='on',right='off', labelleft='off')
        plot_pdos(ax=ax2, pdos_full=pdos_full, ispin=ispin, pdos=pdos)
    
    sharey.set_ylim(emin, emax)

    plt.show()  
    
    mk_png_link(fig)
    mk_pdf_link(fig)
    mk_bands_txt_link()
    mk_igor_link()

In [None]:
def on_picked(c):
    global selected_atoms
    
    if 'atom1' not in viewer.picked.keys():
        return # did not click on atom
    with plot_out:
        clear_output()
        #viewer.clear_representations()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_ball_and_stick()
        viewer.add_ball_and_stick()
        #viewer.add_unitcell()

        idx = viewer.picked['atom1']['index']

        # toggle
        if idx in selected_atoms:
            selected_atoms.remove(idx)
        else:
            selected_atoms.add(idx)

        #if(selection):
        sel_str = ",".join([str(i) for i in sorted(selected_atoms)])
        viewer.add_representation('ball+stick', selection="@"+sel_str, color='red', aspectRatio=3.0)
        #else:
        #    print ("nothing selected")
        viewer.picked = {} # reset, otherwise immidiately selecting same atom again won't create change event
        
        #plot_all()

In [None]:
#def on_change(c):
#    with plot_out:
#        clear_output()
#        plot_all()
        
def on_plot_click(c):
    with plot_out:
        clear_output()
        plot_all()

style = {"description_width":"200px"}
layout = ipw.Layout(width="600px")
sigma_slider = ipw.FloatSlider(description="Broadening [eV]", min=0.01, max=0.5, value=0.1, step=0.01,
                               continuous_update=False, layout=layout, style=style)
#sigma_slider.observe(on_change, names='value')
ngauss_slider = ipw.IntSlider(description="Methfessel-Paxton order", min=0, max=3, value=0,
                              continuous_update=False, layout=layout, style=style)
#ngauss_slider.observe(on_change, names='value')

colorpicker = ipw.ColorPicker(concise=True, description='PDOS color', value='orange', style=style)

center = (homo + lumo)/2.0
emin_box = ipw.FloatText(description="Emin (eV)", value=np.round(center-3.0, 1), step=0.1, style=style)
emax_box = ipw.FloatText(description="Emax (eV)", value=np.round(center+3.0, 1), step=0.1, style=style)
band_proj_box = ipw.FloatText(description="Max band width (eV)", value=0.1, step=0.01, style=style)


plot_button = ipw.Button(description="Plot")
plot_button.on_click(on_plot_click)

selected_atoms = set()    
viewer = nglview.NGLWidget()

viewer.add_component(nglview.ASEStructure(ase_struct)) # adds ball+stick
viewer.add_unitcell()
viewer.center()

viewer.observe(on_picked, names='picked')
plot_out = ipw.Output()

display(sigma_slider, ngauss_slider, viewer, emin_box, emax_box, band_proj_box, colorpicker, plot_button, plot_out)
#on_change(None)