In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt

def k_path(t, high_sym_points):
    """
    Map parameter t ∈ [0, 3] to k-points along the path K → Γ → M → K'.
    
    Parameters:
    -----------
    t : float
        Path parameter (0 ≤ t ≤ 3)
    high_sym_points : dict
        Dictionary of high-symmetry points
    
    Returns:
    --------
    k_vec : ndarray
        Momentum vector at parameter t
    """
    #High-symmetry points
    dkx = -(4*np.pi/3)*math.cos(theta/2);
    dky = -(4*np.pi/3)*math.sin(theta/2);
    
    K = np.array([dkx,-dky]);
    G = np.array([dkx + np.sqrt(3)*dky,0]);
    M = np.array([dkx,0]);
    Kp = np.array([dkx,dky]);
    
    if 0 <= t < 1:
        # K → Γ segment
        return K + t * (Γ - K)
    elif 1 <= t < 2:
        # Γ → M segment
        return Γ + (t - 1) * (M - Γ)
    else:
        # M → K' segment
        return M + (t - 2) * (Kp - M)
        

def compute_band_structure(nx, ny, delta, theta, hamiltonian_func, 
                           n_points=200, bands_to_plot=None):
    """
    Compute energy bands along high-symmetry path.
    
    Parameters:
    -----------
    nx, ny : int
        Lattice dimensions
    delta : float
        Sublattice potential (eV)
    theta : float
        Twist angle in radians
    hamiltonian_func : callable
        Function that takes (nx, ny, delta, k_vec) and returns Hamiltonian
    n_points : int
        Number of k-points along the path
    bands_to_plot : tuple or None
        (band_min, band_max) indices to compute. If None, compute middle 8 bands
    
    Returns:
    --------
    T : ndarray
        Path parameter values
    energies : ndarray
        Energy values, shape (n_bands, n_points)
    """
    high_sym_points = get_high_symmetry_points(theta)
    T = np.linspace(0, 3, n_points)
    
    # Determine which bands to compute
    n_total = 4 * (nx + 1) * (ny + 1)  # Total number of bands
    n_half = n_total // 2
    
    if bands_to_plot is None:
        # Default: plot 8 bands around zero energy
        band_min = n_half - 4
        band_max = n_half + 3
    else:
        band_min, band_max = bands_to_plot
    
    n_bands = band_max - band_min + 1
    energies = np.zeros((n_bands, n_points))
    
    # Compute eigenvalues at each k-point using numpy
    for i, t in enumerate(T):
        k_vec = k_path(t, high_sym_points)
        H = hamiltonian_func(nx, ny, delta, k_vec)
        eigenvalues = np.linalg.eigvalsh(H)  # Hermitian eigenvalue solver
        energies[:, i] = eigenvalues[band_min:band_max + 1]
    
    return T, energies


def plot_band_structure(T, energies, energy_unit='meV', ylim=(-60, 60), 
                       figsize=(8, 6), color='blue', alpha=0.85):
    """
    Plot the band structure.
    
    Parameters:
    -----------
    T : ndarray
        Path parameter values
    energies : ndarray
        Energy values in eV, shape (n_bands, n_points)
    energy_unit : str
        'eV' or 'meV'
    ylim : tuple
        Y-axis limits
    figsize : tuple
        Figure size
    color : str
        Line color
    alpha : float
        Line transparency
    """
    plt.figure(figsize=figsize)
    
    # Convert energy units if needed
    scale = 1000 if energy_unit == 'meV' else 1
    
    # Plot each band
    for band in energies:
        plt.plot(T, band * scale, color=color, alpha=alpha, linewidth=1.5)
    
    # Formatting
    plt.ylabel(f'Energy ({energy_unit})', fontsize=12)
    plt.ylim(ylim)
    plt.xlim(0, 3)
    
    # Set high-symmetry point labels
    plt.xticks([0, 1, 2, 3], ['K', 'Γ', 'M', "K'"], fontsize=12)
    
    # Add vertical lines at high-symmetry points
    for x in [0, 1, 2, 3]:
        plt.axvline(x, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    
    # Add horizontal line at zero energy
    plt.axhline(0, color='black', linestyle='-', linewidth=0.8, alpha=0.7)
    
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

