# Figures for "Enhanced proximity effect in zigzag-shaped Majorana Josephson junctions."
By Tom Laeven, [Bas Nijholt](http://github.com/basnijholt), Michael Wimmer, Anton R. Akhmerov.

Figure 2 and figure 4 both have a "Computation" cell where the data is calculated (or loads it if available) and a "Plotting" cell that only executes when the calculation is finished (or data is loaded.)

Figure 3 directly plots and calculates the data (or loads it if available.)

All of the code can be run without the data present, however, a computational cluster is required to create Fig. 4 within a reasonable time.

## Connect to the cluster

In [None]:
import socket
on_io = socket.gethostname() == 'io'  # Means we are working on the TU Delft and we can use our cluster
if on_io and 0:
    import hpc05, hpc05_monitor
    client = hpc05.connect.start_remote_and_connect(130, profile='pbs_15GB', folder='~/Work/two_dim_majoranas')[0]
    max_usage_task = hpc05_monitor.start(client, interval=1)
else:
    # No not connect to the cluster, use `None`
    # or start your own `ipcluster` or `dask.Client`.
    client = None

## Numerics imports

In [None]:
from copy import copy
from functools import partial
import pickle
import os.path

import adaptive
import kwant
import numpy as np
adaptive.notebook_extension()

import zigzag

import pathlib
pathlib.Path('paper/figures').mkdir(parents=True, exist_ok=True)

## Parameter and system definitions

In [None]:
import scipy.constants
import cmath

constants = dict(
    m_eff=0.02 * scipy.constants.m_e / (scipy.constants.eV * 1e-3) / 1e18,  # effective mass in kg, 
    hbar=scipy.constants.hbar / (scipy.constants.eV * 1e-3),
    e = scipy.constants.e,
    current_unit=scipy.constants.k * scipy.constants.e / scipy.constants.hbar * 1e9,  # to get nA
    mu_B=scipy.constants.physical_constants['Bohr magneton'][0] / (scipy.constants.eV * 1e-3),
    k=scipy.constants.k / (scipy.constants.eV * 1e-3),
    exp=cmath.exp,
    cos=cmath.cos,
    sin=cmath.sin
   )

default_params = dict(
    g_factor_middle=26, g_factor_left=0, g_factor_right=0,
    mu=10,
    alpha_middle=20, alpha_left=0, alpha_right=0,
    Delta_left=1, Delta_right=1,
    B_x=0,
    B_y=0,
    B_z=0,
    phase=np.pi,
    V=0,
    **constants)

default_syst_pars = dict(
    W=200,
    L_x=1300,
    L_sc_up=300,
    L_sc_down=300,
    z_x=1300,
    z_y=0,
    a=10,
    shape='parallel_curve',
    transverse_soi=True,
    mu_from_bottom_of_spin_orbit_bands=True,
    k_x_in_sc=True,
    wraparound=True,
    infinite=True,
)

# Functions used in multiple plots
def deltas(syst, params):
    Deltas = [np.abs(syst.hamiltonian(i, i, params=params)[1, 0])
              for i, site in enumerate(syst.sites)]
    return [1 if x == 0 else np.nan for x in Deltas]

def get_phs_breaking_potential(syst, V=10000):
    # This potential pushes one of the Majorana away such that
    # there is no overlap.
    max_y = max(syst.sites, key=lambda s: s.pos[1])
    return lambda y: V if y == max_y else 0

## Plotting imports and functions

In [None]:
import matplotlib
matplotlib.use('agg')

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

class HistogramNormalize(matplotlib.colors.Normalize):
    def __init__(self, data, vmin=None, vmax=None, mixing_degree=1):
        self.mixing_degree = mixing_degree
        if vmin is not None:
            data = data[data >= vmin]
        if vmax is not None:
            data = data[data <= vmax]

        self.sorted_data = np.sort(data.flatten())
        matplotlib.colors.Normalize.__init__(self, vmin, vmax)

    def __call__(self, value, clip=None):
        hist_norm = np.ma.masked_array(np.searchsorted(self.sorted_data, value) /
                                       len(self.sorted_data))
        linear_norm = super().__call__(value, clip)
        return self.mixing_degree * hist_norm + (1 - self.mixing_degree) * linear_norm

golden_mean = (np.sqrt(5) - 1) / 2  # Aesthetic ratio
fig_width_pt = 246.0  # Columnwidth
inches_per_pt = 1 / 72.27  # Convert pt to inches
fig_width = fig_width_pt * inches_per_pt
fig_height = fig_width * golden_mean  # height in inches
fig_size = [fig_width, fig_height]

params = {
          'backend': 'ps',
          'axes.labelsize': 13,
          'font.size': 13,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': True,
          'figure.figsize': fig_size,
          'font.family': 'serif',
          'font.serif': 'Computer Modern Roman',
          'legend.frameon': True,
          'savefig.dpi': 100 if on_io else 300,
         }

plt.rcParams.update(params)
plt.rc('text.latex', preamble=[r'\usepackage{xfrac}'])

def gist_heat_r_transparent():
    import matplotlib.cm
    import matplotlib.colors as mcolors
    colors = matplotlib.cm.gist_heat_r(np.linspace(0, 1, 128))
    colors[:, 3] = np.linspace(0, 1, 128).tolist()
    gist_heat_r_transparent = mcolors.LinearSegmentedColormap.from_list('gist_heat_r_transparent', colors)
    return gist_heat_r_transparent

gray_r = matplotlib.cm.gray_r
gray_r.set_bad('#e0e0e0')

def num_to_latex_exp(x, only_exp=True, N_digits=1):
    if not only_exp:
        if N_digits == 1:
            num, exponent = f'{x:.0e}'.split('e')
        elif N_digits == 2:
            num, exponent = f'{x:.1e}'.split('e')
        return rf'{num} \times 10^{{{int(exponent)}}}'
    else:
        exponent = np.floor(np.log10(np.abs(x))).astype(int)
        return f'10^{{{int(exponent)}}}'

# Figure 2. Bandstructures

### Computation

In [None]:
def spectrum_wrapper(k_x, z_y, B_x, phase, folded, syst_pars=copy(default_syst_pars),
                     params=copy(default_params), nbands=40):
    import numpy as np
    import zigzag
    params = dict(params, k_x=k_x, B_x=B_x, phase=phase)
    syst_pars = dict(syst_pars, z_y=z_y)
    if not folded:
        syst_pars['L_x'] = syst_pars['z_x'] = syst_pars['a']
    syst = zigzag.system(**syst_pars)
    energies = zigzag.spectrum(syst, params, k=nbands)[0]
    return np.sort(energies)

def abs_min_log_loss(xs, ys):
    from adaptive.learner.learner1D import default_loss
    ys = [np.log(np.abs(y).min()) for y in ys]
    return default_loss(xs, ys)

B = 1
W = default_syst_pars['W']
combos = [
    (0, 0, 0, False, (-np.pi, 0)),
    (0, 0, 0, True, (0, np.pi)),
    (W/4, 0, 0, True, (-np.pi, np.pi)),
    (W/2, 0, 0, True, (-np.pi, np.pi)),
    (0, B, np.pi, False, (-np.pi, 0)),
    (0, B, np.pi, True, (0, np.pi)),
    (W/4, B, np.pi, True, (-np.pi, np.pi)),
    (W/2, B, np.pi, True, (-np.pi, np.pi))
]

learners = [adaptive.Learner1D(partial(spectrum_wrapper, z_y=z_y, B_x=B_x, phase=phase, folded=folded),
                               bounds=bounds, loss_per_interval=abs_min_log_loss)
            for z_y, B_x, phase, folded, bounds in combos]

learner = adaptive.BalancingLearner(learners, cdims=(['z_y', 'B_x', 'phi', 'folded'], combos), strategy='npoints')

fnames = [f'data/bandstructures/spectrum_{combo}.pickle' for combo in combos]
learner.load(fnames)

runner = adaptive.Runner(learner, lambda l: l.learners[0].npoints > 400, executor=client)
runner.live_info()

### Plotting

In [None]:
mapping = {(z_y, B_x, phase, folded): learner for (z_y, B_x, phase, folded, bounds), learner in zip(combos, learners)}

def plot(ax, key, color, xscale=1):
    data = mapping[key].data
    xs, ys = map(np.array, zip(*sorted(data.items())))
    return ax.plot(xscale*xs, ys, c=color, lw=0.8)

fig, axs = plt.subplots(3, sharex=False, sharey=True, figsize=(fig_width, 2*fig_height))
fig.subplots_adjust(hspace=0.35)
(ax1, ax2, ax3) = axs

line2 = plot(ax1, (0, B, np.pi, True), 'C1')[0]
line1 = plot(ax1, (0, 0, 0, True), 'C0')[0]

xscale = 3
plot(ax1, (0, B, np.pi, False), 'C1', xscale)[0]
plot(ax1, (0, 0, 0, False), 'C0', xscale)[0]

plot(ax2, (W/4, B, np.pi, True), 'C1')
plot(ax2, (W/4, 0, 0, True), 'C0')

plot(ax3, (W/2, B, np.pi, True), 'C1')
plot(ax3, (W/2, 0, 0, True), 'C0')

for i, ax in enumerate(axs):
    ax.set_ylabel(r'$E/\Delta$')
    
    # ylims and yticks
    ax.set_ylim(-0.4, 0.4)
    yvals = [-0.3, 0, 0.3]
    ylabels = [f'${x}$' for x in yvals]
    ax.set_yticks(yvals)
    ax.set_yticklabels(ylabels)
    
    # xlims and xticks
    ax.set_xlim(-3, 3)
    xvals = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
    xlabels = [r'$-\pi$', r'$\sfrac{-\pi}{2}$', r'$0$', r'$\sfrac{\pi}{2}$', r'$\pi$']
    ax.set_xticks(xvals)
    ax.set_xticklabels(xlabels)

    # text inside image
    label = 'abc'[i]
    ax.text(0.95, 0.5, f'$\mathrm{{({label})}}$', transform=ax.transAxes,
            verticalalignment='center', horizontalalignment='center')
    
    z_ys = ['0', 'W/2', 'W']  # we redefine z_y to be 2*z_y in paper Fig. 1
    ax.text(0.2, 0.5, f'$z_y={z_ys[i]}$', transform=ax.transAxes,
            verticalalignment='center', horizontalalignment='left')

ax1.legend((line1, line2), (r'$\phi=0$, $B_x=0$', rf'$\phi=\pi$, $B_x = {str(B)}$ T'),
           loc='upper center', bbox_to_anchor=(0.5, 1.7),
           fancybox=True, shadow=False, ncol=1)
ax1.axvline(0, linestyle='--', color='grey')
ax3.set_xlabel(r'$z_x k_x$')

ax1_xticklabels = ax1.set_xticklabels([
    rf'${-1300 * np.pi/30:.0f}$', rf'${-1300 * np.pi/60:.0f}$',
    r'$0$', '$\sfrac{\pi}{2}$', r'$\pi$'])
ax1_xticklabels[0].set_color('r')
ax1_xticklabels[1].set_color('r')
ax1.axvspan(-np.pi, 0, alpha=0.2, color='grey')

plt.savefig("paper/figures/bandstructures.pdf", bbox_inches="tight")
plt.show()

# Figure 3. Wavefunctions

### Computation and plotting

In [None]:
shapes = [
    'parallel_curve',
    'sawtooth',
    'parallel_curve',
    'parallel_curve'
]

rough_edges = [
    None,
    None,
    None,
    (60, 30, 1)  # (X, Y, salt)
]

z_ys = [
    0,
    100,
    100,
    100
]

syst_pars = dict(default_syst_pars,
                 wraparound=False,
                 infinite=False,
                 phs_breaking_potential=True,
                 L_x=3.5*default_syst_pars['z_x'])

params = dict(default_params, B_x=1, V_breaking=lambda y: 0)

# Load the precalculated data if it exists
fname = 'data/wave_functions.pickle'
if os.path.exists(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
else:
    data = []


fig, axs = plt.subplots(len(shapes), 1, sharex=True, sharey=True,
                        figsize=(fig_width, 2*fig_height))

for i, (ax, shape, z_y, rough_edge) in enumerate(
    zip(axs, shapes, z_ys, rough_edges)):
    syst_pars = dict(syst_pars, z_y=z_y, shape=shape, rough_edge=rough_edge)
    syst = zigzag.system(**syst_pars)

    if not os.path.exists(fname):  # Recalculate the data        
        # Calculate the Majorana wf and energy
        E_M, wf = zigzag.majorana_state(syst, params)

        if z_y == 0:
            # Make a lead to calculate the gap and the Majorana size
            lead_pars = dict(syst_pars, L_x=syst_pars['a'], z_x=syst_pars['a'],
                             infinite=True, wraparound=False)
            lead = zigzag.system(**lead_pars)
            xi_M = zigzag.majorana_size_from_modes(lead, params) / 1000
            E_gap = zigzag.gap_from_modes(lead, params)
        else:
            # Use the syst to calculate the Majorana size
            params_decay = dict(params, V_breaking=get_phs_breaking_potential(syst, 100))
            xi_M = zigzag.majorana_size_from_fit(syst, params_decay) / 1000
            # Make a lead to calculate the gap
            lead_pars = dict(syst_pars, L_x=syst_pars['z_x'], infinite=True, wraparound=True)
            lead = zigzag.system(**lead_pars)
            E_gap = zigzag.gap_from_band_structure(lead, params)
        data.append([E_M, wf, E_gap, xi_M])
    else:
        E_M, wf, E_gap, xi_M = data[i]

    E_M_str = num_to_latex_exp(np.abs(E_M), only_exp=False)
    E_gap_str = num_to_latex_exp(E_gap, only_exp=False, N_digits=2)

    # text inside image
    label = 'abcdef'[i]
    ax.text(0.995, 0.955, f'$\mathrm{{({label})}}$', transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='right',
            fontsize=12, color='black')

    ax.text(0.01, 0.935, rf'$E_M={E_M_str} \Delta$',
            transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='left',
            fontsize=12, color='black')

    ax.text(0.01, 0.01, rf'$E_{{\textrm{{gap}}}} = {E_gap_str} \Delta$',
            transform=ax.transAxes,
            verticalalignment='bottom', horizontalalignment='left',
            fontsize=12, color='black')

    ax.text(0.99, 0.02, rf'$\xi_M = {xi_M:.1f}$ $\mu$m',
            transform=ax.transAxes,
            verticalalignment='bottom', horizontalalignment='right',
            fontsize=12, color='black')

    ax.set_ylabel('$y$ (nm)')

    # Plot where the normal region is
    Deltas = deltas(syst, params)
    kwant.plotter.map(syst, Deltas, ax=ax, show=False, cmap=gray_r)

    # Plot the wf
    kwant.plotter.map(syst, wf, ax=ax, show=False, cmap=gist_heat_r_transparent())

# Save the data
if not os.path.exists(fname):
    with open(fname, 'wb') as f:
        pickle.dump(data, f)    

axs[-1].set_xlabel('$x$ (nm)')

plt.savefig("paper/figures/wavefunctions.pdf", bbox_inches="tight")
plt.show()

# Figure 4. Phase diagrams

This calculation consists of two parts. We calculate:
* the phase boundaries for the straight system
* topological phase diagram with energy gaps for a straight and zigzag system

and combine it in a single plot.

## Phase boundaries for the straight system

### Computation

In [None]:
def phase_boundary_wrapper(B_x, syst_pars=copy(default_syst_pars),
                           params=copy(default_params), nbands=300, sigma=0):
    import numpy as np
    import zigzag
    params = dict(params, B_x=B_x)
    syst_pars = dict(syst_pars, z_y=0, mu_from_bottom_of_spin_orbit_bands=True,
                     L_sc_up=800, L_sc_down=800)
    straight_system = syst_pars['z_y'] == 0
    if straight_system:
        syst_pars['z_x'] = syst_pars['L_x'] = syst_pars['a']
    else:
        raise NotImplementedError('This takes too much memory.')
    syst = zigzag.system(**syst_pars)
    return zigzag.phase_bounds(
        syst, params, k_x=0, num_bands=nbands, sigma=sigma)

learner = adaptive.Learner1D(phase_boundary_wrapper, bounds=(0, 5))
fname = 'data/phase_boundary_straight.pickle'
learner.load(fname)

runner = adaptive.Runner(learner, lambda l: l.npoints > 600, executor=client, raise_if_retries_exceeded=False)
runner.live_info()

### Plotting

In [None]:
def plot(l):
    import holoviews as hv
    if l.data:
        xs, ys = map(np.array, zip(*sorted(l.data.items())))
        plots = [hv.Scatter((xs, y)).opts(style=dict(color='k')) for y in ys.T]
    else:
        plots = [hv.Scatter([])]
    plot = hv.Overlay(plots)
    dims = dict(x='B_x', y='mu')
    return plot.redim(**dims)

# plot(learner)  # uncomment for plot

In [None]:
# Write the data to variables for later
B_xs_straight, mus_straight = map(np.array, zip(*sorted(learner.data.items())))

## Energy gaps

### Computation

In [None]:
def gap_wrapper(xy, keys, z_y, syst_pars=copy(default_syst_pars), params=copy(default_params)):
    import numpy as np
    import zigzag
    params = dict(params, **dict(zip(keys, xy)))
    syst_pars = dict(syst_pars, z_y=z_y, L_sc_up=800, L_sc_down=800)

    # We can employ a faster algo for the straight system
    if syst_pars['z_y'] == 0:  # straight system
        syst_pars['z_x'] = syst_pars['L_x'] = syst_pars['a']
        syst_pars['wraparound'] = False
        syst = zigzag.system(**syst_pars)
        return zigzag.gap_from_modes(syst, params)
    else:
        syst = zigzag.system(**syst_pars)
        return zigzag.gap_from_band_structure(
            lead, params, Ns=101, full_output=False)

W = default_syst_pars['W']
combos = [(0, 'mu', (0, 20)), (0, 'phase', (0, 2*np.pi)),
          (W/2, 'mu', (0, 20)), (W/2, 'phase', (0, 2*np.pi))]

learners = [adaptive.Learner2D(function=partial(gap_wrapper, keys=['B_x', key], z_y=z_y),
                               bounds=[(0, 5), ybounds])
            for z_y, key, ybounds in combos]

learner = adaptive.BalancingLearner(learners, cdims=(['z_y', 'key'], combos), strategy='npoints')

fnames = [f'data/phase_diagrams/phase_diagram_{combo}.pickle' for combo in combos]
learner.load(fnames)

runner = adaptive.Runner(learner, lambda l: l.learners[3].npoints > 25000, executor=client, raise_if_retries_exceeded=False)
runner.live_info()

## Plotting - combined energy gaps and phase boundaries

In [None]:
mapping = {(z_y, key): learner for (z_y, key, xbounds), learner in zip(combos, learners)}

E_max = max(max(l.data.values()) for l in learners)

def value_from_learner(learner, x, y):
    # Get an interpolated value from the learner!
    return float(learner.ip()(learner._scale((x, y))).squeeze())
    
def max_value_along_vertical_cut(learner, y):
    xs = np.linspace(*learner.bounds[0])
    vals = [value_from_learner(straight_learner, x, y) for x in xs]
    return xs[np.argmax(vals)]

def plot(ax, z_y, key):
    learner = mapping[(z_y, key)]
    im = learner.plot().Image.I
    l, b, r, t = im.bounds.lbrt()
    return ax.imshow(im.data,
                     extent=(l, r, b, t),
                     aspect='auto',
                     cmap='viridis')

def color_spline(ax, c):
    ax.spines['bottom'].set_color(c)
    ax.spines['top'].set_color(c)
    ax.spines['right'].set_color(c)
    ax.spines['left'].set_color(c)
    ax.tick_params(axis='x', colors=c)
    ax.tick_params(axis='y', colors=c)

fig, axs = plt.subplots(3, 2, sharex=False, sharey=False, figsize=(fig_width, 2.2*fig_height))
plt.subplots_adjust(bottom=0.2, left=0.125, right=0.85, top=0.8, hspace=0.22, wspace=0.12)
(ax1, ax2), (ax3, ax4), (ax5, ax6) = axs

plot_map = {ax3: (0, 'mu'), ax4: (W/2, 'mu'), ax5: (0, 'phase'), ax6: (100, 'phase')}
ims = [plot(ax, z_y, key) for ax, (z_y, key) in plot_map.items()]

tick_labels = []

for ax in [ax3, ax4, ax5, ax6]:
    # xlims and xticks
    xvals = [0, 1, 2, 3, 4, 5]
    xlabels = [f'${x}$' if ax in [ax5, ax6] else '' for x in xvals]
    ax.set_xticks(xvals)
    for tick in ax.set_xticklabels(xlabels):
        tick_labels.append(tick)

# Get value of magnetic field with best gap
B_opt = max_value_along_vertical_cut(mapping[(0, 'mu')], y=10)

for ax in [ax3, ax4]:
    # Add dashed lines and dots
    ax.axhline(10, c='C6', ls='--', zorder=1)
    ax.scatter([B_opt], [10], s=20, c='C5', zorder=2)
    color_spline(ax, 'C1')
    # ylims and yticks for mu
    yvals = [0, 10, 20]
    ylabels = [fr'${y}$' if ax is ax3 else '' for y in yvals]
    ax.set_yticks(yvals)
    for tick in ax.set_yticklabels(ylabels):
        tick_labels.append(tick)

for ax in [ax5, ax6]:
    # Add dashed lines and dots
    ax.axhline(np.pi, c='C1', ls='--', zorder=1)
    ax.scatter([B_opt], [np.pi], s=20, c='C5', zorder=2)
    color_spline(ax, 'C6')

    # ylims and yticks for phi
    yvals = [0, np.pi, 2*np.pi]
    ylabels = [r'$0$', '$\pi$', r'$2\pi$'] if ax is ax5 else ['', '', '']
    ax.set_yticks(yvals)
    for tick in ax.set_yticklabels(ylabels):
        tick_labels.append(tick)

for i, ax in enumerate(axs.reshape(-1)):
    # Text labels inside image
    label = 'abcdef'[i]
    ax.text(0.99, 0.97, f'$\mathrm{{({label})}}$', transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='right',
            color='white' if i > 1 else 'black')

ax5.set_xlabel(r'$B_x$ (T)')
ax6.set_xlabel(r'$B_x$ (T)')
ax3.set_ylabel('$\mu$ (meV)')
ax5.set_ylabel('$\phi$ (rad)')

# Use the HistogramNormalize on all of the data and add a colorbar
im = learners[1].plot().Image.I
N = 200 # or max(l.plot().Image.I.data.shape[0] for l in learners)
all_data = np.vstack([l.plot(n=N).Image.I.data for l in learners])
norm = HistogramNormalize(all_data, vmin=0, vmax=E_max, mixing_degree=0.6)
for im in ims:
    im.set_norm(norm)
cax = fig.add_axes([0.88, 0.2, 0.04, 0.39])  # [l, b, w, h]
cb = fig.colorbar(ims[3], cax=cax)
cbar_ticks = [0, 0.20]
cb.set_ticks(cbar_ticks)
cb.set_label(r'$E_{\mathrm{gap}}/ \Delta$', labelpad=-10)

# Plot the phase boundaries in (c).
for y in mus_straight.T:
    ax3.scatter(B_xs_straight, y, c='w', s=0.5)
ax3.set_xlim(0, 5)
ax3.set_ylim(*combos[0][-1])

# Set ticks and lims for the wave functions
for ax in [ax1, ax2]:
    ax.set_xlim(0, 1300)
    ax.set_ylim(-400, 400)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    color_spline(ax, 'C5')

# color all the tick labels to black again!
for tick in tick_labels:
    tick.set_color('k')    

# Load the precalculated wf data if it exists
fname = 'data/wave_functions_for_phase_diagrams.pickle'
if os.path.exists(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
else:
    data = {}

# Plot the phase diagrams
for z_y, ax in zip([0, W/2], [ax1, ax2]):
    L_x = 3 * default_syst_pars['z_x']
    syst_pars = dict(default_syst_pars, L_sc_up=800, L_sc_down=800,
                     z_y=z_y, L_x=L_x, wraparound=False, infinite=False,
                     phs_breaking_potential=True)

    syst = zigzag.system(**syst_pars)
    params = dict(default_params, B_x=B_opt, V_breaking=lambda y: 0)

    if not os.path.exists(fname):  # do the calculation is the data isn't there
        E_M, wf = zigzag.majorana_state(syst, params)
        if z_y == 0:
            # Use a new lead to calculate the Majorana size
            lead_pars = dict(syst_pars, L_x=syst_pars['a'], z_x=syst_pars['a'],
                             infinite=True, wraparound=False)
            lead = zigzag.system(**lead_pars)
            xi_M = zigzag.majorana_size_from_modes(lead, params) / 1000
        else:
            # Use the syst to calculate the Majorana size
            params_decay = dict(params, V_breaking=get_phs_breaking_potential(syst, 100))
            xi_M = zigzag.majorana_size_from_fit(syst, params_decay) / 1000
        data[z_y] = (E_M, wf, xi_M)
    else:
        E_M, wf, xi_M = data[z_y]

    Deltas = deltas(syst, params)
    kwant.plotter.map(syst, Deltas, ax=ax, show=False, cmap=gray_r)
    kwant.plotter.map(syst, wf, ax=ax, show=False, cmap=gist_heat_r_transparent())

    E_gap = value_from_learner(mapping[(z_y, 'mu')], B_opt, 10)

    ax.text(0.01, 0.02, rf'$E_{{\textrm{{gap}}}} = {E_gap:.2f} \Delta$',
            transform=ax.transAxes,
            verticalalignment='bottom', horizontalalignment='left',
            color='black', fontsize=10.5)
    
    xi_M_str = f'{xi_M:.1f}' if xi_M < 1 else f'{xi_M:.0f}'
    ax.text(0.01, 0.94, rf'$\xi_M = {xi_M_str}$ $\mu$m',
            transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='left',
            color='black', fontsize=10.5)

# Save the wf data
if not os.path.exists(fname):
    with open(fname, 'wb') as f:
        pickle.dump(data, f)

plt.savefig("paper/figures/phasediagrams.pdf", bbox_inches="tight")
plt.show()