In [None]:
import numpy as np
import itertools
from scipy.signal import chirp, find_peaks, peak_widths
import matplotlib.pyplot as plt
from operator import itemgetter
from time import time
from tqdm import tqdm
from multiprocessing.pool import ThreadPool as Pool
import sys
import scipy.linalg
import scipy.sparse.linalg


#plt.style.use('bmh')
markers = ["o", "X", "P", "p", "*"]
cols = [p['color'] for p in plt.rcParams['axes.prop_cycle']]
plt.rc('text.latex', preamble=r'\usepackage{amsmath}\usepackage{braket}\usepackage{nicefrac}')
plt.rcParams.update({'font.size': 30,
                     'figure.figsize': (9,7),
                     'axes.facecolor': 'white',
                     'axes.edgecolor': 'lightgray',
                     "figure.autolayout": 'True',
                     'axes.xmargin': 0.03,
                     'axes.ymargin': 0.05,
                     'axes.grid': False,
                     'axes.linewidth': 5,
                     'lines.markersize': 15,
                     #'text.usetex': True,
                     'lines.linewidth': 8,
                     "legend.frameon": True,
                     "legend.framealpha": 0.7,
                     "legend.handletextpad": 1,
                     "legend.edgecolor": "black",
                     "legend.handlelength": 1,
                     "legend.labelspacing": 0,
                     "legend.columnspacing": 1,
                     "legend.fontsize": 35,
                    })
linestyles = ["-", "--", ":"]
bbox = dict(boxstyle="round", facecolor="lightgray")


def B_RG_flow(J0,B0,D0,d=0.01,plot = False):

    J = [J0]
    B = [B0]
    D = [D0]
    
    while D[-1] > d:
        #print(J[i],B[i])
        J.append(J[-1] - J[-1]**2*((-(D[0]/2)-(D[-1]/2)+(J[-1]/4))/((-(D[0]/2)-(D[-1]/2)+(J[-1]/4))**2 - (B[-1]/2)**2))*d)
        B.append(B[-1] -((J[-1]**2)/4)*(B[-1]/((-(D[0]/2)-(D[-1]/2)+(J[-1]/4))**2 - (B[-1]/2)**2))*d)
        D.append(D[-1]-d)
        #print((-D[i]+(J[i]/4))**2 - (B[i]/2)**2)
        if ((-(D[0]/2)-(D[-1]/2)+(J[-1]/4))**2 - (B[-1]/2)**2) *((-(D[0]/2)-(D[-2]/2)+(J[-2]/4))**2 - (B[-2]/2)**2)<= 0 or J[-1]*J[-2]<=0 or B[-1]*B[-2]<0 :
#             J[N > i] = J[i]
#             B[N > i] = B[i]
            J = J[:-1]
            B = B[:-1]
            D = D[:-1]
            break
    if plot:
        plt.plot(np.array(D)/D0, J,'o-')
        #plt.yscale("log")
        plt.xlabel(r'$\leftarrow RG flow$')
        plt.ylabel(r'$J \rightarrow$')
        plt.title(r'$J_0={J0}, B_0={B0}, D_0={D0}$'.format(J0=J0,B0=B0,D0=D0))
        plt.savefig("R-J ={J0},B ={B0},D ={D0}_J.pdf".format(J0= J0, B0 =B0, D0=D0), bbox_inches='tight')
        plt.show()
        plt.plot(np.array(D)/D0, np.array(B), '-', c='c')
        plt.xlabel(r'$ \leftarrow RG flow$')
        plt.ylabel(r'$ B \rightarrow$')
        plt.title(r'$J_0={J0}, B_0={B0}, D_0={D0}$'.format(J0=J0,B0=B0,D0=D0))
        #plt.legend()
        plt.savefig("R-J ={J0},B ={B0},D ={D0}_B.pdf".format(J0= J0, B0 =B0, D0=D0), bbox_inches='tight')
        plt.show()
    return np.flip(D), np.flip(J), np.flip(B)




def getBasis(num_levels, nTot=-1):
    """ The argument num_levels is the total number of qubits
    participating in the Hilbert space. Function returns a basis
    of the classical states necessary to express any state as a 
    superposition. Members of the basis are lists such as 
    [0,0,0,0], [0,0,0,1],..., [1,1,01] and [1,1,1,1], where each
    character represents the configuration (empty or occupied) of
    each single-particle level
    """
    
    basis = []
    for char_arr in itertools.product(["0", "1"], repeat=num_levels):
        if nTot == -1 or sum([int(ch) for ch in char_arr]) == nTot:
            basis.append("".join(char_arr))
    
    return basis


def applyTermOnBasisState(bstate, int_kind, site_indices):
    """ Applies a simple operator on a basis state. A simple operator is of the form '+-',[0,1].
    The first string, passed through the argument int_kind, indicates the form of the operator.
    It can be any operator of any length composed of the characters +,-,n,h. The list [0,1], passed
    through the argument site_indices, defines the indices of the sites on which the operators will 
    be applied. The n^th character of the string will act on the n^th element of site_indices. The
    operator is simple in the sense that there is no summation of multiple operators involved here.
    """

    # check that the operator is composed only out of +,-,n,h
    assert False not in [k in ['+', '-', 'n', 'h'] for k in int_kind], "Interaction type not among +, - or n."

    # check that the number of operators in int_kind matches the number of sites in site_indices.
    assert len(int_kind) == len(site_indices), "Number of site indices in term does not match number of provided interaction types."

    # final_coeff stores any factors that might emerge from applying the operator.
    final_coeff = 1

    # loop over all characters in the operator string, along with the corresponding site indices.
    for op, index in zip(int_kind[::-1], site_indices[::-1]):

        # if the character is a number or a hole operator, just give the corresponding occupancy.
        if op == "n":
            final_coeff *= int(bstate[index])
        elif op == "h":
            final_coeff *= 1 - int(bstate[index])

        # if the character is a create or annihilate operator, check if their is room for that.
        # If not, set final_coeff to zero. If there is, flip the occupancy of the site.
        elif (op == "+" and int(bstate[index]) == 1) or (op == "-" and int(bstate[index]) == 0):
            final_coeff *= 0
        else:
            final_coeff *= (-1) ** sum([int(ch) for ch in bstate[:index]])
            bstate = bstate[:index] + str(1 - int(bstate[index])) + (bstate[index+1:] if index + 1 < len(bstate) else '')

    return bstate, final_coeff


def getOperator(manyBodyBasis, int_kind, site_indices):
    """ Constructs a matrix operator given a prescription.
    manyBodyBasis is the set of all possible classical states.
    int_kind is a string that defines the qubit operators taking
    part in the operator. For eg.,'+-' means 'c^dag c'. 
    site_indices is a list that defines the indices of the states
    on whom the operators act. For rg., [0,1] means the operator
    is c^dag_0 c_1.
    """
    
    assert isinstance(manyBodyBasis, list)
    assert False not in [isinstance(item, str) for item in manyBodyBasis]
    # check that the number of qubit operators in int_kind matches the number of provided indices.
    assert isinstance(int_kind, str)
    assert False not in [k in ['+', '-', 'n', 'h'] for k in int_kind], "Interaction type not among +, - or n."
    #print(site_indices, [isinstance(index, int) for index in site_indices])
    # check that each operator in int_kind is from the set {+,-,n,h}, since these are the only ones we handle right now.
    assert isinstance(site_indices, list)
    assert False not in [isinstance(index, int) for index in site_indices]
    assert len(int_kind) == len(site_indices), "Number of site indices in term does not match number of provided interaction types."

    # initialises a zero matrix
    operator = np.zeros([len(manyBodyBasis), len(manyBodyBasis)], dtype=np.longdouble)
    
    # Goes over all pairs of basis states |b1>, |b2> of the operator in order to obtain each matrix element <b2|O|b1>.
    for start_index, start_state in enumerate(manyBodyBasis):
        
        # get the action of 'int_kind' on the state b2
        end_state, mat_ele = applyTermOnBasisState(start_state, int_kind, site_indices)

        if end_state in manyBodyBasis:
            end_index = manyBodyBasis.index(end_state)
            operator[end_index][start_index] = mat_ele
    return operator


def fermionicHamiltonian(manyBodyBasis, terms_list):
    """ Creates a matrix Hamiltonian from the specification provided in terms_list. terms_list is a dictionary
    of the form {['+','-']: [[1.1, [0,1]], [0.9, [1,2]], [2, [3,1]]], ['n']: [[1, [0]], [0.5, [1]], [1.2, [2]], [2, [3]]]}.
    Each key represents a specific type of interaction, such as c^dag c or n. The value associated with that key 
    is a nested list, of the form [g,[i_1,i_2,...]], where the inner list represents the indices of the particles 
    to whom those interactions will be applied, while the float value g in the outer list represents the strength 
    of that term in the Hamiltonian. For eg., the first key-value pair represents the interaction 
    1.1c^dag_0 c_1 + 0.9c^dag_1 c_2 + ..., while the second pair represents 1n_0 + 0.5n_1 + ...
    """
    
    # initialise a zero matrix
    hamlt = np.zeros([len(manyBodyBasis), len(manyBodyBasis)])

    # loop over all keys of the dictionary, equivalent to looping over various terms of the Hamiltonian
    for int_kind, val in terms_list.items():

        couplings = [t1 for t1,t2 in val]
        site_indices_all = [t2 for t1,t2 in val]

        # for each int_kind, pass the indices of sites to the get_operator function to create the operator 
        # for each such term
        hamlt += sum([coupling * getOperator(manyBodyBasis, int_kind, site_indices) for coupling, site_indices in tqdm(zip(couplings, site_indices_all), total=len(couplings), disable = True, desc="Obtaining operators for " + int_kind + " .")])
    return np.array(hamlt)


def get_eSIAMHamiltonian(manyBodyBasis, num_bath_sites, couplings):
    """ Gives the string-based prescription to obtain a eSIAM Hamiltonian:
    H = sum_k Ek n_ksigma + hop_strength sum_ksigma c^dag_ksigma c_dsigma + hc 
        + imp_Ed sum_sigma n_dsigma + imp_U n_dup n_ddn + kondo J sum_12 vec S_d dot vec S_{12}
    The coupling argument is a list that contains all the Hamiltonian parameters.
    Other parameters are self-explanatory. 
    """

    Ek, hop_strength, imp_U, imp_Ed,, kondo_J, zerothsite_U = couplings
    # ensure the number of terms in the kinetic energy is equal to the number of bath sites provided
    assert len(Ek) == num_bath_sites

    # adjust dispersion to make room for spin degeneracy: (Ek1, Ek2) --> (Ek1,  Ek1,  Ek2,  Ek2)
    #                                                      k1   k2            k1up  k1dn  k2up  k2dn
    Ek = np.repeat(Ek, 2)
    #print(Ek)
    
    # create kinetic energy term, by looping over all bath site indices 2,3,...,2*num_bath_sites+1,
    # where 0 and 1 are reserved for the impurity orbitals and must therefore be skipped.
    ham_KE = fermionicHamiltonian(manyBodyBasis, {'n': [[Ek[i - 2], [i]] for i in range(2, 2 * num_bath_sites + 2)]})

    # create the impurity-bath hopping terms, by looping over the up orbital indices i = 2, 4, 6, ..., 2*num_bath_sites,
    # and obtaining the corresponding down orbital index as i + 1. The four terms are c^dag_dup c_kup, h.c., c^dag_ddn c_kdn, h.c.
    ham_hop = (fermionicHamiltonian(manyBodyBasis, {'+-': [[hop_strength, [0, i]] for i in range(2, 2 * num_bath_sites + 2, 2)]}) 
               + fermionicHamiltonian(manyBodyBasis, {'+-': [[hop_strength, [i, 0]] for i in range(2, 2 * num_bath_sites + 2, 2)]})
               + fermionicHamiltonian(manyBodyBasis, {'+-': [[hop_strength, [1, i + 1]] for i in range(2, 2 * num_bath_sites + 2, 2)]})
               + fermionicHamiltonian(manyBodyBasis, {'+-': [[hop_strength, [i + 1, 1]] for i in range(2, 2 * num_bath_sites + 2, 2)]})
              )

    # create the impurity local terms for Ed, U
    ham_imp = (fermionicHamiltonian(manyBodyBasis, {'n': [[imp_Ed, [0]], [imp_Ed, [1]]]}) 
               + fermionicHamiltonian(manyBodyBasis, {'nn': [[imp_U, [0, 1]]]})
              )

    # create the sum_k Sdz Skz term, by writing it in terms of number operators. 
    # The first line is n_dup sum_k Skz = n_dup sum_ksigma (-1)^sigma n_ksigma, sigma=(0,1).
    # The second line is -n_ddn sum_k Skz = -n_ddn sum_ksigma (-1)^sigma n_ksigma, sigma=(0,1).
    zz_terms = (sum([], [[kondo_J, [0, 2 * k1, 2 * k2]] for k1, k2 in itertools.product(range(1, num_bath_sites + 1), repeat=2)]) 
                + sum([], [[-kondo_J, [0, 2 * k1 + 1, 2 * k2 + 1]] for k1, k2 in itertools.product(range(1, num_bath_sites + 1), repeat=2)])
                + sum([], [[-kondo_J, [1, 2 * k1, 2 * k2]] for k1, k2 in itertools.product(range(1, num_bath_sites + 1), repeat=2)])
                + sum([], [[kondo_J, [1, 2 * k1 + 1, 2 * k2 + 1]] for k1, k2 in itertools.product(range(1, num_bath_sites + 1), repeat=2)])
               )
    Ham_zz = 0.25 * fermionicHamiltonian(manyBodyBasis, {'n+-': zz_terms})
    Ham_plus_minus = 0.5 * (fermionicHamiltonian(manyBodyBasis, {'+-+-': [[kondo_J, [0, 1, 2 * k1 + 1, 2 * k2]] for k1,k2 in itertools.product(range(1, num_bath_sites + 1), repeat=2)]}))

    ham_zerothsite = fermionicHamiltonian(manyBodyBasis, {'nn': [[zerothsite_U, [2, 3]]]})
    
    return ham_KE + ham_hop + ham_imp + Ham_zz + Ham_plus_minus + np.conj(np.transpose(Ham_plus_minus)) + ham_zerothsite

def get_computational_coefficients(basis, state):
    """ Given a general state and a complete basis, returns specifically those
    basis states that can express this general state as a superposition. Also returns
    the associated coefficients of the superposition.
    """
    assert len(basis) == len(state)
    decomposition = dict()
    for i,coeff in enumerate(state):
        decomposition[basis[i]] = coeff
    
    return decomposition

def diagonalise(basis, hamlt):
    """ Diagonalise the provided Hamiltonian matrix.
    Returns all eigenvals and states.
    """
    
    E, v = scipy.linalg.eigh(hamlt)
    with Pool() as pool:
        workers = [pool.apply_async(get_computational_coefficients, (basis, v[:,i])) for i in range(len(E))]
        eigstates = [worker.get() for worker in tqdm(workers, disable= True, desc="Expressing state in terms of basis.")]
    return E, eigstates

def init_wavefunction(hamlt, mb_basis, displayGstate=False):
    """ Generates the initial wavefunction at the fixed point by diagonalising
    the Hamiltonian provided as argument. Expresses the state as a superposition
    of various classical states, returns these states and the associated coefficients.
    No IOMS are taken into account at this point.
    """
   
    eigvals, eigstates = diagonalise(mb_basis, hamlt)
    tolerance = 10
    #print ("G-state energy:", eigvals[eigvals == min(eigvals)])
    if sum (np.round(eigvals, tolerance) == min(np.round(eigvals, tolerance))) == 1:
        gstate = eigstates[0]
        excitedstate = eigstates[1:]
        E_0 = eigvals[0]
        E_n = eigvals[1:]
    else:
        #assert False, 
        "Ground state is degenerate! No SU(2)-symmetric ground state exists."
    
    if displayGstate:
        print (visualise_state(mb_basis, gstate))

    return eigvals, eigstates

def applyOperatorOnState(initialState, terms_list, finalState=dict(), tqdmDesc=None):
    """ Applies a general operator on a general state. The general operator is specified through
    the terms_list parameter. The description of this parameter has been provided in the docstring
    of the get_fermionic_hamiltonian function.
    """

    # loop over all basis states for the given state, to see how the operator acts 
    # on each such basis state
    for bstate, coeff in tqdm(initialState.items(), disable=True, desc=tqdmDesc):

        # loop over each term (for eg the list [[0.5,[0,1]], [0.4,[1,2]]]) in the full interaction,
        # so that we can apply each such chunk to each basis state.
        for int_kind, val in terms_list.items():

            # loop over the various coupling strengths and index sets in each interaction term. In
            # the above example, coupling takes the values 0.5 and 0.4, while site_indices take the values
            # [0,1] and [1,2].
            for coupling, site_indices in val:

                # apply each such operator chunk to each basis state
                mod_bstate, mod_coeff = applyTermOnBasisState(bstate, int_kind, site_indices)

                # multiply this result with the coupling strength and any coefficient associated 
                # with the initial state
                mod_coeff *= coeff * coupling

                if mod_coeff != 0:
                    try:
                        finalState[mod_bstate] += mod_coeff
                    except:
                        finalState[mod_bstate] = mod_coeff
                           
    return finalState

def innerProduct(state2, state1):
    """ Calculates the overlap <state2 | state1>.
    """
    innerProduct = sum([np.conjugate(state2[bstate]) * state1[bstate] for bstate in state1 if bstate in state2])
    return innerProduct

def matrixElement(finalState, operator, initState):
    """ Calculates the matrix element <final_state | operator | init_state> of an
    operator between the states initState and finalState  
    """
    intermediateState = applyOperatorOnState(initState, operator, finalState=dict())
    matElement = innerProduct(finalState, intermediateState)
    return matElement

def get_Spectral_function(J0, B0, D0, num_entangled, hamiltonianFunc, bandwidth):
    
    deltaD = 0.1
    D, J, B = B_RG_flow(J0, B0, D0, d=0.1, plot = True)
    
    #init_couplings = [Ek, J[0], B[0]]
    
    # get the basis of all classical states.
    mb_basis = getBasis(2 * (1 + num_entangled))
    
    E_k_spacing = np.linspace(deltaD, D[-1], len(D))
    
    A_UP = 0
    A_DOWN = 0
    for E_k_spacing_i, J_i, B_i in tqdm(zip(E_k_spacing, J, B)):
        Ek_i = np.linspace(-E_k_spacing_i, E_k_spacing_i, num_entangled)
        init_couplings = [Ek_i, J_i, B_i]
        
        # obtain the zero-bandwidth Hamiltonian at the IR
        hamlt = hamiltonianFunc(mb_basis, num_entangled, init_couplings)


        eigvals, eigstates = init_wavefunction(hamlt, mb_basis)

        if 1.9 > B0 >= 1:
            lenEvli = 0
            lenEvlf = len(eigvals)  #for cutting off the high energy triplet5 states take lenEvlf = 15
        elif B0 >= 8 :
            lenEvli = 0
            lenEvlf = len(eigvals)
        else :
            lenEvli = 0
            lenEvlf = len(eigvals)
        
        tolerance = 8
        Degeneracy = sum (np.round(eigvals, tolerance) == min(np.round(eigvals, tolerance)))
        E_0 = eigvals[:Degeneracy]
        #E_n = [eigval for eigval in eigvals if eigval < 0][Degeneracy:]
        E_n = eigvals[lenEvli:lenEvlf]
        #print(Degeneracy)
        gstates = eigstates[:Degeneracy]
        #excitedstates = [eigstate for eigval, eigstate in zip(eigvals, eigstates) if eigval < 0][Degeneracy:]
        allstates = eigstates[lenEvli:lenEvlf]
        #print(eigvals)

        a = 5.485
        b = 0 #(1.3/(max(bandwidth)))**6
        Brodening_eta = lambda w: a - b * (abs(w))**6
        
        operator = dict()
        operator['+--'] = [[1, [1, 0, 2 * k + 1]] for k in range(1, num_entangled + 1)]
        operator['n-'] = [[0.5, [0, 2 * k]] for k in range(1, num_entangled + 1)] + [[-0.5, [1, 2 * k]] for k in range(1, num_entangled + 1)]
        
        for i in range(0, Degeneracy):
            Coeffs1 = [(matrixElement(gstates[i], operator, allstates[n]) ** 2) for n in range(0,len(allstates))] 
            Coeffs2 = [(matrixElement(allstates[n], operator, gstates[i]) ** 2) for n in range(0,len(allstates))] 
            delta1 = lambda w: np.array([(1/(np.abs(Brodening_eta(w)) * np.sqrt(np.pi))) * np.exp(-((w + E_0[i] - E_n[n]) / Brodening_eta(w))**2) for n in range(0,len(allstates))] )
            delta2 = lambda w: np.array([(1/(np.abs(Brodening_eta(w)) * np.sqrt(np.pi))) * np.exp(-((w - E_0[i] + E_n[n]) / Brodening_eta(w))**2) for n in range(0,len(allstates))] )
            #print('H',[Coeff1 for Coeff1 in Coeffs1 if Coeff1 > 1e-10])
            #print('G',[E_n[n] - E_0[i] for n, Coeff1 in zip(range(0,len(excitedstates)),Coeffs1) if Coeff1 > 1e-10])
            A_UP +=  np.array([(1/np.pi) * (1/Degeneracy) * Brodening_eta(w) * sum(Coeffs1 * delta1(w) + Coeffs2 * delta2(w)) for w in tqdm(bandwidth, disable = True)])
        
        
        operatorDown = dict()
        operatorDown['+--'] = [[1, [0, 1, 2 * k]] for k in range(1, num_entangled + 1)]
        operatorDown['n-'] = [[-0.5, [0, 2 * k + 1]] for k in range(1, num_entangled + 1)] + [[0.5, [1, 2 * k + 1]] for k in range(1, num_entangled + 1)]
        
        operatorDownDag = dict()
        operatorDownDag['++-'] = [[1, [2 * k, 1, 0]] for k in range(1, num_entangled + 1)]
        operatorDownDag['+n'] = [[-0.5, [2 * k + 1, 0]] for k in range(1, num_entangled + 1)] + [[0.5, [2 * k + 1, 1]] for k in range(1, num_entangled + 1)]
        
        for i in range(0, Degeneracy):
            Coeffs1 = [(matrixElement(gstates[i], operatorDown, allstates[n]) ** 2) for n in range(0,len(allstates))] 
            Coeffs2 = [(matrixElement(allstates[n], operatorDown, gstates[i]) ** 2) for n in range(0,len(allstates))] 
            delta1 = lambda w: np.array([(1/(np.abs(Brodening_eta(w)) * np.sqrt(np.pi))) * np.exp(-((w + E_0[i] - E_n[n]) / Brodening_eta(w))**2) for n in range(0,len(allstates))] )
            delta2 = lambda w: np.array([(1/(np.abs(Brodening_eta(w)) * np.sqrt(np.pi))) * np.exp(-((w - E_0[i] + E_n[n]) / Brodening_eta(w))**2) for n in range(0,len(allstates))] )
            #print('H',[Coeff1 for Coeff1 in Coeffs1 if Coeff1 > 1e-10])
            #print('G',[E_n[n] - E_0[i] for n, Coeff1 in zip(range(0,len(excitedstates)),Coeffs1) if Coeff1 > 1e-10])
            A_DOWN +=  np.array([(1/np.pi) * (1/Degeneracy) * Brodening_eta(w) * sum(Coeffs1 * delta1(w) + Coeffs2 * delta2(w)) for w in tqdm(bandwidth, disable = True)])
    
    A_UP = A_UP / np.trapz(A_UP, bandwidth)   
    A_DOWN = A_DOWN / np.trapz(A_DOWN, bandwidth)
    
    A = A_UP + A_DOWN
    return A
    
def get_Spectral_function_plot(Bs, bandwidth):
    num_entangled = 3
    
    A = [] 
    for B_i in Bs:
        A.append(get_Spectral_function(0.2, B_i, 1, num_entangled, getKondoHamiltonian,bandwidth))
    plt.show()
    i = 0

    for A_i in A:
        plt.plot(bandwidth, A_i, label = r'$B = {B}$'.format(B = Bs[i]) )
        plt.xlabel(r"$\omega \rightarrow$")
        plt.ylabel(r"$A(\omega,B)\rightarrow$")
        plt.axvline(x=0, c= 'k', lw=1)
        plt.legend( loc ="upper left", fontsize="25")
        # plt.yscale("log")
        i +=1
    plt.savefig("Spectral_Function_RGflow_Gaussian.pdf" , bbox_inches='tight')
    plt.show()
    return 
                 
Bs = [0, 1, 2, 3, 4, 5, 6, 7, 8]
bandPoints = 2000
bandwidth = np.linspace(-25, 25, bandPoints)
get_Spectral_function_plot(Bs, bandwidth)