In [None]:
import numpy as np


def get_band_energy(nx, ny, kx, ky, hz, V, band_index):
    """
    Get the energy of a specific band at given k-point.
    
    Parameters
    ----------
    band_index : int
        Index of the band (0 = lowest energy band)
    
    Returns
    -------
    energy : float
        Energy eigenvalue of the specified band
    wavefunction : ndarray
        Corresponding eigenvector
    """
    H = build_hamiltonian(nx, ny, kx, ky, hz, V)
    eigenvalues = np.linalg.eigvalsh(H)
    return eigenvalues[band_index]


def compute_gradient(nx, ny, kx, ky, hz, V, band_index, dk=1e-5):
    """
    Compute the gradient of band energy with respect to kx and ky using finite differences.
    
    Parameters
    ----------
    dk : float
        Small increment for finite difference calculation
    
    Returns
    -------
    grad : ndarray
        Gradient vector [dE/dkx, dE/dky]
    """
    
    # Energy at the point
    E0, _ = get_band_energy(nx, ny, kx, ky, hz, V, band_index)
    
    # Forward differences
    E_kx_plus, _ = get_band_energy(nx, ny, kx + dk, ky, hz, V, band_index)
    E_ky_plus, _ = get_band_energy(nx, ny, kx, ky + dk, hz, V, band_index)
    
    # Compute derivatives
    dE_dkx = (E_kx_plus - E0) / dk
    dE_dky = (E_ky_plus - E0) / dk
    
    grad = np.array([dE_dkx, dE_dky])
    
    return grad


def compute_hessian(nx, ny, kx, ky, hz, V, band_index, dk=1e-5):
    """
    Compute the determinant of the Hessian matrix of band energy with respect to kx and ky using finite differences.
    
    Parameters
    ----------
    dk : float
        Small increment for finite difference calculation
    
    Returns
    -------
    hessian : float
        determinant of Hessian matrix, d²E/dkx² d²E/dky² - (d²E/dkxdky)^2
    """
    
    # Energies at neighboring points
    E0 = get_band_energy(nx, ny, kx, ky, hz, V, band_index)
    
    E_kx_plus = get_band_energy(nx, ny, kx + dk, ky, hz, V, band_index)
    E_kx_minus = get_band_energy(nx, ny, kx - dk, ky, hz, V, band_index)
    
    E_ky_plus = get_band_energy(nx, ny, kx, ky + dk, hz, V, band_index)
    E_ky_minus = get_band_energy(nx, ny, kx, ky - dk, hz, V, band_index)
    
    E_kx_plus_ky_plus = get_band_energy(nx, ny, kx + dk, ky + dk, hz, V, band_index)
    E_kx_plus_ky_minus = get_band_energy(nx, ny, kx + dk, ky - dk, hz, V, band_index)
    E_kx_minus_ky_plus = get_band_energy(nx, ny, kx - dk, ky + dk, hz, V, band_index)
    E_kx_minus_ky_minus = get_band_energy(nx, ny, kx - dk, ky - dk, hz, V, band_index)
    
    # Second derivatives using central differences
    d2E_dkx2 = (E_kx_plus - 2*E0 + E_kx_minus) / (dk**2)
    d2E_dky2 = (E_ky_plus - 2*E0 + E_ky_minus) / (dk**2)
    
    # Mixed derivative
    d2E_dkxdky = (E_kx_plus_ky_plus - E_kx_plus_ky_minus - E_kx_minus_ky_plus + E_kx_minus_ky_minus) / (4 * dk**2)
    
    hessian = d2E_dkx2 * d2E_dky2 - (d2E_dkxdky)**2
    
    return hessian


In [None]:
def get_Vcritical(nx, ny, hz, band_index, V_min=1.3, V_max=1.6, tol = 1e-6):
    """
    Compute the critical value of V for a given hz that corresponds to the determinant 
    of the hessian vanishing at the K valley.

    Paramaters
    ----------
    V_min, V_max : float
        Search range for V (default: 1.3 to 1.6).
    tol : float
        Tolerance for root finding (default: 1e-6).
    
    Returns
    -------
    V_crit : float
        Critical value of V where det(Hessian) = 0 at K valley for given hz.

    Notes
    -----
    The K valley point for triangular lattice is at k = (4π/3a, 0).
    This function finds V such that the band structure has a higher-order 
    saddle point at the K valley.
    """

    kx_K, ky_K = [4*np.pi/3, 0] #k at the K valley

    #List of V values within [V_min, V_max] range
    NV = 501; #Number of V values within the range
    V_list = np.linspace(V_min, V_max, NV);

    for V in V_list:
        hess = compute_hessian(nx, ny, kx_K, ky_K, hz, V, band_index, dk=1e-5);
        if np.round(hess,tol) == 0 :
            return V
        
 def get_hovhs(nx, ny, band_index, V_min=1.3, V_max=1.6, hz_min = 0, hz_min = 1.0):
     """
    Computing the (hz,V) that corresponds to higher-
    order Van Hove singularity around the K valley point for a
    given band indicated by the band index.

    Paramaters
    ----------
    V_min, V_max : float
        Search range for V (default: 1.3 to 1.6).
    hz_min, hz_max: float
        Range for h values (default:0.0 to 1.0).

    Return
    ------
    hz_list: list
        hz values
    Vc_list: list
        Critical V values that correspond to HOVHS
     """

     #List of hz values within the given range.
     Nhz = 101; #Number of hz values in the given range
     hz_list = np.linspace(0,1,Nhz);

     #List of V_critical values
     Vc_list = []; 
     for hz in hz_list:
         Vc_list.append(get_Vcritical(nx, ny, hz, band_index, V_min=1.3, V_max=1.6, tol = 1e-6))

    return hz_list, Vc_list


def plot_hovhs(hz_list, Vc_list):
    """
    Plot the HOVHS in hz_V phase space

    Parameters
    ----------
    hz_list: list
        hz values
    Vc_list: list
        Critical V values that correspond to HOVHS
    """

    plt.plot(hz_list, Vc_list,'k')

    plt.xlabel(r"h_z/\frac{v_F}{a}")
    plt.ylabel(r"V/\frac{v_F}{a}")

    plt.show()