In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import matplotlib

class FormatScalarFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, fformat="%1.1f", offset=True, mathText=True):
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,
                                                        useMathText=mathText)
    def _set_format(self, vmin, vmax):
        self.format = self.fformat
        if self._useMathText:
            self.format = '$%s$' % matplotlib.ticker._mathdefault(self.format)


def make_plot(fig, ax, data, extent, title=None, title_size=None, center0=False, vmin=None, vmax=None, cmap='gist_heat', noadd=False):
    if center0:
        data_amax = np.max(np.abs(data))
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=-data_amax, vmax=data_amax)
    else:
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=vmin, vmax=vmax)
    
    if noadd:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.set_xlabel(r"x ($\AA$)")
        ax.set_ylabel(r"y ($\AA$)")
        if 1e-3 < np.max(data) < 1e3:
            cb = fig.colorbar(im, ax=ax)
        else:
            cb = fig.colorbar(im, ax=ax, format=FormatScalarFormatter("%.1f"))
        cb.formatter.set_powerlimits((-2, 2))
        cb.update_ticks()
    ax.set_title(title)
    if title_size:
        ax.title.set_fontsize(title_size)
    ax.axis('scaled')
    
    
def make_series_plot(fig, data, voltages):
    for i_bias, bias in enumerate(voltages):
        ax = plt.subplot(1, len(voltages), i_bias+1)
        make_plot(fig, ax, data[:, :, i_bias], title="V=%.2f"%bias, title_size=22, cmap='gist_heat', noadd=True)

In [None]:
npz_path = "/home/kristjan/local_work/stm-test/stms/stm_pa_0.2.npz"
loaded_data = np.load(npz_path)

In [None]:
isovalues = loaded_data['isovalues']
heights = loaded_data['heights']
e_arr = loaded_data['e_arr']
x_arr = loaded_data['x_arr']
y_arr = loaded_data['y_arr']

cc_sts = loaded_data['cc_sts']
cc_stm = loaded_data['cc_stm'].astype(np.float32)
ch_sts = loaded_data['ch_sts']
ch_stm = loaded_data['ch_stm']

In [None]:
extent = [np.min(x_arr), np.max(x_arr), np.min(y_arr), np.max(y_arr)]

In [None]:
sets = [cc_stm, cc_sts, ch_stm, ch_sts]
sel_set = [0, 0, 0, 0]

figsize = (12, 3*(extent[3] - extent[2])/(extent[1] - extent[0]))

for i_e, e in enumerate(e_arr):
    fig, ax_arr = plt.subplots(1, 4, figsize=figsize)
    print("E = %.2f eV" % e)
    for i_ax, ax in enumerate(ax_arr):
        make_plot(fig, ax, sets[i_ax][sel_set[i_ax], :, :, i_e], extent, title="", noadd=True)
    plt.show()