# Connect to the cluster

In [None]:
import socket
on_io = socket.gethostname() == 'io'
if on_io and 0:
    import hpc05, hpc05_monitor
    client = hpc05.connect.start_remote_and_connect(130, profile='pbs_15GB', folder='~/Work/two_dim_majoranas')[0]
#     max_usage_task = hpc05_monitor.start(client, interval=1)
else:
    # No not connect to the cluster in CI for example.
    client = None

## Numerics imports

In [None]:
from functools import partial
from copy import copy
import numpy as np
import kwant
import adaptive
import zigzag
adaptive.notebook_extension()

import pathlib
pathlib.Path('paper/figures').mkdir(parents=True, exist_ok=True) 

## Plotting imports

In [None]:
import matplotlib
matplotlib.use('agg')

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

class HistogramNormalize(matplotlib.colors.Normalize):
    def __init__(self, data, vmin=None, vmax=None, mixing_degree=1):
        self.mixing_degree = mixing_degree
        if vmin is not None:
            data = data[data >= vmin]
        if vmax is not None:
            data = data[data <= vmax]

        self.sorted_data = np.sort(data.flatten())
        matplotlib.colors.Normalize.__init__(self, vmin, vmax)

    def __call__(self, value, clip=None):
        hist_norm = np.ma.masked_array(np.searchsorted(self.sorted_data, value) /
                                       len(self.sorted_data))
        linear_norm = super().__call__(value, clip)
        return self.mixing_degree * hist_norm + (1 - self.mixing_degree) * linear_norm

golden_mean = (np.sqrt(5) - 1) / 2  # Aesthetic ratio
fig_width_pt = 246.0  # Columnwidth
inches_per_pt = 1.0 / 72.27  # Convert pt to inches
fig_width = fig_width_pt * inches_per_pt
fig_height = fig_width * golden_mean  # height in inches
fig_size = [fig_width, fig_height]

params = {
          'backend': 'ps',
          'axes.labelsize': 13,
          'font.size': 13,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': True,
          'figure.figsize': fig_size,
          'font.family': 'serif',
          'font.serif': 'Computer Modern Roman',
          'legend.frameon': True,
          'savefig.dpi': 100 if on_io else 300,
         }

plt.rcParams.update(params)
plt.rc('text.latex', preamble=[r'\usepackage{color}', r'\usepackage{bm}', r'\usepackage{xfrac}'])

## Parameter and system definitions

In [None]:
import scipy.constants
import cmath

constants = dict(
    m_eff=0.02 * scipy.constants.m_e / (scipy.constants.eV * 1e-3) / 1e18,  # effective mass in kg, 
    hbar=scipy.constants.hbar / (scipy.constants.eV * 1e-3),
    e = scipy.constants.e,
    current_unit=scipy.constants.k * scipy.constants.e / scipy.constants.hbar * 1e9,  # to get nA
    mu_B=scipy.constants.physical_constants['Bohr magneton'][0] / (scipy.constants.eV * 1e-3),
    k=scipy.constants.k / (scipy.constants.eV * 1e-3),
    exp=cmath.exp,
    cos=cmath.cos,
    sin=cmath.sin
   )

default_params = dict(
    g_factor_middle=26, g_factor_left=0, g_factor_right=0,
    mu=10,
    alpha_middle=20, alpha_left=0, alpha_right=0,
    Delta_left=1, Delta_right=1,
    B_x=0,
    B_y=0,
    B_z=0,
    phase=np.pi,
    T=0.0,
    V=0,
    **constants)

default_syst_pars = dict(
    L_m=200,
    L_x=1300,
    L_sc_up=300,
    L_sc_down=300,
    z_x=1300,
    z_y=0,
    a=10,
    shape='parallel_curve',
    transverse_soi=True,
    mu_from_bottom_of_spin_orbit_bands=True,
    k_x_in_sc=True,
    wraparound=True,
    infinite=True,
    current=False,
    ns_junction=False)

# Functions used in multiple plots

def gist_heat_r_transparent():
    import matplotlib.cm
    import matplotlib.colors as mcolors
    colors = matplotlib.cm.gist_heat_r(np.linspace(0, 1, 128))
    colors[:, 3] = np.linspace(0, 1, 128).tolist()
    gist_heat_r_transparent = mcolors.LinearSegmentedColormap.from_list('gist_heat_r_transparent', colors)
    return gist_heat_r_transparent

def deltas(syst, params):
    Deltas = [np.abs(syst.hamiltonian(i, i, params=params)[1, 0])
              for i, site in enumerate(syst.sites)]
    return [1 if x == 0 else np.nan for x in Deltas]

gray_r = matplotlib.cm.gray_r
gray_r.set_bad('#e0e0e0')

def num_to_latex_exp(x, only_exp=True, N_digits=1):
    if not only_exp:
        if N_digits == 1:
            num, exponent = f'{x:.0e}'.split('e')
        elif N_digits == 2:
            num, exponent = f'{x:.1e}'.split('e')
        return rf'{num} \times 10^{{{int(exponent)}}}'
    else:
        exponent = np.floor(np.log10(np.abs(x))).astype(int)
        return f'10^{{{int(exponent)}}}'

def fitted_majorana_size(syst, params):
    wf = zigzag.majorana_state(syst, params)[1]

    # Project onto the x-axis.
    xs, ys = np.array([s.pos for s in syst.sites]).T
    ncols = len(np.where(xs == 0)[0])
    wf_proj = wf.reshape((-1, ncols)).sum(axis=1)
    xs = xs.reshape((-1, ncols))[:, 0]

    # Only fit the middle part of the wf.
    N = len(xs) // 4
    xs_fit = xs[N:2*N]
    wf_proj_fit = wf_proj[N:2*N]

    # Do the fit.
    def flog(x, a, b, x0):
        return np.log(a) - b * x - x0

    popt, pcov = scipy.optimize.curve_fit(flog, xs_fit, np.log(wf_proj_fit))

    return 1 / popt[1]

def get_phs_breaking_potential(syst, V=10000):
    # This potential pushes one of the Majorana away such that
    # there is no overlap.
    max_y = max(syst.sites, key=lambda s: s.pos[1])
    return lambda y: V if y == max_y else 0

# Figure 2. E_gap(z_y)

## Complete zigzag system

### System size examples

In [None]:
# z_ys_plots = [10, 50, 100, 200, 300]
# for ratio in [4, 8, 16]:
#     fig, axs = plt.subplots(ncols=len(z_ys_plots))
#     for (ax, z_y) in zip(axs.reshape(-1), z_ys_plots):
#         z_x = L_x = z_y * ratio
#         syst_pars = dict(default_syst_pars, z_y=z_y, z_x=z_x, L_x=L_x, 
#                          shape='sawtooth', wraparound=False, infinite=False)
#         syst = zigzag.system(**syst_pars)

#         Deltas = deltas(syst, default_params)
#         ax.set_title(f'{z_y} nm', fontsize=10)
#         ax.set_xticks([])
#         ax.set_yticks([])
#         fig.suptitle(f'ratio {ratio}')
#         kwant.plotter.map(syst, Deltas, cmap=gray_r, ax=ax, colorbar=False);
#     plt.show()

### Define learners

In [None]:
import operator
from adaptive.learner.learner1D import curvature_loss_function

def gap_wrapper2(k_x, z_y, ratio, syst_pars=copy(default_syst_pars), params=copy(default_params), nbands=10):
    import numpy as np
    import zigzag
    z_x = z_y * ratio
    params = dict(params, k_x=k_x, mu=40, B_x=1)
    syst_pars = dict(syst_pars, z_y=z_y, z_x=z_x, L_x=z_x,
                     a=5, shape='sawtooth', L_sc_up=800, L_sc_down=800)
    syst = zigzag.system(**syst_pars)
    Es = zigzag.spectrum(syst, params, k=nbands)[0]
    return dict(E_min=np.abs(Es).min(), energies=Es)

def fnames(learner):
    val = learner.function.keywords  # because functools.partial
    fname = '__'.join([f'{k}_{v}.pickle' for k, v in val.items()])
    return 'data/z_y_vs_E_gap_mu40meV/' + fname

Learner = adaptive.make_datasaver(adaptive.Learner1D, arg_picker=operator.itemgetter('E_min'))
kwargs = dict(bounds=[-np.pi, np.pi], loss_per_interval=curvature_loss_function())
z_ys = np.arange(10, 405, 5)

combos = {'z_y': z_ys, 'ratio': [4, 8, 16]}
learner = adaptive.BalancingLearner.from_product(gap_wrapper2, Learner, kwargs, combos)
learner.strategy = 'npoints'
learner.load(fnames)

In [None]:
runner = adaptive.Runner(learner, lambda l: l.learners[0].npoints > 95, executor=client,
                         retries=20, raise_if_retries_exceeded=False)
runner.live_info()

In [None]:
saving_task = runner.start_periodic_saving(dict(fname=fnames), 600)

In [None]:
# Plot the entire band structure
def plot_bands(l):
    import holoviews as hv
    if l.data:
        ks, vals = zip(*l.extra_data.items())
        Es = np.array([x['energies'] for x in vals])
        E_min = min(l.data.values())
        scatter = hv.Overlay([hv.Scatter((ks, E)).opts(style=dict(color='k')) for E in Es.T])
        return scatter * hv.HLine(E_min)
    else:
        return hv.Overlay(hv.Scatter([]))

learner.plot(plotter=plot_bands)

In [None]:
from collections import defaultdict
d = defaultdict(dict)
for cdim, l in zip(learner._cdims_default, learner.learners):
    d[cdim['ratio']][cdim['z_y']] = l

def get_mins(learners_dict):
    learners = learners_dict.values()
    return [min(l.data.values()) for l in learners]

fig, ax = plt.subplots()
for ratio, val in d.items():
    E_gaps = get_mins(val)
    ax.plot(z_ys, E_gaps, label=f'ratio {ratio}')
ax.legend()
ax.set_xlabel('$z_y$ (nm)')
ax.set_ylabel(r'$E_{\textrm{gap}}$ (meV)')
plt.savefig('paper/figures/z_y_vs_E_gap.pdf')
plt.show()

In [None]:
import pickle
with open('data/qc/tmp.pickle', 'wb') as f:
    E_gaps = {ratio: get_mins(val) for ratio, val in d.items()}
    pickle.dump([z_ys, E_gaps], f)

## Quasi classical system

In [None]:
# "quasi classics"
a = 4
syst_pars = dict(
    default_syst_pars,
    L_sc_up=1600,
    L_sc_down=1600,
    L_x=a,
    z_x=a,
    a=a,
    shape=None,
    mu_from_bottom_of_spin_orbit_bands=False
)

params = dict(default_params, B=0.4, mu=40)

def bands(k_x, ratio, params=params, syst_pars=syst_pars):
    syst = zigzag.system(**syst_pars)
    theta = np.arctan(4 / ratio)
    params = dict(params,
                  theta=theta,
                  B_x=np.cos(theta)*params['B'],
                  B_y=np.sin(theta)*params['B'],
                  k_x=k_x*a)
    return np.min(np.abs(zigzag.spectrum(syst, params, k=4)[0]))

ratios = [4, 8, 16]
k_F = np.sqrt(params['mu'] * (2 * params['m_eff'])) / params['hbar']
loss = adaptive.learner.learner1D.triangle_loss

learner = adaptive.BalancingLearner.from_product(
    bands,
    learner_type=adaptive.Learner1D,
    learner_kwargs=dict(bounds=[0, 1.1*k_F], loss_per_interval=loss),
    combos=dict(ratio=ratios))
learner.strategy = 'npoints'

def fnames(learner):
    val = learner.function.keywords  # because functools.partial
    fname = '__'.join([f'{k}_{v}.pickle' for k, v in val.items()])
    return 'data/qc/' + fname

learner.load(fnames)

In [None]:
runner = adaptive.Runner(learner, goal=lambda l: l.learners[0].npoints > 400)
runner.live_info()

In [None]:
learner.plot()

In [None]:
from types import SimpleNamespace
from scipy.optimize import minimize_scalar
from scipy.misc import derivative

def energies(k_y, params):
    p = SimpleNamespace(**params)
    g = p.g_factor_middle
    alpha = p.alpha_middle
    E_z = 0.5 * g * p.mu_B * p.B
    sqrt_term = np.sqrt(
        + E_z**2
        + alpha * (2 * E_z * (k_y * np.cos(p.theta) - p.k_x * np.sin(p.theta))
                   + alpha * (p.k_x**2 + k_y**2))
    )
    term = p.hbar**2 * (p.k_x**2 + k_y**2) / (2 * p.m_eff) - p.mu
    return [-sqrt_term + term, sqrt_term + term]

def find_k_ys(params):
    f = lambda k_y, i: abs(energies(k_y, params)[i])
    return [minimize_scalar(f, args=i, bounds=[0, np.pi], method='bounded')['x']
            for i in [0, 1]]

def find_dE_dkys(k_ys, params):
    f = lambda k_y, params, i: energies(k_y, params)[i]
    return [derivative(f, k_y, dx=1e-6, args=[params, i])
            for i, k_y in enumerate(k_ys)]

def find_dE_dkxs(k_ys, params):
    def _f(k_x, k_y, params, i):
        params['k_x'] = float(k_x)
        return energies(k_y, params)[i]
    return [derivative(_f, params['k_x'], dx=1e-6, args=[k_y, copy(params), i])
            for i, k_y in enumerate(k_ys)]

def tan_a(params):
    k_ys = find_k_ys(params)
    dE_dkx = find_dE_dkxs(k_ys, params)
    dE_dky = find_dE_dkys(k_ys, params)
    return np.array(dE_dky) / np.array(dE_dkx)

def a_min(z_x, z_y, W, *, correction=True):
    theta = np.arctan(4 * z_y / z_x)
    crossing_point = W / np.cos(theta)
    if z_y < crossing_point / 2:
        return np.nan
    a = z_x / 2
    b = z_y * 2
    c = np.sqrt(W**2 + (b/a)**2 * W**2)
    
    if correction:
        D = np.sqrt(a**2 + (b - c)**2)
    else:
        D = np.sqrt(a**2 + (b + c)**2)
    xprime = np.sqrt(D**2 - W**2)
    return W / xprime

def get_cutoff(z_x, z_y, W, params):
    a_m = a_min(z_x, z_y, W)
    
    k_F = np.sqrt(params['mu'] * (2 * params['m_eff'])) / params['hbar']
    
    def f(k_x):
        params['k_x'] = k_x
        return abs(np.min(np.abs(tan_a(params))) - a_m)

    k_c, e = scipy.optimize.fmin(f, k_F/2, ftol=1e-7, xtol=1e-6, full_output=1, disp=0)[:2]
    return k_c if e < 1e-3 else np.inf

In [None]:
def get_gaps(z_ys, learner, ratio, syst_pars=syst_pars):
    ks, Es = zip(*sorted(learner.data.items()))
    W = syst_pars['L_m']
    params['theta'] = np.arctan(4 / ratio)
    k_max_at_z_y = np.array([get_cutoff(ratio*z_y, z_y, W, params) for z_y in z_ys])
    E_gaps = np.interp(k_max_at_z_y, ks, np.minimum.accumulate(Es, 0))
    return E_gaps

z_ys = np.linspace(100, 400, 100)
data = {cdim['ratio']: get_gaps(z_ys, l, cdim['ratio'])
        for cdim, l in zip(learner._cdims_default, learner.learners)}

In [None]:
fig, ax = plt.subplots()
for ratio, E_gaps in data.items():
    ax.plot(z_ys, E_gaps, label=f'ratio {ratio}, QC')

ax.legend()
ax.set_xlabel('$z_y$ (nm)')
ax.set_ylabel(r'$E_{\textrm{gap}}$ (meV)')
plt.savefig('paper/figures/z_y_vs_E_gap2.pdf')
plt.show()

In [None]:
import pickle
with open('data/qc/tmp_quasi.pickle', 'wb') as f:
    pickle.dump([z_ys, data], f)

In [None]:
import pickle
with open('data/qc/tmp.pickle', 'rb') as f:
    z_ys_complete, E_gaps_complete = pickle.load(f)
    
with open('data/qc/tmp_quasi.pickle', 'rb') as f:
    z_ys_quasi, E_gaps_quasi = pickle.load(f)

In [None]:
fig, ax = plt.subplots()

col = {4: 'b', 8: 'g', 16: 'red'}

for ratio, E_gaps in E_gaps_quasi.items():
    ax.plot(z_ys_quasi, E_gaps, ls='--', c=col[ratio])

for ratio, E_gaps in E_gaps_complete.items():
    ax.plot(z_ys_complete, E_gaps, c=col[ratio], label=f'ratio {ratio}')
    
ax.legend()
ax.set_xlabel('$z_y$ (nm)')
ax.set_ylabel(r'$E_{\textrm{gap}}$ (meV)')
plt.savefig('paper/figures/z_y_vs_E_gap2.pdf')
plt.show()

In [None]:
# Fermi wavelength
mu = 20
mu = mu * scipy.constants.eV * 1e-3
m = 0.02 * scipy.constants.m_e
hbar = scipy.constants.hbar
k_F = np.sqrt(2 * mu * m / hbar ** 2)
(2 * np.pi / k_F) * 1e9

# Figure 3. Bandstructures

In [None]:
def spectrum_wrapper(k_x, z_y, B_x, phase, folded, syst_pars=copy(default_syst_pars),
                     params=copy(default_params), nbands=40):
    import numpy as np
    import zigzag
    params = dict(params, k_x=k_x, B_x=B_x, phase=phase)
    syst_pars = dict(syst_pars, z_y=z_y)
    if not folded:
        syst_pars['L_x'] = syst_pars['z_x'] = syst_pars['a']
    syst = zigzag.system(**syst_pars)
    energies = zigzag.spectrum(syst, params, k=nbands)[0]
    return np.sort(energies)

def abs_min_log_loss(xs, ys):
    from adaptive.learner.learner1D import default_loss
    ys = [np.log(np.abs(y).min()) for y in ys]
    return default_loss(xs, ys)

B = 1
W = default_syst_pars['L_m']
combos = [
    (0, 0, 0, False, (-np.pi, 0)),
    (0, 0, 0, True, (0, np.pi)),
    (W/4, 0, 0, True, (-np.pi, np.pi)),
    (W/2, 0, 0, True, (-np.pi, np.pi)),
    (0, B, np.pi, False, (-np.pi, 0)),
    (0, B, np.pi, True, (0, np.pi)),
    (W/4, B, np.pi, True, (-np.pi, np.pi)),
    (W/2, B, np.pi, True, (-np.pi, np.pi))
]

learners = [adaptive.Learner1D(partial(spectrum_wrapper, z_y=z_y, B_x=B_x, phase=phase, folded=folded),
                               bounds=bounds, loss_per_interval=abs_min_log_loss)
            for z_y, B_x, phase, folded, bounds in combos]

learner = adaptive.BalancingLearner(learners, cdims=(['z_y', 'B_x', 'phi', 'folded'], combos), strategy='npoints')

fnames = [f'data/bandstructures/spectrum_{combo}.pickle' for combo in combos]
learner.load(fnames)

runner = adaptive.Runner(learner, lambda l: l.learners[0].npoints > 400, executor=client)
runner.live_info()

In [None]:
runner.task.print_stack()

In [None]:
mapping = {(z_y, B_x, phase, folded): learner for (z_y, B_x, phase, folded, bounds), learner in zip(combos, learners)}

def plot(ax, key, color, xscale=1):
    data = mapping[key].data
    xs, ys = map(np.array, zip(*sorted(data.items())))
    return ax.plot(xscale*xs, ys, c=color, lw=0.8)

fig, axs = plt.subplots(3, sharex=False, sharey=True, figsize=(fig_width, 2*fig_height))
fig.subplots_adjust(hspace=0.35)
(ax1, ax2, ax3) = axs

line2 = plot(ax1, (0, B, np.pi, True), 'C1')[0]
line1 = plot(ax1, (0, 0, 0, True), 'C0')[0]

xscale = 3
plot(ax1, (0, B, np.pi, False), 'C1', xscale)[0]
plot(ax1, (0, 0, 0, False), 'C0', xscale)[0]

plot(ax2, (W/4, B, np.pi, True), 'C1')
plot(ax2, (W/4, 0, 0, True), 'C0')

plot(ax3, (W/2, B, np.pi, True), 'C1')
plot(ax3, (W/2, 0, 0, True), 'C0')

for i, ax in enumerate(axs):
    ax.set_ylabel(r'$E/\Delta$')
    
    # ylims and yticks
    ax.set_ylim(-0.4, 0.4)
    yvals = [-0.3, 0, 0.3]
    ylabels = [f'${x}$' for x in yvals]
    ax.set_yticks(yvals)
    ax.set_yticklabels(ylabels)
    
    # xlims and xticks
    ax.set_xlim(-3, 3)
    xvals = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
    xlabels = [r'$-\pi$', r'$\sfrac{-\pi}{2}$', r'$0$', r'$\sfrac{\pi}{2}$', r'$\pi$']
    ax.set_xticks(xvals)
    ax.set_xticklabels(xlabels)

    # text inside image
    label = 'abc'[i]
    ax.text(0.95, 0.5, f'$\mathrm{{({label})}}$', transform=ax.transAxes,
            verticalalignment='center', horizontalalignment='center')
    
    z_ys = ['0', 'W/2', 'W']  # we redefine z_y to be 2*z_y in paper Fig. 1
    ax.text(0.01, 0.5, f'$z_y={z_ys[i]}$', transform=ax.transAxes,
            verticalalignment='center', horizontalalignment='left')

ax1.legend((line1, line2), (r'$\phi=0$, $B=0$', rf'$\phi=\pi$, $B = {str(B)}$ T'),
           loc='upper center', bbox_to_anchor=(0.5, 1.7),
           fancybox=True, shadow=False, ncol=1)
ax1.axvline(0, linestyle='--', color='grey')
ax3.set_xlabel(r'$z_x k_x$')

ax1_xticklabels = ax1.set_xticklabels([
    rf'${-1300 * np.pi/30:.0f}$', rf'${-1300 * np.pi/60:.0f}$',
    r'$0$', '$\sfrac{\pi}{2}$', r'$\pi$'])
ax1_xticklabels[0].set_color('r')
ax1_xticklabels[1].set_color('r')
ax1.axvspan(-np.pi, 0, alpha=0.2, color='grey')

plt.savefig("paper/figures/bandstructures.pdf", bbox_inches="tight")
plt.show()

# Figure 4. Phase diagrams

## Phase boundaries for the straight system

In [None]:
import operator
from adaptive import Learner1D, make_datasaver
from statistics import mean

def phase_boundary_wrapper(B_x, z_y, syst_pars=copy(default_syst_pars),
                           params=copy(default_params), nbands=300, sigma=0):
    import numpy as np
    import zigzag
    params = dict(params, B_x=B_x)
    syst_pars = dict(syst_pars, z_y=z_y, mu_from_bottom_of_spin_orbit_bands=True,
                     L_sc_up=800, L_sc_down=800)
    straight_system = syst_pars['z_y'] == 0
    if straight_system:
        syst_pars['z_x'] = syst_pars['L_x'] = syst_pars['a']
    else:
        raise NotImplementedError('No can do!')
    syst = zigzag.system(**syst_pars)
    return zigzag.phase_bounds(
        syst, params, k_x=0, num_bands=nbands, sigma=sigma)


W = default_syst_pars['L_m']
B_x_bounds = (0, 5)
z_y = 0
learner = Learner1D(partial(phase_boundary_wrapper,
    z_y=z_y, sigma=0), bounds=B_x_bounds)
fname = 'data/phase_boundary/phase_boundary_straight.pickle'
learner.load(fname)

runner = adaptive.Runner(learner, lambda l: l.npoints > 600, executor=client, raise_if_retries_exceeded=False)
runner.live_info()

In [None]:
def plot(l, transpose=False):
    import holoviews as hv
    if l.data:
        xs, ys = map(np.array, zip(*sorted(l.data.items())))
        if not transpose:
            plots = [hv.Scatter((xs, y)) for y in ys.T]
        else:
            plots = [hv.Scatter((y, xs)) for y in ys.T]
        plots = [p.opts(style=dict(color='k')) for p in plots]
    else:
        plots = [hv.Scatter([])]
    plot = hv.Overlay(plots)
    dims = dict(x='B_x', y='mu') if not transpose else dict(x='mu', y='B_x')
    return plot.redim(**dims)

# plot(learner, transpose=True)

In [None]:
B_xs_straight, mus_straight = map(np.array, zip(*sorted(learner.data.items())))

## Energy gaps

In [None]:
import operator
from adaptive import Learner2D, make_datasaver

def gap_wrapper(xy, keys, z_y, syst_pars=copy(default_syst_pars), params=copy(default_params)):
    import numpy as np
    import zigzag
    from scipy.optimize import brute
    params = dict(params, **dict(zip(keys, xy)))

    # We can employ a faster algo for the straight system
    syst_pars = dict(syst_pars, z_y=z_y, L_sc_up=800, L_sc_down=800)
    straight_system = syst_pars['z_y'] == 0
    if straight_system:
        syst_pars['z_x'] = syst_pars['L_x'] = syst_pars['a']
        syst_pars['wraparound'] = False
    
    syst = zigzag.system(**syst_pars)

    if straight_system:
        return [None, zigzag.gap(syst, params)]
    else:
        def energies(k_x):
            params['k_x'] = float(k_x)
            return np.abs(zigzag.spectrum(syst,
                params, k=4)[0]).min()

        return brute(energies, ranges=((0, np.pi),), Ns=101, full_output=True)

W = default_syst_pars['L_m']
combos = [(0, 'mu', (0, 20)), (0, 'phase', (0, 2*np.pi)),
          (W/2, 'mu', (0, 20)), (W/2, 'phase', (0, 2*np.pi))]

Learner = make_datasaver(Learner2D, arg_picker=operator.itemgetter(1))

learners = [Learner(function=partial(gap_wrapper, keys=[key, 'B_x'], z_y=z_y),
                               bounds=[xbounds, (0, 5)])
            for z_y, key, xbounds in combos]

learner = adaptive.BalancingLearner(learners, cdims=(['z_y', 'key'], combos), strategy='npoints')

fnames = [f'data/phase_diagrams/phase_diagram_{combo}_longer_sc.pickle' for combo in combos]
learner.load(fnames)

In [None]:
runner = adaptive.Runner(learner, lambda l: l.learners[0].npoints > 25000, executor=client)
runner.live_info()

In [None]:
runner.start_periodic_saving(dict(fname=fnames), 180)

## Combined plot

In [None]:
mapping = {(z_y, key): learner for (z_y, key, xbounds), learner in zip(combos, learners)}

E_max = max(max(l.data.values()) for l in learners)

def get_gap_from_learner(learner, mu, B_x):
    # Get an interpolated value from the learner!
    return float(learner.ip()(learner._scale((mu, B_x))).squeeze())

def plot(ax, z_y, key):
    learner = mapping[(z_y, key)]
    im = learner.plot().Image.I
    l, b, r, t = im.bounds.lbrt()
    return ax.imshow(np.rot90(np.flipud(im.data)),
                     extent=(b, t, l, r),
                     aspect='auto',
                     cmap='viridis')

def color_spline(ax, c):
    ax.spines['bottom'].set_color(c)
    ax.spines['top'].set_color(c)
    ax.spines['right'].set_color(c)
    ax.spines['left'].set_color(c)
    ax.tick_params(axis='x', colors=c)
    ax.tick_params(axis='y', colors=c)

fig, axs = plt.subplots(3, 2, sharex=False, sharey=False, figsize=(fig_width, 2.2*fig_height))
plt.subplots_adjust(bottom=0.2, left=0.125, right=0.85, top=0.8, hspace=0.22, wspace=0.12)
(ax1, ax2), (ax3, ax4), (ax5, ax6) = axs
ims = [plot(ax, z_y, key)  for ax, (z_y, key, xbounds) in zip([ax3, ax5, ax4, ax6], combos)]
tick_labels = []

for ax in [ax3, ax4, ax5, ax6]:
    # xlims and xticks
    xvals = [0, 1, 2, 3, 4, 5]
    xlabels = [f'${x}$' if ax in [ax5, ax6] else '' for x in xvals]
    ax.set_xticks(xvals)
    for tick in ax.set_xticklabels(xlabels):
        tick_labels.append(tick)

for i, ax in enumerate([ax5, ax6]):
    # ylims and yticks for phi
    yvals = [0, np.pi, 2*np.pi]
    ylabels = [r'$0$', '$\pi$', r'$2\pi$'] if ax is ax5 else ['', '', '']
    ax.set_yticks(yvals)
    for tick in ax.set_yticklabels(ylabels):
        tick_labels.append(tick)

for i, ax in enumerate([ax3, ax4]):
    # ylims and yticks for mu
    yvals = [0, 10, 20]
    ylabels = [fr'${y}$' if ax is ax3 else '' for y in yvals]
    ax.set_yticks(yvals)
    for tick in ax.set_yticklabels(ylabels):
        tick_labels.append(tick)

for i, ax in enumerate(axs.reshape(-1)):
    # text inside image
    label = 'abcdef'[i]
    ax.text(0.99, 0.97, f'$\mathrm{{({label})}}$', transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='right',
            color='white' if i > 1 else 'black')

# ax1.set_xlabel(r'$x$', labelpad=-5)
# ax2.set_xlabel(r'$x$', labelpad=-5)
# ax1.set_ylabel(r'$y$')
ax5.set_xlabel(r'$B_x$ (T)')
ax6.set_xlabel(r'$B_x$ (T)')
ax3.set_ylabel('$\mu$ (meV)')
ax5.set_ylabel('$\phi$ (rad)')

im = learners[1].plot().Image.I
N = 200 #max(l.plot().Image.I.data.shape[0] for l in learners)
all_data = np.vstack([l.plot(n=N).Image.I.data for l in learners])
norm = HistogramNormalize(all_data, vmin=0, vmax=E_max, mixing_degree=0.6)
for im in ims:
    im.set_norm(norm)

cax = fig.add_axes([0.88, 0.2, 0.04, 0.39])  # [l, b, w, h]
cb = fig.colorbar(ims[3], cax=cax)

cbar_ticks = [0, 0.20]
cb.set_ticks(cbar_ticks)
cb.set_label(r'$E_{\mathrm{gap}}/ \Delta$', labelpad=-10)

for y in mus_straight.T:
    ax3.scatter(B_xs_straight, y, c='w', s=0.5)

ax3.set_xlim(0, 5)
ax3.set_ylim(*combos[0][-1])

# Get value of magnetic field with best gap
_xs = np.linspace(0, 2)
_ys = [get_gap_from_learner(mapping[(0, 'mu')], 10, x) for x in _xs]
B_x_opt = _xs[np.argmax(_ys)]

for ax in [ax3, ax4]:
    ax.axhline(10, c='C6', ls='--', zorder=1)
    ax.scatter([B_x_opt], [10], s=20, c='C5', zorder=2)
    color_spline(ax, 'C1')

for ax in [ax5, ax6]:
    ax.axhline(np.pi, c='C1', ls='--', zorder=1)
    ax.scatter([B_x_opt], [np.pi], s=20, c='C5', zorder=2)
    color_spline(ax, 'C6')
    
for z_y, ax in zip([0, W/2], [ax1, ax2]):
    L_x = 3 * default_syst_pars['z_x']
    syst_pars = dict(default_syst_pars, L_sc_up=800, L_sc_down=800,
                     z_y=z_y, L_x=L_x, wraparound=False, infinite=False,
                     phs_breaking_potential=True)
    syst = zigzag.system(**syst_pars)
    params = dict(default_params, B_x=B_x_opt, V_breaking=lambda y: 0)
    Deltas = deltas(syst, params)
    E_M, wf = zigzag.majorana_state(syst, params)
    kwant.plotter.map(syst, Deltas, ax=ax, show=False, cmap=gray_r)
    kwant.plotter.map(syst, wf, ax=ax, show=False, cmap=gist_heat_r_transparent())

    E_gap = get_gap_from_learner(mapping[(z_y, 'mu')], 10, B_x_opt)

    if z_y == 0:
        # Use a new lead to calculate the Majorana size
        lead_pars = dict(syst_pars, L_x=syst_pars['a'], z_x=syst_pars['a'],
                         infinite=True, wraparound=False)
        lead = zigzag.system(**lead_pars)
        xi_M = zigzag.majorana_size(lead, params) / 1000
    else:
        # Use the syst to calculate the Majorana size
        params_decay = dict(params, V_breaking=get_phs_breaking_potential(syst, 100))
        xi_M = fitted_majorana_size(syst, params_decay) / 1000

    ax.text(0.01, 0.02, rf'$E_{{\textrm{{gap}}}} = {E_gap:.2f} \Delta$',
            transform=ax.transAxes,
            verticalalignment='bottom', horizontalalignment='left',
            color='black', fontsize=10.5)
    
    xi_M_str = f'{xi_M:.1f}' if xi_M < 1 else f'{xi_M:.0f}'
    ax.text(0.01, 0.94, rf'$\xi_M = {xi_M_str}$ $\mu$m',
            transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='left',
            color='black', fontsize=10.5)

for ax in [ax1, ax2]:
    ax.set_xlim(0, 1300)
    ax.set_ylim(-400, 400)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    color_spline(ax, 'C5')

for tick in tick_labels:
    # color all the tick labels to black again!
    tick.set_color('k')

plt.savefig("paper/figures/phasediagrams.pdf", bbox_inches="tight")
plt.show()

# Figure 5. Wavefunctions

In [None]:
part_to_color = {
    'middle_interior' : 'grey',
    'middle_barrier' : 'black',
    'bottom_superconductor' : 'gold',
    'top_superconductor' : 'gold',
    'top_cut' : 'red',
    'bottom_cut' : 'blue'
}

shapes = [
    'parallel_curve',
#     'parallel_curve',
    'sawtooth',
    'parallel_curve',
    'parallel_curve'
]

rough_edges = [
    None,
#     (60, 30, 1),
    None,
    None,
    (60, 30, 1)
]

z_ys = [
    0,
#     0,
    100,
    100,
    100
]


syst_pars = dict(default_syst_pars,
                 wraparound=False,
                 infinite=False,
                 phs_breaking_potential=True,
                 L_x=3.5*default_syst_pars['z_x'])
params = dict(default_params, B_x=1, V_breaking=lambda y: 0)

In [None]:
import pickle
from scipy.optimize import brute
from functools import partial
import os.path
import pickle

def energies(k_x, lead):
    return np.abs(zigzag.spectrum(
        lead, dict(params, k_x=float(k_x)), k=4)[0]).min()

# Load the precalculated data if it exists
fname = 'data/wave_functions.pickle'
if os.path.exists(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
else:
    data = []


fig, axs = plt.subplots(len(shapes), 1, sharex=True, sharey=True,
                        figsize=(fig_width, 2*fig_height))

for i, (ax, shape, z_y, rough_edge) in enumerate(
    zip(axs, shapes, z_ys, rough_edges)):
    syst_pars = dict(syst_pars, z_y=z_y, shape=shape, rough_edge=rough_edge)
    syst = zigzag.system(**syst_pars)

    if not os.path.exists(fname):
        # Recalculate the data
        
        # Calculate the Majorana wf and energy
        E_M, wf = zigzag.majorana_state(syst, params)

        if z_y == 0:
            # Make a lead to calculate the gap and the Majorana size
            lead_pars = dict(syst_pars, L_x=syst_pars['a'], z_x=syst_pars['a'],
                             infinite=True, wraparound=False)
            lead = zigzag.system(**lead_pars)
            xi_M = zigzag.majorana_size(lead, params) / 1000
            E_gap = zigzag.gap(lead, params)
        else:
            # Use the syst to calculate the Majorana size
            params_decay = dict(params, V_breaking=get_phs_breaking_potential(syst, 100))
            xi_M = fitted_majorana_size(syst, params_decay) / 1000
            # Make a lead to calculate the gap
            lead_pars = dict(syst_pars, L_x=syst_pars['z_x'], infinite=True, wraparound=True)
            lead = zigzag.system(**lead_pars)
            E_gap = brute(partial(energies, lead=lead),
                          ranges=((0, np.pi),), Ns=31, full_output=True)[1]
        data.append([E_M, wf, E_gap, xi_M])
    else:
        E_M, wf, E_gap, xi_M = data[i]

    E_M_str = num_to_latex_exp(np.abs(E_M), only_exp=False)
    E_gap_str = num_to_latex_exp(E_gap, only_exp=False, N_digits=2)

    # text inside image
    label = 'abcdef'[i]
    ax.text(0.995, 0.955, f'$\mathrm{{({label})}}$', transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='right',
            fontsize=12, color='black')

    ax.text(0.01, 0.935, rf'$E_M={E_M_str} \Delta$',
            transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='left',
            fontsize=12, color='black')

    ax.text(0.01, 0.01, rf'$E_{{\textrm{{gap}}}} = {E_gap_str} \Delta$',
            transform=ax.transAxes,
            verticalalignment='bottom', horizontalalignment='left',
            fontsize=12, color='black')

    ax.text(0.99, 0.02, rf'$\xi_M = {xi_M:.1f}$ $\mu$m',
            transform=ax.transAxes,
            verticalalignment='bottom', horizontalalignment='right',
            fontsize=12, color='black')

    ax.set_ylabel('$y$ (nm)')

    # Plot where the normal region is
    Deltas = deltas(syst, params)
    kwant.plotter.map(syst, Deltas, ax=ax, show=False, cmap=gray_r)

    # Plot the wf
    kwant.plotter.map(syst, wf, ax=ax, show=False, cmap=gist_heat_r_transparent())

# Save the data
if not os.path.exists(fname):
    with open(fname, 'wb') as f:
        pickle.dump(data, f)    

axs[-1].set_xlabel('$x$ (nm)')

plt.savefig("paper/figures/wavefunctions.pdf", bbox_inches="tight")
plt.show()