In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation

import numpy as np
import scipy.constants as const
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
import re
import gzip
import matplotlib.pyplot as plt
from collections import OrderedDict
import urlparse
import io

import matplotlib.pyplot as plt

In [None]:
def get_calc_by_label(workcalc, label):
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={'uuid':workcalc.uuid})
    qb.append(JobCalculation, output_of=WorkCalculation, filters={'label':label})
    assert qb.count() == 1
    calc = qb.first()[0]
    assert(calc.get_state() == 'FINISHED')
    return calc

In [None]:
stm_ch = None
stm_cc = None
sts = None
voltages = None
x = None
y = None
x_grid = None
y_grid = None

figure_size_xy = None


def load_pk(b):
    global stm_ch, stm_cc, sts, voltages, x, y, x_grid, y_grid
    global figure_size_xy
    try:
        workcalc = load_node(pk=pk_select.value)
        stm_image_calc = get_calc_by_label(workcalc, 'stm_images')
    except:
        print("Incorrect pk.")
        return
    stm_ch = np.load(stm_image_calc.out.retrieved.get_abs_path('stm_ch.npz'))
    stm_cc = np.load(stm_image_calc.out.retrieved.get_abs_path('stm_cc.npz'))
    sts = np.load(stm_image_calc.out.retrieved.get_abs_path('sts.npz'))

    voltages = stm_ch['bias']
    x = stm_ch['x']
    y = stm_ch['y']
    x_grid, y_grid = np.meshgrid(x, y, indexing='ij')

    fig_x_size = 6
    figure_size_xy = ( fig_x_size, fig_x_size*(np.max(y)-np.min(y)) / (np.max(x)-np.min(x)) )
    
    load_stm_data_from_node()
    setup_sts_data()

pk_select = ipw.IntText(value=0, description='pk')

load_pk_btn = ipw.Button(description='Load pk')
load_pk_btn.on_click(load_pk)
display(pk_select, load_pk_btn)

# Scanning tunneling microscopy

In [None]:
import matplotlib

class FormatScalarFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, fformat="%1.1f", offset=True, mathText=True):
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,
                                                        useMathText=mathText)
    def _set_format(self, vmin, vmax):
        self.format = self.fformat
        if self._useMathText:
            self.format = '$%s$' % matplotlib.ticker._mathdefault(self.format)


def make_plot(data, title=None, center0=False, vmin=None, vmax=None, cmap='gist_heat'):
    if not isinstance(data, (list,)):
        data = [data]
    if not isinstance(center0, (list,)):
        center0 = [center0]
    if not isinstance(title, (list,)):
        title = [title]

    plt.figure(figsize=(figure_size_xy[0]+1.0, len(data)*figure_size_xy[1]))
    for i, data_e in enumerate(data):
        plt.subplot(len(data), 1, i+1)
        if center0[i]:
            data_amax = np.max(np.abs(data_e))
            plt.pcolormesh(x_grid, y_grid, data_e, vmin=-data_amax, vmax=data_amax, cmap=cmap)
        else:
            plt.pcolormesh(x_grid, y_grid, data_e, vmin=vmin, vmax=vmax, cmap=cmap)
        plt.xlabel("x (angstrom)")
        plt.ylabel("y (angstrom)")
        if 1e-3 < np.max(data) < 1e3:
            cb = plt.colorbar()
        else:
            cb = plt.colorbar(format=FormatScalarFormatter("%.1f"))
        cb.formatter.set_powerlimits((-2, 2))
        cb.update_ticks()
        if i < len(title):
            plt.title(title[i], loc='left')
    plt.axis('scaled')
    plt.show()

    
def make_series_plot(data, voltages):
    plt.figure(figsize=(figure_size_xy[0]*len(voltages), figure_size_xy[1]))
    for i_bias, bias in enumerate(voltages):
        plt.subplot(1, len(voltages), i_bias+1)
        plt.pcolormesh(x_grid, y_grid, data[:, :, i_bias], cmap='gist_heat')
        plt.axis('scaled')
        plt.title("V=%.2f"%bias)
        plt.xticks([])
        plt.yticks([])
    plt.show()

## STM series

In [None]:

stm_options = None
stm_data = None

def load_stm_data_from_node():
    global stm_options, stm_data
    stm_options = ["Const Cur. isoval=%.1e"%isoval for isoval in  stm_cc['isovals']]
    stm_options += ["Const Height h=%.1f"%height for height in  stm_ch['heights']]
    stm_data = np.concatenate([stm_cc['data'], stm_ch['data']])
    
    drop_stm_series.options=stm_options
    
    drop_stm_series_singl.options=stm_options
    
    drop_voltage.options=voltages
    
    

def on_series_drop_change(c):
    with series_plot_out:
        make_series_plot(stm_data[drop_stm_series.index], voltages)
        
def on_clear_click(b):
    with series_plot_out:
        clear_output()

style = {'description_width': '140px'}
layout = {'width': '50%'}
        
drop_stm_series = ipw.Dropdown(description="add series", options=[])
drop_stm_series.observe(on_series_drop_change, names='value')

series_plot_out = ipw.Output()
clear_button = ipw.Button(description="clear")
clear_button.on_click(on_clear_click)

display(drop_stm_series, clear_button, series_plot_out)


## STM single

In [None]:

def plot_stm(c):
    if drop_stm_series_singl.value != None and drop_voltage.value != None: 
        with stm_plot_out:
            clear_output()
            make_plot(stm_data[drop_stm_series_singl.index, :, :, drop_voltage.index],
                      title=drop_stm_series_singl.value + ", v=%.1f"%drop_voltage.value)

drop_stm_series_singl = ipw.Dropdown(description="series", options=[])
drop_stm_series_singl.observe(plot_stm, names='value')

drop_voltage = ipw.Dropdown(description="bias", options=[])
drop_voltage.observe(plot_stm, names='value')

stm_plot_out = ipw.Output()

display(drop_stm_series_singl, drop_voltage, stm_plot_out)

# Scanning tunneling spectroscopy

In [None]:

def setup_sts_data():
    drop_sts_height.options = sts['heights']

def plot_sts(c):
    sts_plot_hbox = ipw.Box(layout=box_layout)
    main_vbox.children += (sts_plot_hbox, )
    min_val = np.min(sts['data'][drop_sts_height.index, :, :, :])
    max_val = np.max(sts['data'][drop_sts_height.index, :, :, :])
    i_height = drop_sts_height.index
    height = sts['heights'][i_height]
    sts_plot_hbox.children = ()
    for i_e, e in enumerate(sts['e']):
        sts_plot_out = ipw.Output()
        sts_plot_hbox.children += (sts_plot_out, )
        with sts_plot_out:
            make_plot(sts['data'][i_height, :, :, i_e],
                      title='h=%.1f ang, E=%.2f eV'%(height, e), vmin=min_val, vmax=max_val, cmap='seismic')
            
def on_sts_clear(b):
    main_vbox.children = ()            
def on_sts_plot(b):
    plot_sts(0)
        
drop_sts_height = ipw.Dropdown(description="height", options=[])

sts_plot_btn = ipw.Button(description='plot')
sts_plot_btn.on_click(on_sts_plot)
sts_clear_btn = ipw.Button(description='clear')
sts_clear_btn.on_click(on_sts_clear)

button_hbox = ipw.HBox((sts_plot_btn, sts_clear_btn))

box_layout = ipw.Layout(overflow_x='scroll',
                    border='3px solid black',
                    width='100%',
                    height='500px',
                    display='inline-flex',
                    flex_flow='column wrap',
                    align_items='flex-start')

main_vbox = ipw.VBox()

display(drop_sts_height, button_hbox, main_vbox)

In [None]:
#def plot_sts(c):
#    with sts_plot_out:
#        clear_output()
#        min_val = np.min(sts['data'][drop_sts_height.index, :, :, :])
#        max_val = np.max(sts['data'][drop_sts_height.index, :, :, :])
#        make_plot(sts['data'][drop_sts_height.index, :, :, e_slider.value],
#                  title='E=%.2f'%sts['e'][e_slider.value], vmin=min_val, vmax=max_val, cmap='seismic')
#
#drop_sts_height = ipw.Dropdown(description="height", options=sts['heights'])
#
#play = ipw.Play(
#    interval=1000,
#    value=0,
#    min=0,
#    max=len(sts['e'])-1,
#    step=1,
#    description="Press play",
#    disabled=False
#)
#e_slider = ipw.IntSlider(max=len(sts['e'])-1)
#ipw.jslink((play, 'value'), (e_slider, 'value'))
#
#e_slider.observe(plot_sts, names='value')
#
#sts_plot_out = ipw.Output(layout={'width': '50%', 'height':'500px'})
#
#display(drop_sts_height, ipw.VBox([ipw.HBox([play, e_slider]), sts_plot_out]))
#plot_sts(0)

In [None]:
### Load the URL after everything is set up ###
try:
    url = urlparse.urlsplit(jupyter_notebook_url)
    pk_select.value = urlparse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass