# Imports

In [None]:
%load_ext autoreload
import os, sys
sys.path.append(os.path.abspath('..'))

import warnings
from functools import partial

import numpy as np
import cmath
import scipy.constants

import numbers

import kwant

import holoviews as hv
hv.notebook_extension('bokeh')

import sns_system
import spectrum
import plotting_results

## Constants

In [None]:
constants = dict(
    m_eff=0.023 * scipy.constants.m_e / (scipy.constants.eV * 1e-3) / 1e18,  # effective mass in kg, 
    hbar=scipy.constants.hbar / (scipy.constants.eV * 1e-3),
    e = scipy.constants.e,
    current_unit=scipy.constants.k * scipy.constants.e / scipy.constants.hbar * 1e9,  # to get nA
    mu_B=scipy.constants.physical_constants['Bohr magneton'][0] / (scipy.constants.eV * 1e-3),
    k=scipy.constants.k / (scipy.constants.eV * 1e-3),
    exp=cmath.exp,
    cos=cmath.cos,
    sin=cmath.sin
   )

# Create system

In [None]:
syst_pars =  {'Ll' :  1000,
               'Lr' : 1000,
               'Lm' : 500,
               'Ly' : 10,
               'a' :  10,
               'mu_from_bottom_of_spin_orbit_bands': True}

params_raw= dict(g_factor_middle = 10,
                 g_factor_left = 0,
                 g_factor_right = 0,
                 mu = 1.0,
                 alpha_middle = 28,
                 alpha_left = 28,
                 alpha_right = 28,
                 Delta_left = .18,
                 Delta_right = .18,
                 B = 0.5,
                 phase = np.pi/2,
                 T = 0.025,
                 V=0)

params = dict(**constants,
              **params_raw)

syst = sns_system.make_wrapped_system(**syst_pars)

_=plotting_results.plot_syst(syst_pars, sns_system.dummy_params)

# Function definitions

### Vector factory

In [None]:
def ensure_rng(rng=None):
    """Turn rng into a random number generator instance
    If rng is None, return the RandomState instance used by np.random.
    If rng is an integer, return a new RandomState instance seeded with rng.
    If rng is already a RandomState instance, return it.
    Otherwise raise ValueError.
    """
    if rng is None:
        return np.random.mtrand._rand
    if isinstance(rng, numbers.Integral):
        return np.random.RandomState(rng)
    if all(hasattr(rng, attr) for attr in ('random_sample', 'randn',
                                           'randint', 'choice')):
        return rng
    raise ValueError("Expecting a seed or an object that offers the "
                     "numpy.random.RandomState interface.")
    
def make_local_factory(site_indices=None, rng=0):
            """Return a `vector_factory` that outputs local vectors.

            If `sites` is provided, the local vectors belong only to
            those sites.

            The parameter `rng` is passed to define a seed for finding the
            bounds of the spectrum. Using the same seed ensures reproducibility.
            """
            rng = ensure_rng(rng)

            idx = -1
            def vector_factory(n):
                nonlocal idx, rng, site_indices
                
                if site_indices is None:
                    site_indices = np.arange(n)
                if idx == -1:
                    idx += 1
                    return np.exp(rng.rand(n) * 2j * np.pi)
                else:
                    vec = np.zeros(n, dtype=complex)
                    vec[site_indices[idx]] = 1

                    idx += 1
                    return vec
            return vector_factory
        
def make_ev_factory(eigenvecs, rng=0):
            """Return a `vector_factory` that outputs local vectors.

            If `sites` is provided, the local vectors belong only to
            those sites.
            If `eigenvecs` is provided, project out those vectors form
            the local vectors.
            The parameter `rng` is passed to define a seed for finding the
            bounds of the spectrum. Using the same seed ensures reproducibility.
            """
            rng = ensure_rng(rng)
            if eigenvecs is not None:
                pass

            idx = -1
            def vector_factory(n):
                nonlocal idx, rng, eigenvecs
                
                if idx == -1:
                    idx += 1
                    return np.exp(rng.rand(n) * 2j * np.pi)
                else:
                    vec = eigenvecs[:,idx]
                    idx += 1
                    return vec
            return vector_factory

### Function to get sites and indices for cut

In [None]:
def get_cut_sites_and_indices(syst, cut_tag, direction):
    l_cut = []
    r_cut = []
    cut_indices = []
    
    for site_idx, site in enumerate(syst.sites):
        if site.tag[direction]==cut_tag:
            l_cut.append(site)
            temp = [4*site_idx, 4*site_idx+1, 4*site_idx+2, 4*site_idx+3]
            cut_indices.append(temp)
        if site.tag[direction]==cut_tag+1:
            r_cut.append(site)
            temp = [4*site_idx, 4*site_idx+1, 4*site_idx+2, 4*site_idx+3]
            cut_indices.append(temp)
    
    cut_indices = np.hstack(cut_indices)
    cut_sites   = list(zip(l_cut, r_cut))
    
    return (cut_indices, cut_sites)


### Fermi dirac function

In [None]:
import supercurrent
fermi_dirac = supercurrent.fermi_dirac

### Projected and non-projected current operator

In [None]:
sigz = kwant.continuum.discretizer.ta.array([[1,0,0,0],
                                             [0,1,0,0],
                                             [0,0,-1,0],
                                             [0,0,0,-1]])

def make_projected_current(syst, params, eigvecs, cut=None):
    """Returns a current operator `C` that projects out the Andreev vectors
    on the right side.

    The returned function `f(bra,ket)` gives the output of
    `<bra| C (1-P) |ket>`, where `C` is the current operator,
    and `P` a projector to the Andreev vectors.
    """
    kwant_operator = kwant.operator.Current(syst, sigz, where=cut)
    kwant_operator = kwant_operator.bind(params=params)

    def projected_current(bra, ket):
        nonlocal eigvecs
        projected_ket = eigvecs.T.conj() @ ket
        ket = ket - eigvecs @ projected_ket
        return kwant_operator(bra, ket)
    return projected_current


def make_exact_current(syst, params, cut=None):
    kwant_operator = kwant.operator.Current(syst, sigz, where=cut)
    kwant_operator = kwant_operator.bind(params=params)
    return kwant_operator



## Current exact

In [None]:
def current_exact(syst_pars, params, cut_tag=0, direction=0):   
    syst = sns_system.make_wrapped_system(**syst_pars)
    
    (cut_indices, cut_sites) = get_cut_sites_and_indices(syst, cut_tag, direction)

    current_operator = kwant.operator.Current(syst,
                                              onsite=sigz,
                                              where=cut_sites).bind(params=params)
    
    ham = syst.hamiltonian_submatrix(params=params)

    (en, evs) = scipy.linalg.eigh(ham)
    
    I = 0
    
    for (e, ev) in zip(en, evs.T):
        I += fermi_dirac(e.real, params) * current_operator(ev)
    return sum(I)*params['e']/params['hbar']

## Calculate current using projected KPM

In [None]:
def current_kpm_projected(syst_pars, params,
                          k, energy_resolution,
                          cut_tag=0, direction=0):
    
    I = 0
    
# 1. Make system
    params.update(dict(**sns_system.constants))
    syst = sns_system.make_wrapped_system(**syst_pars)
    
#  i. Make fermi dirac function
    _fermi_dirac = partial(fermi_dirac, params=params)
    
# 2. Make cut(list of sites, and list of site index)
    (cut_indices, cut_sites) = get_cut_sites_and_indices(syst, cut_tag, direction)

    
# 3. Calculate exact spectrum for k eigenvalues/vectors
    ham = syst.hamiltonian_submatrix(params=params, sparse=True)

    
    (en, evs) = spectrum.sparse_diag(ham, k=k, sigma=0)
    if max(en)<params['Delta_left']:
        warnings.warn('max(en)<params[\'Delta\']', Warning)
        
        
# 4. Calculate current from exact spectrum
#  i. Create current operator, bind parameters
    exact_current_operator = make_exact_current(syst, params, cut=cut_sites)
                 
        
    
# ii. Apply current operator for each eigenstate, apply fermi-dirac
    for (e, ev) in zip(en, evs.T):        
        I += _fermi_dirac(e.real) * exact_current_operator(ev)

# 5. Calculate current from kpm part
#  i. Create current operator w/ projected out states
    kpm_current_operator = make_projected_current(syst, params, evs, cut=cut_sites)

# ii. Create eigenvector factory 
    factory = make_local_factory(site_indices=cut_indices)

#     iv. Create SpectralDensity object
    sd = kwant.kpm.SpectralDensity(syst,
                                   params=params,
                                   operator=kpm_current_operator,
                                   num_vectors=len(cut_indices),
                                   num_moments=2,
                                   vector_factory=factory)

#      v. Add moments up to correct resolution
    sd.add_moments(energy_resolution=energy_resolution)

#     vi. Integrate over spectral density

    I += sd.integrate(distribution_function=_fermi_dirac)*len(cut_indices)

    return params['e']/ params['hbar'] * sum(I)

## Calculate current non-projected KPM

In [None]:
def current_kpm_non_projected(syst_pars, params,
                              k, energy_resolution,
                              cut_tag=0, direction=0):
    
# 1. Make system
    params.update(dict(**sns_system.constants))
    syst = sns_system.make_wrapped_system(**syst_pars)
    
#  i. Make fermi dirac function
    _fermi_dirac = partial(fermi_dirac, params=params)
    
# 2. Make cut(list of sites, and list of site index)
    (cut_indices, cut_sites) = get_cut_sites_and_indices(syst, cut_tag, direction)

    
# 3. Calculate exact spectrum for k eigenvalues/vectors
    ham = syst.hamiltonian_submatrix(params=params, sparse=True)

    
    (en, evs) = spectrum.sparse_diag(ham, k=k, sigma=0)
    if max(en)<params['Delta_left']:
        warnings.warn('max(en)<params[\'Delta\']', Warning)
        
        
# 4. Calculate current from exact spectrum
#  i. Create current operator, bind parameters
    exact_current_operator = make_exact_current(syst, params, cut=cut_sites)
                 
        
    
# ii. Apply current operator for each eigenstate, apply fermi-dirac
    I_AB_exact = 0
    for (e, ev) in zip(en, evs.T):        
        I_AB_exact += _fermi_dirac(e.real) * exact_current_operator(ev)
    
    
# 5. Calculate current from kpm part, with all states(including AB)
#  i. Create current operator (same as exact)
    kpm_current_operator = make_exact_current(syst, params, cut=cut_sites)



    factory = make_local_factory(site_indices=cut_indices)

    sd = kwant.kpm.SpectralDensity(syst,
                                   params=params,
                                   operator=kpm_current_operator,
                                   num_vectors=len(cut_indices),
                                   num_moments=2,
                                   vector_factory=factory)

#      v. Add moments up to correct resolution
    sd.add_moments(energy_resolution=energy_resolution)
        
#     vi. Integrate over spectral density
    
    I_all_kpm = sd.integrate(distribution_function=_fermi_dirac)*len(cut_indices)
    
# 6. Calculate current from kpm part, with AB states
#  i. Create current operator (same as exact)
    kpm_current_operator = make_exact_current(syst, params, cut=cut_sites)



    factory = make_ev_factory(evs)

    sd = kwant.kpm.SpectralDensity(syst,
                                   params=params,
                                   operator=kpm_current_operator,
                                   num_vectors=k,
                                   num_moments=2,
                                   vector_factory=factory)

#      v. Add moments up to correct resolution
    sd.add_moments(energy_resolution=energy_resolution)
    
#     vi. Integrate over spectral density
    
    I_AB_kpm = sd.integrate(distribution_function=_fermi_dirac)*k

    return params['e']/ params['hbar'] * sum(I_AB_exact + (I_all_kpm - I_AB_kpm))

# Calculations

In [None]:
%%time
current_kpm_projected(syst_pars=syst_pars,
                      params=dict(**params, k_y=0),
                      k=20,
                      energy_resolution=params['Delta_left']/2)

In [None]:
%%time
current_kpm_non_projected(syst_pars=syst_pars,
                          params=dict(**params, k_y=0),
                          k=20, 
                          energy_resolution=params['Delta_left']/2)

In [None]:
%%time
current_exact(syst_pars=syst_pars,
              params=dict(**params, k_y=0))

In [None]:
p = dict(**params, k_y=0)
params_k = dict(**params, k_y=2)

ham = syst.hamiltonian_submatrix(params=p)
(en, evs) = scipy.linalg.eigh(ham)
ham_k = syst.hamiltonian_submatrix(params=params_k)
(en_k, evs_k) = scipy.linalg.eigh(ham_k)

In [None]:
_=evs_k.T.conjugate()@evs

In [None]:
_

In [None]:
# %%timeit
ham = syst.hamiltonian_submatrix(params=params_k, sparse=True)
e = (evs.T.conjugate()@ham@evs).diagonal()
# ham_evs = evs.T.conjugate()@ham@evs
# index = np.argmax(ham_evs, axis=0)
# e = ham_evs[index]/evs[index]

# _ = e.diagonal()

$E = A B C $

$(ABC)_{i,j} = \sum_{k} A_{i,k} (BC)_{k,j}$

$(BC)_{k,j} = \sum_{l} B_{k,l} C_{l,j}$

$(ABC)_{i,j} = \sum_{k,l} A_{i,k} B_{k,l} C_{l,j}$

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(np.sort(e.real)-np.sort(en))
# plt.plot(en)