In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation

import numpy as np
import scipy.constants as const
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
import re
import gzip
import matplotlib.pyplot as plt
from collections import OrderedDict
import urlparse
import io
import zipfile
import StringIO

import matplotlib.pyplot as plt

In [None]:
def get_calc_by_label(workcalc, label):
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={'uuid':workcalc.uuid})
    qb.append(JobCalculation, output_of=WorkCalculation, filters={'label':label})
    assert qb.count() == 1
    calc = qb.first()[0]
    assert(calc.get_state() == 'FINISHED')
    return calc

In [None]:
stm_ch = None
stm_cc = None
sts = None
voltages = None
x = None
y = None
extent = None

figure_size_xy = None


def load_pk(b):
    global stm_ch, stm_cc, sts, voltages, x, y, extent
    global figure_size_xy
    try:
        workcalc = load_node(pk=pk_select.value)
        stm_image_calc = get_calc_by_label(workcalc, 'stm_images')
    except:
        print("Incorrect pk.")
        return
    stm_ch = np.load(stm_image_calc.out.retrieved.get_abs_path('stm_ch.npz'))
    stm_cc = np.load(stm_image_calc.out.retrieved.get_abs_path('stm_cc.npz'))
    sts = np.load(stm_image_calc.out.retrieved.get_abs_path('sts.npz'))

    voltages = stm_ch['bias']
    x = stm_ch['x']
    y = stm_ch['y']
    
    extent = [np.min(x), np.max(x), np.min(y), np.max(y)]

    fig_y_size = 4
    figure_size_xy = ( fig_y_size*(np.max(x)-np.min(x)) / (np.max(y)-np.min(y)), fig_y_size )
    
    load_stm_data_from_node()
    setup_sts_data()
    
    stm_zip_btn.disabled = False
    sts_zip_btn.disabled = False

pk_select = ipw.IntText(value=0, description='pk')

load_pk_btn = ipw.Button(description='Load pk')
load_pk_btn.on_click(load_pk)
display(pk_select, load_pk_btn)

# Scanning tunneling microscopy

In [None]:
import matplotlib

class FormatScalarFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, fformat="%1.1f", offset=True, mathText=True):
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,
                                                        useMathText=mathText)
    def _set_format(self, vmin, vmax):
        self.format = self.fformat
        if self._useMathText:
            self.format = '$%s$' % matplotlib.ticker._mathdefault(self.format)

def make_plot(fig, ax, data, title=None, title_size=None, center0=False, vmin=None, vmax=None, cmap='gist_heat', noadd=False):
    if center0:
        data_amax = np.max(np.abs(data))
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=-data_amax, vmax=data_amax)
    else:
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=vmin, vmax=vmax)
    
    if noadd:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.set_xlabel(r"x ($\AA$)")
        ax.set_ylabel(r"y ($\AA$)")
        if 1e-3 < np.max(data) < 1e3:
            cb = fig.colorbar(im, ax=ax)
        else:
            cb = fig.colorbar(im, ax=ax, format=FormatScalarFormatter("%.1f"))
        cb.formatter.set_powerlimits((-2, 2))
        cb.update_ticks()
    ax.set_title(title)
    if title_size:
        ax.title.set_fontsize(title_size)
    ax.axis('scaled')

    
def make_series_plot(fig, data, voltages):
    for i_bias, bias in enumerate(voltages):
        ax = plt.subplot(1, len(voltages), i_bias+1)
        make_plot(fig, ax, data[:, :, i_bias], title="V=%.2f"%bias, title_size=22, cmap='gist_heat', noadd=True)

## STM series

In [None]:
stm_options = None
stm_data = None

def load_stm_data_from_node():
    global stm_options, stm_data
    stm_options = ["Const Cur. isoval=%.1e"%isoval for isoval in  stm_cc['isovals']]
    stm_options += ["Const Height h=%.1f"%height for height in  stm_ch['heights']]
    stm_data = np.concatenate([stm_cc['data'], stm_ch['data']])
    
    drop_stm_series.options=stm_options
    
    drop_stm_series_singl.options=stm_options
    
    drop_voltage.options=voltages
    

def on_series_drop_change(c):
    with series_plot_out:
        fig = plt.figure(figsize=(figure_size_xy[0]*len(voltages), figure_size_xy[1]))
        make_series_plot(fig, stm_data[drop_stm_series.index], voltages)
        plt.show()
        
def on_clear_click(b):
    with series_plot_out:
        clear_output()

style = {'description_width': '140px'}
layout = {'width': '50%'}
        
drop_stm_series = ipw.Dropdown(description="add series", options=[])
drop_stm_series.observe(on_series_drop_change, names='value')

series_plot_out = ipw.Output()
clear_button = ipw.Button(description="clear")
clear_button.on_click(on_clear_click)

display(drop_stm_series, clear_button, series_plot_out)


## STM single

In [None]:

def plot_stm(c):
    if drop_stm_series_singl.value != None and drop_voltage.value != None: 
        with stm_plot_out:
            clear_output()
            
            fig = plt.figure(figsize=(1.1*figure_size_xy[0]+1.0, 1.1*figure_size_xy[1]))
            ax = plt.gca()
            make_plot(fig, ax, stm_data[drop_stm_series_singl.index, :, :, drop_voltage.index],
                      title=drop_stm_series_singl.value + ", v=%.1f"%drop_voltage.value)
            plt.show()

drop_stm_series_singl = ipw.Dropdown(description="series", options=[])
drop_stm_series_singl.observe(plot_stm, names='value')

drop_voltage = ipw.Dropdown(description="bias", options=[])
drop_voltage.observe(plot_stm, names='value')

stm_plot_out = ipw.Output()

display(drop_stm_series_singl, drop_voltage, stm_plot_out)

# Scanning tunneling spectroscopy

In [None]:

def setup_sts_data():
    drop_sts_height.options = sts['heights']
    energy_range_slider.min = np.min(sts['e'])
    energy_range_slider.max = np.max(sts['e'])
    energy_range_slider.value = [np.min(sts['e']), np.max(sts['e'])]
    sts_plot_btn.disabled = False

def plot_sts(c):
    sts_plot_hbox = ipw.Box(layout=box_layout)
    main_vbox.children += (sts_plot_hbox, )
    
    min_e, max_e = energy_range_slider.value
    ie_1 = np.argmax(sts['e'] >= min_e)
    if np.max(sts['e']) < max_e:
        ie_2 = len(np.max(sts['e']))
    else:
        ie_2 = np.argmax(sts['e'] >= max_e) + 1
    
    min_val = np.min(sts['data'][drop_sts_height.index, :, :, ie_1:ie_2])
    max_val = np.max(sts['data'][drop_sts_height.index, :, :, ie_1:ie_2])
    i_height = drop_sts_height.index
    height = sts['heights'][i_height]
    sts_plot_hbox.children = ()
    for i_e in range(ie_1, ie_2):
        sts_plot_out = ipw.Output()
        sts_plot_hbox.children += (sts_plot_out, )
        with sts_plot_out:
            fig = plt.figure(figsize=(figure_size_xy[0], figure_size_xy[1]))
            ax = plt.gca()
            make_plot(fig, ax, sts['data'][i_height, :, :, i_e],
                      title='h=%.1f ang, E=%.2f eV'%(height, sts['e'][i_e]),
                      vmin=min_val, vmax=max_val, cmap='seismic', noadd=True)
            plt.show()
            
def on_sts_clear(b):
    main_vbox.children = ()            
def on_sts_plot(b):
    plot_sts(0)
        
drop_sts_height = ipw.Dropdown(description="height", options=[])

sts_plot_btn = ipw.Button(description='plot', disabled=True)
sts_plot_btn.on_click(on_sts_plot)
sts_clear_btn = ipw.Button(description='clear')
sts_clear_btn.on_click(on_sts_clear)

button_hbox = ipw.HBox((sts_plot_btn, sts_clear_btn))

box_layout = ipw.Layout(overflow_x='scroll',
                    border='3px solid black',
                    width='100%',
                    height='320px',
                    display='inline-flex',
                    flex_flow='column wrap',
                    align_items='flex-start')

main_vbox = ipw.VBox()

energy_range_slider = ipw.FloatRangeSlider(
    value=[0.0, 0.0],
    min=0.0,
    max=0.0,
    step=0.1,
    description='energy range',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

display(drop_sts_height, energy_range_slider, button_hbox, main_vbox)

# Export

In [None]:
def create_zip_link(figure_method, zip_progress, html_link_out, filename):
    
    zip_buffer = io.BytesIO()
    with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED, False) as zip_file:
        figure_method(zip_file, zip_progress)
    
    # Empty the /tmp folder...
    ! mkdir -p tmp
    
    with open('tmp/'+filename, 'wb') as f:
        f.write(zip_buffer.getvalue())
    
    with html_link_out:
        display(HTML('<a href="tmp/%s" target="_blank">download zip</a>' % filename))
    
def create_stm_zip_content(zip_file, zip_progress):
    total_pics = stm_data.shape[0] + stm_data.shape[0]*len(voltages)
     # series
    for i_s in range(stm_data.shape[0]):
        series_name = stm_options[i_s].lower().replace(" ", '_').replace("=", '')
        imgdata = StringIO.StringIO()
        fig = plt.figure(figsize=(figure_size_xy[0]*len(voltages), figure_size_xy[1]))
        make_series_plot(fig, stm_data[i_s], voltages)
        fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
        zip_file.writestr("s_"+series_name+".png", imgdata.getvalue())
        plt.close()

        zip_progress.value += 1.0/float(total_pics-1)

    # individuals
    for i_s in range(stm_data.shape[0]):
        series_name = stm_options[i_s].lower().replace(" ", '_').replace("=", '')
        for i_v in range(len(voltages)):
            bias = voltages[i_v]
            plot_name = series_name + "_%dv%+.1f" % (i_v, bias)
            imgdata = StringIO.StringIO()
            fig = plt.figure(figsize=(1.1*figure_size_xy[0]+1.0, 1.1*figure_size_xy[1]))
            ax = plt.gca()
            make_plot(fig, ax, stm_data[i_s, :, :, i_v],
                  title=stm_options[i_s] + ", v=%.1f"%bias)
            fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
            zip_file.writestr(plot_name+".png", imgdata.getvalue())
            plt.close()
            
            # ---------------------------------------------------
            # Add raw data to the zip
            header = "xlim=(%.2f, %.2f), ylim=(%.2f, %.2f)" % (extent[0], extent[1],
                                                               extent[2], extent[3])
            txtdata = StringIO.StringIO()
            np.savetxt(txtdata, stm_data[i_s, :, :, i_v], header=header, fmt="%.2e")
            zip_file.writestr("txt/"+plot_name+".txt", txtdata.getvalue())
            # ---------------------------------------------------

            zip_progress.value += 1.0/float(total_pics-1)
            
def create_sts_zip_content(zip_file, zip_progress):
    
    min_e, max_e = energy_range_slider.value
    ie_1 = np.argmax(sts['e'] >= min_e)
    if np.max(sts['e']) < max_e:
        ie_2 = len(np.max(sts['e']))
    else:
        ie_2 = np.argmax(sts['e'] >= max_e) + 1
    
    min_val = np.min(sts['data'][drop_sts_height.index, :, :, ie_1:ie_2])
    max_val = np.max(sts['data'][drop_sts_height.index, :, :, ie_1:ie_2])
    i_height = drop_sts_height.index
    height = sts['heights'][i_height]
    
    total_pics = ie_2-ie_1
    
    for i_e in range(ie_1, ie_2):
        en = sts['e'][i_e]
        plot_name = "sts_h%.1f_%de%.2f" % (height, i_e-ie_1, en)
        imgdata = StringIO.StringIO()
        fig = plt.figure(figsize=(figure_size_xy[0]+1.0, figure_size_xy[1]))
        ax = plt.gca()
        make_plot(fig, ax, sts['data'][i_height, :, :, i_e],
                  title='h=%.1f ang, E=%.2f eV'%(height, en),
                  vmin=min_val, vmax=max_val, cmap='seismic', noadd=True)
        fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
        zip_file.writestr(plot_name+".png", imgdata.getvalue())
        plt.close()
        
        # ---------------------------------------------------
        # Add raw data to the zip
        header = "xlim=(%.2f, %.2f), ylim=(%.2f, %.2f)" % (extent[0], extent[1],
                                                           extent[2], extent[3])
        txtdata = StringIO.StringIO()
        np.savetxt(txtdata, sts['data'][i_height, :, :, i_e], header=header, fmt="%.2e")
        zip_file.writestr("txt/"+plot_name+".txt", txtdata.getvalue())
        # ---------------------------------------------------

        zip_progress.value += 1.0/float(total_pics-1)


def create_stm_zip_link(b):
    stm_zip_btn.disabled = True
    create_zip_link(create_stm_zip_content, stm_zip_progress, stm_link_out, "stm_%d.zip"%pk_select.value)

def create_sts_zip_link(b):
    sts_zip_btn.disabled = True
    e1, e2 = energy_range_slider.value
    create_zip_link(create_sts_zip_content, sts_zip_progress, sts_link_out, "sts_%d_e%.1f_%.1f.zip"% (pk_select.value, e1, e2))
    
stm_zip_btn = ipw.Button(description='STM zip', disabled=True)
stm_zip_btn.on_click(create_stm_zip_link)

stm_zip_progress = ipw.FloatProgress(
        value=0,
        min=0,
        max=1.0,
        description='progress:',
        bar_style='info',
        orientation='horizontal'
    )

stm_link_out = ipw.Output()
display(ipw.HBox([stm_zip_btn, stm_zip_progress]), stm_link_out)

sts_zip_btn = ipw.Button(description='STS zip', disabled=True)
sts_zip_btn.on_click(create_sts_zip_link)

sts_zip_progress = ipw.FloatProgress(
        value=0,
        min=0,
        max=1.0,
        description='progress:',
        bar_style='info',
        orientation='horizontal'
    )

sts_link_out = ipw.Output()
display(ipw.HBox([sts_zip_btn, sts_zip_progress]), sts_link_out)

def clear_tmp(b):
    ! rm -rf tmp && mkdir tmp
    with stm_link_out:
        clear_output()
    with sts_link_out:
        clear_output()
    sts_zip_progress.value = 0.0
    stm_zip_progress.value = 0.0
    if stm_ch is not None:
        sts_zip_btn.disabled = False
        stm_zip_btn.disabled = False
    
    
clear_tmp_btn = ipw.Button(description='clear tmp')
clear_tmp_btn.on_click(clear_tmp)
display(clear_tmp_btn)


In [None]:
### Load the URL after everything is set up ###
try:
    url = urlparse.urlsplit(jupyter_notebook_url)
    pk_select.value = urlparse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass