# Imports

In [None]:
%load_ext autoreload

In [None]:
import os, sys
import holoviews as hv
sys.path.append(os.path.abspath('../../two_dim_majoranas/'))
import hpc05

import ipywidgets as widgets
run_cluster = widgets.Checkbox(
    value=False,
    description='Check to run cluster',
    disabled=False
)
def assert_cluster_checked():
    assert run_cluster.value is True, "Command not run. Check the box above to run."

In [None]:
import adaptive
adaptive.notebook_extension()

import numpy as np
import scipy.constants
import cmath

import functools as ft

import sns_system, plotting_results
from distributed_sns import AggregatesSimulationSet as ASS
from distributed_sns import SimulationSet as SS

### Define constants

In [None]:
constants = dict(
    m_eff=0.023 * scipy.constants.m_e / (scipy.constants.eV * 1e-3) / 1e18,  # effective mass in kg, 
    hbar=scipy.constants.hbar / (scipy.constants.eV * 1e-3),
    e = scipy.constants.e,
    current_unit=scipy.constants.k * scipy.constants.e / scipy.constants.hbar * 1e9,  # to get nA
    mu_B=scipy.constants.physical_constants['Bohr magneton'][0] / (scipy.constants.eV * 1e-3),
    k=scipy.constants.k / (scipy.constants.eV * 1e-3),
    exp=cmath.exp,
    cos=cmath.cos,
    sin=cmath.sin
   )

# Cluster setup

In [None]:
run_cluster

In [None]:
assert_cluster_checked()
hpc05.kill_remote_ipcluster()
run_cluster.value = False

In [None]:
assert_cluster_checked()
client, dview, lview = hpc05.start_remote_and_connect(300, folder='~/two_dim_majoranas', timeout=180)
run_cluster.value = False

# Define and plot system

In [None]:
syst_pars = {'Ll' :  500,
               'Lr' : 500,
               'Lm' : 500,
               'Ly' : 12.5,
               'a' :  12.5,
            'mu_from_bottom_of_spin_orbit_bands': True}

_=plotting_results.plot_syst(syst_pars, sns_system.dummy_params)

# Define standard parameters

In [None]:
params_raw= dict(g_factor_middle = 10,
                 g_factor_left = 0,
                 g_factor_right = 0,
                 mu = 10.0,
                 alpha_middle = 28,
                 alpha_left = 28,
                 alpha_right = 28,
                 Delta_left = .18,
                 Delta_right = .18,
                 B = 0.5,
                 phase = np.pi/2,
                 T = 0.025)

params = dict(**constants,
              **params_raw)

### Define keys to be varied

In [None]:
keys_with_bounds = {"phase":[0, 2*np.pi],
                    "B":[0,1.2]}

### Define metrics to be recorded

In [None]:
metric_params_dict = {
                       "current":{"tol":0.1}}

### Define data folder

In [None]:
data_folder = 'low_spin_orbit_current'

# Make ASS

In [None]:
%autoreload
ASS= distributed_sns.AggregatesSimulationSet

In [None]:
ass = ASS(keys_with_bounds,
              syst_pars, params,
              metric_params_dict)

### Add additional dimensions

In [None]:
alphas = [0.1, 1, 5, 10]

In [None]:
def _params_alpha(syst_pars, params, alpha):
    for key in ['alpha_middle', 'alpha_left', 'alpha_right']:
        if key in params:
            params[key] = alpha
        if key in syst_pars:
            syst_pars[key] = alpha
    return alpha

params_alpha = [ft.partial(_params_alpha,
                           alpha=_alpha
                           ) for _alpha in alphas]

In [None]:
ass.add_dimension("alpha", params_alpha)

In [None]:
temperatures = [0.0, 0.1, 0.2, 0.3]

In [None]:
def _params_temperature(syst_pars, params, temperature):
    for key in ['T']:
        if key in params:
            params[key] = temperature
        if key in syst_pars:
            syst_pars[key] = temperature
    return temperature

params_temperature = [ft.partial(_params_temperature,
                           temperature=_temperature
                           ) for _temperature in temperatures]

In [None]:
ass.add_dimension("temperature", params_temperature)

In [None]:
mus = [0.1, 1, 10, 20]

In [None]:
def _params_mu(syst_pars, params, mu):
    for key in ['mu']:
        if key in params:
            params[key] = mu
        if key in syst_pars:
            syst_pars[key] = mu
    return mu

params_mu = [ft.partial(_params_mu,
                           mu=_mu
                           ) for _mu in mus]

In [None]:
ass.add_dimension("mu", params_mu)

### Make learner

In [None]:
ass.make_balancing_learner(1000)
ass.load(data_folder, 1000)

### Make runner with saver

In [None]:
runner = adaptive.Runner(ass.get_balancing_learner(), executor=client)
ass.start_periodic_saver(runner, data_folder, interval=180)
runner.live_info()

In [None]:
runner.cancel()

# Plot

In [None]:
ass.learner._points.clear()

In [None]:
[(l.npoints, l.loss()) for l in ass.learner.learners]



In [None]:
hv.extension('matplotlib')

In [None]:
kf_normal = lambda mu: np.sqrt(2*params['m_eff']*mu)/params['hbar']
kfso = lambda alpha: params['m_eff']*alpha/params['hbar']**2
def plot_fermi_surface(mu, alpha, **_):
    theta = np.linspace(0, 2*np.pi)
    r0 = kf_normal(mu) - kfso(alpha)
    r1 = kf_normal(mu) + kfso(alpha)
    x0, y0 = (r0*np.cos(theta), r0*np.sin(theta))
    x1, y1 = (r1*np.cos(theta), r1*np.sin(theta))
    return (hv.Path((x0,y0))*hv.Path((x1,y1)))[-r1:r1,-r1:r1]

In [None]:
N_POINTS=100
kdims, plot_dict = ass.get_plot_dict(N_POINTS, tables=True)
pdict = dict((k,v[:,:]) for k,v in plot_dict.items())

plot_dict_min_curr={}
plot_dict_Ic={}
total_dict={}
for k,v in pdict.items():
    image =v.items()[0][1]
    x0, y0, x1, y1 = image.lbrt
    xdim=np.linspace(x0, x1, N_POINTS)
    ydim=np.linspace(y0, y1, N_POINTS)
    plot_dict_min_curr[k] = hv.Path((xdim[np.argmin(image.data, axis=1)], ydim))
    plot_dict_Ic[k] = hv.Path((np.max(image.data, axis=1) - np.min(image.data, axis=1),ydim))
    
    total_dict[k] = (image * plot_dict_min_curr[k] 
                     + plot_dict_Ic[k][:plot_dict_Ic[k].range('x')[1],:]
                     + plot_fermi_surface(**dict(zip(kdims, k)))
                     + v.items()[1][1]
                     + v.items()[2][1]
                    )

    


In [None]:
%%opts Image {+framewise +axiswise} [colorbar=True aspect=1]
%%opts Path (color='red') {+framewise +axiswise} [aspect=1] 
hv.HoloMap(total_dict, kdims=kdims)

In [None]:
%%opts Image {+framewise +axiswise} [colorbar=True aspect=1]
%%opts Path (color='red') {+framewise +axiswise} [aspect=1] 
%%output filename=f"./{data_folder}/spin_orbit_bottom_of_band" fig='png'
N_POINTS=100
kdims, plot_dict = ass.get_plot_dict(N_POINTS)
pdict = dict((k,v[:,:]) for k,v in plot_dict.items())

plot_dict_min_curr={}
plot_dict_Ic={}
for k,v in pdict.items():
    x0, y0, x1, y1 = r.lbrt
    xdim=np.linspace(x0, x1, N_POINTS)
    ydim=np.linspace(y0, y1, N_POINTS)
    plot_dict_min_curr[k] = hv.Path((xdim[np.argmin(v.data, axis=1)], ydim))
    plot_dict_Ic[k] = hv.Path((np.max(v.data, axis=1) - np.min(v.data, axis=1),ydim))
(hv.HoloMap(pdict, kdims=kdims) * 
 hv.HoloMap(plot_dict_min_curr, kdims=kdims) + 
 hv.HoloMap({k:v[:v.range('x')[1],:] for k,v in plot_dict_Ic.items()}, kdims=kdims) +
 hv.HoloMap({k:plot_fermi_surface(**dict(zip(kdims, k))) for k,v in plot_dict_Ic.items()}, kdims=kdims)
)

In [None]:
kf_normal = lambda mu: np.sqrt(2*params['m_eff']*mu)/params['hbar']
kfso = lambda alpha: params['m_eff']*alpha/params['hbar']**2
def plot_fermi_surface(mu, alpha, **_):
    theta = np.linspace(0, 2*np.pi)
    r0 = kf_normal(mu) - kfso(alpha)
    r1 = kf_normal(mu) + kfso(alpha)
    x0, y0 = (r0*np.cos(theta), r0*np.sin(theta))
    x1, y1 = (r1*np.cos(theta), r1*np.sin(theta))
    return (hv.Path((x0,y0))*hv.Path((x1,y1)))[-r1:r1,-r1:r1]

In [None]:
def dispersion(kx, ky, mu, alpha, B):
    m = params['m_eff']
    hbar = params['hbar']
    Ez = params['g_factor_middle'] * params['mu_B'] * B
    Ekin = hbar**2/(2*m) * (kx**2 + ky**2) - mu + m*alpha**2/(2*hbar**2)
    Eso  = np.sqrt(alpha**2*ky**2 + (alpha*kx - Ez)**2)
    return (Ekin + Eso, Ekin - Eso)